diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 9ae78a64b6..4efd64a5db 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -408,10 +408,6 @@ jobs: cd python ccache --zero-stats pip install -v -e '.[tests]' - - name: Clean up after an unsuccessful build - if: ${{ !success() && steps.amd-install-triton.outcome != 'success' }} - run: | - rm -rf ~/.triton - name: CCache Stats run: ccache --print-stats - name: Run lit tests @@ -477,8 +473,11 @@ jobs: ~/.ccache key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - name: Clean up caches + # Always cleanup the worker, even if builds or tests failed + if: always() run: | - rm -rf ~/.triton/cache + rm -rf ~/.triton + rm -rf ~/.ccache Build-Tests: needs: Runner-Preparation if: needs.Runner-Preparation.outputs.matrix-MACOS != '' diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index 6b58be6571..fd4b0e980e 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -402,11 +402,6 @@ jobs: ccache --zero-stats pip install -v -e '.[tests]' - - name: Clean up after an unsuccessful build - if: ${{ !success() && steps.amd-install-triton.outcome != 'success' }} - run: | - rm -rf ~/.triton - - *print-ccache-stats - *run-lit-tests-step @@ -442,8 +437,11 @@ jobs: - *save-build-artifacts-step - name: Clean up caches + # Always cleanup the worker, even if builds or tests failed + if: always() run: | - rm -rf ~/.triton/cache + rm -rf ~/.triton + rm -rf ~/.ccache Build-Tests: needs: Runner-Preparation diff --git a/README.md b/README.md index 4e5bf22346..79f42adf93 100644 --- a/README.md +++ b/README.md @@ -211,10 +211,10 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi - `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass. - `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma). - `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks. -- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx. -- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx when `TRITON_KERNEL_DUMP` is set to 1. -- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx at the beginning of each compilation stage. -- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx files when `TRITON_KERNEL_OVERRIDE` is set to 1. +- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx/amdgcn. +- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx/amdgcn when `TRITON_KERNEL_DUMP` is set to 1. +- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage. +- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx/amdgcn files when `TRITON_KERNEL_OVERRIDE` is set to 1. **Kernel Override Steps** @@ -224,7 +224,7 @@ export TRITON_KERNEL_DUMP=1 export TRITON_DUMP_DIR= export TRITON_KERNEL_OVERRIDE=1 export TRITON_OVERRIDE_DIR= -# Step 1: Run the kernel once to dump kernel's IRs and ptx in $TRITON_DUMP_DIR +# Step 1: Run the kernel once to dump kernel's IRs and ptx/amdgcn in $TRITON_DUMP_DIR # Step 2: Copy $TRITON_DUMP_DIR/ to $TRITON_OVERRIDE_DIR # Step 3: Delete the stages that you do not want to override and modify the stage you do want to override # Step 4: Run the kernel again to see the overridden result diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index a3e38e177d..2e4cbbc651 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -161,6 +161,8 @@ class GatherLoweringHelper { // Get the shared memory scratch size required by this op. unsigned getScratchSizeInBytes(); + // Determine if the gather can be performed completely within a warp. + bool isWarpLocal(); private: triton::GatherOp gatherOp; diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 4b2179611a..fc1b731d15 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1123,8 +1123,8 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, return idx; } -// Emit code to compute the (blockId, warpId, laneId) for the current thread. -std::tuple +// Emit code to compute the (laneId, warpId, blockId) for the current thread. +std::tuple emitHardwareTuple(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, bool withCTAOffset, unsigned threadsPerWarp); diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 85c789635a..e592a9d6d1 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -214,7 +214,12 @@ LinearLayout ensureLayoutNotSmallerThan( const LinearLayout &layout, const llvm::SmallDenseMap &shape); +// Return a vector of the standard out dimension names for tensor layouts. These +// are "dim0", "dim1", etc. SmallVector standardOutDimNames(MLIRContext *ctx, int rank); +// Return an identity mapping from `inDimName` to the standard out dimensions, +// with the dimensions sized according to the shape. The bases are sorted +// according to `order`, with the most minor dimension first. LinearLayout identityStandardND(StringAttr inDimName, ArrayRef shape, ArrayRef order); diff --git a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h index cdf22d15d4..414313138b 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -2,6 +2,8 @@ #define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ #include "mlir/Dialect/SCF/IR/SCF.h" +#include +#include #include namespace mlir { @@ -38,6 +40,7 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse, // Return the minClusterId and maxClusterId for the given ForOp. std::pair getMinMaxCluster(scf::ForOp &forOp); std::pair getStageCluster(Operation *op); +std::optional> maybeGetStageCluster(Operation *op); void setStageCluster(Operation *op, int stage, int cluster); } // namespace triton } // namespace mlir diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h new file mode 100644 index 0000000000..8f36b7732f --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h @@ -0,0 +1,99 @@ +#pragma once +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton::nvidia_gpu { + +constexpr inline int TMA_SIZE_BYTES = 128; +constexpr inline int TMA_ALIGN = 128; + +template +mlir::LogicalResult createTMADesc(mlir::Value tmaPtr, + mlir::triton::MakeTensorDescOp op, + BuilderT &builder) { + using namespace mlir; + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + auto mkI32Constant = [&](int32_t val) { + return builder.template create( + loc, builder.getI32Type(), builder.getI32IntegerAttr(val)); + }; + + auto elemType = op.getBase().getType().getPointeeType(); + auto elemSize = elemType.getIntOrFloatBitWidth() / 8; + + int32_t contig_dim_size = op.getTensorShape().back(); + int32_t contig_dim_size_in_bytes = contig_dim_size * elemSize; + if (contig_dim_size_in_bytes > 128) { + contig_dim_size = 128 / elemSize; + } + llvm::SmallVector boxDim; + boxDim.push_back(mkI32Constant(contig_dim_size)); + for (int k = op.getTensorShape().size() - 2; k >= 0; --k) { + boxDim.push_back(mkI32Constant(op.getTensorShape()[k])); + } + + int32_t swizzle_mode; + if (contig_dim_size_in_bytes >= 128) { + swizzle_mode = 3; + } else if (contig_dim_size_in_bytes == 64) { + swizzle_mode = 2; + } else if (contig_dim_size_in_bytes == 32) { + swizzle_mode = 1; + } else { + op->emitError() + << "contiguous box dimension must be at least 32 bytes but got " + << contig_dim_size_in_bytes; + return failure(); + } + + Value elemSizeVal = builder.template create( + loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize)); + Value globalStride = builder.template create( + loc, op.getStrides()[0], elemSizeVal); + // TODO: Workaround for ptxas bug, remove when we update ptxas + Value four = builder.template create( + loc, builder.getI64Type(), builder.getI64IntegerAttr(4)); + globalStride = + builder.template create(loc, globalStride, four); + + int elemTypeEnum; + switch (elemSize) { + case 1: { + elemTypeEnum = 0; + break; + } + case 2: { + elemTypeEnum = 1; + break; + } + case 4: { + elemTypeEnum = 2; + break; + } + default: { + op->emitError() + << "Tensor descriptor element type must have size 1, 2, or 4 but got " + << elemSize; + return failure(); + } + } + + auto one = mkI32Constant(1); + builder.template create( + loc, + /*desc_ptr=*/tmaPtr, + /*global_address=*/op.getBase(), + /*box_dim=*/boxDim, + /*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]}, + /*global_stride=*/ValueRange{globalStride}, + /*element_strides=*/ValueRange{one, one}, + /*elem_type*/ builder.getI32IntegerAttr(elemTypeEnum), + /*interleave_layout*/ builder.getI32IntegerAttr(0), + /*swizzle_mode=*/builder.getI32IntegerAttr(swizzle_mode), + /*fill_mode=*/builder.getI32IntegerAttr(0)); + return success(); +} + +} // namespace mlir::triton::nvidia_gpu diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 8170b923ee..92a432fbec 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -419,13 +419,93 @@ GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp) : gatherOp(gatherOp) {} unsigned GatherLoweringHelper::getScratchSizeInBytes() { - // For now, lower the gather op by writing the source tensor to shared memory. - // TODO(jeff): Leverage locality to avoid using scratch space when possible. + // If the gather is warp-local, no scratch space is needed. + if (isWarpLocal()) + return 0; + + // Otherwise, performing the gather will require scratch space to communicate + // the source tensor across threads. For now, assume the whole source tensor + // is written back to shared memory. RankedTensorType srcType = gatherOp.getSrc().getType(); return product(srcType.getShape()) * ceil(srcType.getElementTypeBitWidth(), 8); } +bool GatherLoweringHelper::isWarpLocal() { + // The gather is warp-local if for each column along the gather axis in the + // source and index tensors, all the elements are owned by the same warp. + RankedTensorType srcType = gatherOp.getSrc().getType(); + RankedTensorType idxType = gatherOp.getIndices().getType(); + std::optional srcLayout = + toLinearLayout(srcType.getShape(), srcType.getEncoding()); + std::optional idxLayout = + toLinearLayout(idxType.getShape(), idxType.getEncoding()); + + // FIXME: If an unsupported layout was encountered, assume the gather is not + // warp-local. + if (!srcLayout || !idxLayout) + return false; + + Builder b(gatherOp.getContext()); + StringAttr kBlock = b.getStringAttr("block"); + StringAttr kWarp = b.getStringAttr("warp"); + StringAttr kLane = b.getStringAttr("lane"); + StringAttr kGatherDim = + b.getStringAttr("dim" + std::to_string(gatherOp.getAxis())); + + // The tensor layouts must be distributed layouts, where the basis matrix is a + // subpermutation matrix (permutation matrix plus zeros for broadcasting). + // FIXME(jeff): Check this invariant somehow. + // + // We want to know if all elements of a column along the gather axis are + // mapped to the same set of warps, which means the gather can be performed + // entirely within the warp. We need to query + // + // srcLayout.invert().sublayoutIsZero({kGatherDim}, {kBlock, kWarp}) + // + // But due to broadcasting, the matrix might not be invertible. But since the + // matrix is a permutation matrix (checked below), we can instead query + // + // srcLayout.sublayoutIsZero({kBlock, kWarp}, {kGatherDim}) + // + // Which implies that changing the warp will not change the gather dimension. + // And since there is no swizzling, this applies to all warps. + if (!srcLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim) || + !idxLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim)) + return false; + + SmallVector otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + if (dim != gatherOp.getAxis()) { + otherDims.push_back(b.getStringAttr("dim" + Twine(dim))); + } + } + + // If the gather axis `dimN` is invariant to the warp, but the `(block, warp)` + // mapping to all other dimensions must be the same for both layouts. If so, + // then the warp that owns a particular index element also owns all the source + // elements it could index into. + if (srcLayout->sublayout({kBlock, kWarp}, otherDims) != + idxLayout->sublayout({kBlock, kWarp}, otherDims)) + return false; + + // The two constraints above ensure that data-movement to perform the gather + // operation are contained within a warp. The subsequent constraints simplify + // codegen. + + // Require that for any given gather column, the threads mapped to the column + // in the index and source tensors are the same. This means we don't need to + // xor shuffle across threads before emitting index shuffles; we push warp + // shuffling to layout conversions. + if (srcLayout->sublayout(kLane, otherDims) != + idxLayout->sublayout(kLane, otherDims)) + return false; + + // Otherwise, the source layout has to be invertible. This primarily means + // the codegen path doesn't support broadcasted source layouts. + return srcLayout->isInvertible(); +} + unsigned getNumScratchElements(ArrayRef shape) { if (shape.empty()) return 0; diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 5ab81eff81..673ac8e74f 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -1,8 +1,10 @@ #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" using namespace mlir; using namespace mlir::triton; +using namespace mlir::triton::gpu; namespace { class GatherOpConversion : public ConvertOpToLLVMPattern { @@ -17,12 +19,51 @@ class GatherOpConversion : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override; private: + // Codegen the gather by storing the source tensor into shared memory and then + // gathering directly from shared memory. + void emitGatherInShared(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + // Codegen a warp-local gather by shuffling elements across the warp and + // selecting from them. + void emitWarpLocalGather(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + const TargetInfoBase &targetInfo; }; LogicalResult GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + GatherLoweringHelper helper(op); + // Specialize the lowering based on the source layout. Given that the cost of + // a warp shuffle is approximately half the cost of a roundtrip to shared + // memory with zero bank conflicts, we will need a more precise heuristic to + // choose between the two codegen paths and rely on the middle end to pick the + // right layout. + if (helper.isWarpLocal()) { + emitWarpLocalGather(op, adaptor, rewriter); + } else { + emitGatherInShared(op, adaptor, rewriter); + } + return success(); +} + +static Value convertIndexToI32(Location loc, Value index, + ConversionPatternRewriter &rewriter) { + unsigned idxWidth = index.getType().getIntOrFloatBitWidth(); + // The LL index computations are performed with 32 bit integers. If the + // indices are something else, cast them to i32. + if (idxWidth > 32) { + index = trunc(i32_ty, index); + } else if (idxWidth < 32) { + // Negative indices don't make sense, so zero-extend. + index = zext(i32_ty, index); + } + return index; +} + +void GatherOpConversion::emitGatherInShared( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); RankedTensorType srcType = op.getSrc().getType(); @@ -78,19 +119,10 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType, /*withCTAOffset=*/true); - unsigned idxWidth = op.getIndices().getType().getElementTypeBitWidth(); unsigned axis = op.getAxis(); SmallVector results(dstIndices.size()); for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) { - // The LL index computations are performed with 32 bit integers. If the - // indices are something else, cast them to i32. - if (idxWidth > 32) { - idx = trunc(i32_ty, idx); - } else if (idxWidth < 32) { - // Negative indices don't make sense, so zero-extend. - idx = zext(i32_ty, idx); - } - indices[axis] = idx; + indices[axis] = convertIndexToI32(loc, idx, rewriter); Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); Value ptr = gep(smemBase.getType(), elemType, smemBase, offset); results[i] = load(elemType, ptr); @@ -99,7 +131,224 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, Value packed = packLLElements(loc, getTypeConverter(), results, rewriter, dstType); rewriter.replaceOp(op, packed); - return success(); +} + +// High-level description of the algorithm: +// +// `isWarpLocal` checks that it is possible to compute each output element +// without data movement across warps. +// +// If the gather dim is `dimN`, then this means +// +// ll^-1(dimN)[(block, warp)] == 0 +// +// for both source and index tensors: moving along the gather axis does not +// change the warp. Broadcasted layouts are not supported, so we know the +// layouts are permutation matrices. +// +// We can check this with `ll((block, warp))[dimN] == 0`. +// +// Let `gatherCol` be a tuple of all dimensions except the gather dimension. +// We also check that the gather columns line up the same way with respect to +// the warp between the source and index tensors with +// +// ll_src((block, warp))[gatherCol] == ll_idx((block, warp))[gatherCol] +// +// This means that for all index columns, the corresponding column in the source +// tensor is owned by the same warp. +// +// We also check +// +// ll_src(lane)[gatherCol] == ll_idx(lane)[gatherCol] +// +// This boils down to the fact that the algorithm essentially emits a series of +// index shuffles for each index value owned by each thread, and then a pile of +// selects to pick the right value. We need to figure out given an index value +// in a particular column, what are the source register values it could read +// from and who owns them. +// +// If this relationship did not hold, then the possible source registers for +// each index value varies with the thread, meaning the value operand provided +// to each shuffle index instruction would depend on the thread ID. This isn't a +// big deal. It just means would have to emit a pile of selects before each +// shuffle as well, to pick the right source register value. But we choose not +// to handle this. +// +// The codegen algorithm emits code: +// - Given the thread ID and a particular index tensor register, figure out +// which gather column it belongs to using a layout. +// - Using the index value itself as the value for `dimN`, use another layout to +// figure out which lane in the warp owns the desired value and which register +// in that lane it is. +// - For the gather column, figure out the source registers in that column, and +// for each of them, emit an index shuffle with the same computed lane ID. +// - Use the register component to select the right value from the shuffle +// results. +void GatherOpConversion::emitWarpLocalGather( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + RankedTensorType srcType = op.getSrc().getType(); + RankedTensorType idxType = op.getIndices().getType(); + + // Layout dimension names. + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); + StringAttr kGatherDim = rewriter.getStringAttr("dim" + Twine(op.getAxis())); + SmallVector allDims, otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + allDims.push_back(str_attr("dim" + Twine(dim))); + if (dim != op.getAxis()) { + otherDims.push_back(allDims.back()); + } + } + + // Compute the src and idx layouts. + LinearLayout srcLayout = + *toLinearLayout(srcType.getShape(), srcType.getEncoding()); + LinearLayout idxLayout = + *toLinearLayout(idxType.getShape(), idxType.getEncoding()); + + // Let `ll_src` be the source layout and `ll_idx` be the index layout. + // Let `src_col` be a tuple of dimensions except the gather dimension, + // representing a specific column in the source tensor. Likewise for + // `idx_col`. Let `src_idx` be the index into gather dimension in the source + // tensor. + // + // `(src_lane, src_reg) = ll_src^-1(src_col, src_idx)`, where `src_lane` is + // the thread that contains the required element and `src_reg` is the register + // within that thread. + // + // Because `ll_src(block=0, warp=0, lane=0)[otherDims] == + // ll_idx(0, 0, 0)[otherDims]`, we know given any `idx_reg` (element in the + // index tensor) the thread will need to read from the same column in the + // source tensor. + // + // Thus, we can obtain + // + // (src_lane, src_reg) = (ll_src^-1)( + // ll_idx(black, warp, lane, idx_reg)[otherDims], + // idxValues[idx_reg] + // )[{"lane", "register"}] + // + // And the mapping will be the correct for each thread. + // + // Given `src_reg \in [0, K*N)`, we just need to emit N index shuffles for + // each `idx_reg` (the number of index shuffles is quadratic!) and + // `llvm.select` using `src_reg` to get the right one. `K` is the number of + // elements per column owned by a thread. + + // Fully invert the source layout. We know it is invertible because + // `isWarpLocal` checked this. + LinearLayout invSrcLayout = srcLayout.invert(); + + // Sanity check: the warp must be invariant to the index because otherwise the + // gather would need to read across warps! + assert(invSrcLayout.sublayoutIsZero(kGatherDim, {kBlock, kWarp}) && + "expected a warp-local gather"); + invSrcLayout = invSrcLayout.sublayout(allDims, {kLane, kRegister}); + + LinearLayout idxColLayout = + idxLayout.sublayout({kBlock, kWarp, kLane, kRegister}, otherDims); + + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + + auto [laneId, warpId, blockId] = + emitHardwareTuple(loc, rewriter, targetInfo, /*withCTAOffset=*/true, + srcLayout.getInDimSize(kLane)); + + unsigned /*N=*/srcRegsPerThread = srcLayout.getInDimSize(kRegister); + assert(srcRegsPerThread == srcValues.size()); + + // Given a index value, we need to know which sources register values it could + // index into. This is invariant to anything other than the register, which we + // checked already. Compute the full reverse map from + // + // idx_reg -> gather_column -> (src_reg0, src_reg1, ...) + // + LinearLayout invertSrcRegMap = invSrcLayout.sublayout(allDims, {kRegister}); + // Remove zero bases in the gather dimension to make the function injective + // (for a given column) over the same codomain. + LinearLayout::BasesT newInvertRegMapBases; + for (auto &[inDim, inDimBases] : invertSrcRegMap.getBases()) { + auto &newInDimBases = newInvertRegMapBases[inDim]; + if (inDim != kGatherDim) { + newInDimBases = inDimBases; + continue; + } + for (auto &basis : inDimBases) { + if (llvm::any_of(basis, [](int32_t val) { return val != 0; })) { + newInDimBases.push_back(basis); + } + } + } + invertSrcRegMap = LinearLayout( + newInvertRegMapBases, llvm::to_vector(invertSrcRegMap.getOutDimNames())); + // We are left with only non-zero bases in the gather dimension, which means + // the number of registers per column is the size of the "gather dimension". + unsigned numRegsPerColumn = invertSrcRegMap.getInDimSize(kGatherDim); + // Get a map from idx_reg to the column it indexes into. + LinearLayout idxRegToCol = idxLayout.sublayout({kRegister}, otherDims); + // Now given `idx_reg`, we can compute the column it belongs to in both src + // and index tensors, then partially apply `invertSrcRegMap` with this to + // obtain a function that outputs the corresponding registers in the src + // tensor in the same column. + + // L(column, i) = L(column, 0) xor L(0, i) + LinearLayout invertSrcRegMapColPart = + invertSrcRegMap.sublayout(otherDims, {kRegister}); + LinearLayout invertSrcRegMapRest = + invertSrcRegMap.sublayout({kGatherDim}, {kRegister}); + + SmallVector results; + for (auto [idxReg, idxVal] : llvm::enumerate(idxValues)) { + SmallVector> column = + applyLinearLayout(loc, rewriter, idxColLayout, + {{kBlock, blockId}, + {kWarp, warpId}, + {kLane, laneId}, + {kRegister, i32_val(idxReg)}}); + assert(column.size() == otherDims.size()); + + // Combine the computed column with the data-dependent gather index. + column.emplace_back(kGatherDim, convertIndexToI32(loc, idxVal, rewriter)); + SmallVector> srcLaneAndReg = + applyLinearLayout(loc, rewriter, invSrcLayout, column); + + auto [srcLaneName, srcLane] = srcLaneAndReg.back(); + auto [srcRegName, srcReg] = srcLaneAndReg.front(); + assert(srcLaneName == kLane && srcRegName == kRegister); + + assert(!srcValues.empty() && "can't gather from an empty tensor"); + + // Figure out which src registers we need to index shuffle from. This is + // invariant to anything else. + SmallVector> normalizedColumn = + idxRegToCol.apply({{kRegister, idxReg}}); + int32_t srcBase = + invertSrcRegMapColPart.apply(normalizedColumn).front().second; + + Value result = undef(srcValues.front().getType()); + for (unsigned i = 0; i != numRegsPerColumn; ++i) { + int32_t rest = + invertSrcRegMapRest.apply({{kGatherDim, i}}).front().second; + int32_t srcRegIdx = srcBase ^ rest; + + Value value = + targetInfo.shuffleIdx(rewriter, loc, srcValues[srcRegIdx], srcLane); + result = select(icmp_eq(i32_val(srcRegIdx), srcReg), value, result); + } + + results.push_back(result); + } + + rewriter.replaceOp(op, packLLElements(loc, getTypeConverter(), results, + rewriter, op.getType())); } } // namespace diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 02b3b121f4..eb2c82cfc9 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -110,7 +110,7 @@ std::tuple emitHardwareTuple(Location loc, Value warpId = udiv(threadId, threadsPerWarp); Value blockId = withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); - return {blockId, warpId, laneId}; + return {laneId, warpId, blockId}; } SmallVector> @@ -130,7 +130,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, StringAttr kWarp = str_attr("warp"); StringAttr kBlock = str_attr("block"); - auto [blockId, warpId, laneId] = emitHardwareTuple( + auto [laneId, warpId, blockId] = emitHardwareTuple( loc, rewriter, target, withCTAOffset, ll->getInDimSize(kLane)); unsigned rank = shape.size(); SmallVector> ret; @@ -353,7 +353,7 @@ bool emitTransferBetweenRegistersAndShared( std::min(regToSharedLayout->getNumConsecutiveInOut(), maxVecElems.value_or(std::numeric_limits::max())); - auto [blockId, warpId, laneId] = + auto [laneId, warpId, blockId] = emitHardwareTuple(loc, rewriter, target, /*withCTAOffset=*/false, regToSharedLayout->getInDimSize(kLane)); @@ -746,7 +746,7 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, auto instrShape = mmaLayout.getInstrShape(); SmallVector mmaColIdx(2); SmallVector mmaRowIdx(2); - auto [blockId, warpId, laneId] = emitHardwareTuple( + auto [laneId, warpId, blockId] = emitHardwareTuple( loc, rewriter, targetInfo, /*withCTAOffset=*/false, 32); // TODO: fix the bug in MMAEncodingAttr document SmallVector multiDimWarpId(2); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 548702fdc0..fc037cd26f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -1,28 +1,24 @@ #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/IRMapping.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" -#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" -#include - #define DEBUG_TYPE "triton-matmul-loop-pipeline" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") @@ -70,6 +66,30 @@ class OpBuilderWithStage : public OpBuilder { using OpBuilder::create; }; +class OpBuilderForStage : public OpBuilder { + std::optional stage_, cluster_; + +public: + explicit OpBuilderForStage(Operation *op, int stage, int cluster) + : OpBuilder(op, nullptr), stage_(stage), cluster_(cluster) {} + explicit OpBuilderForStage(Operation *op) : OpBuilder(op, nullptr) { + auto sc = tt::maybeGetStageCluster(op); + if (sc) { + stage_ = sc->first; + cluster_ = sc->second; + } + } + + template OpTy create(Args &&...args) { + OpTy op = OpBuilder::create(std::forward(args)...); + + if (stage_ && cluster_) { + tt::setStageCluster(op, *stage_, *cluster_); + } + return op; + } +}; + static bool sameStageCluster(Operation *op1, Operation *op2) { auto [s1, c1] = tt::getStageCluster(op1); auto [s2, c2] = tt::getStageCluster(op2); @@ -708,7 +728,128 @@ getFinalSchedule(scf::ForOp &forOp, int numStages) { return fSchedule; } -// Convert load ops into their asyn version and apply multi-buffering based on +LogicalResult +allocTMABuffers(scf::ForOp forOp, + llvm::MapVector &tmaBufferMapping, + int numStages) { + IRRewriter rewriter(forOp); + + // Create a multi-buffered allocation for each MakeTensorDescOp call in the + // loop + forOp.walk([&](tt::MakeTensorDescOp op) { + // TODO peter: walk to loop yield to find the init value if this is a + // loop-carried value. That would save us from allocating another buffer + // just for the init value + auto loc = op.getLoc(); + Value alloc = rewriter.create( + loc, triton::getPointerType(rewriter.getI8Type()), + numStages * ttng::TMA_SIZE_BYTES, ttng::TMA_ALIGN); + tmaBufferMapping[op.getOperation()] = alloc; + }); + return success(); +} + +template +Value createIncrementModulo(BuilderT &builder, Location loc, Value counter, + Value modulus, Value zero, Value one) { + Value addOne = builder.template create(loc, counter, one); + Value inRangeCond = builder.template create( + loc, arith::CmpIPredicate::slt, addOne, modulus); + return builder.template create(loc, inRangeCond, addOne, + zero); +} + +template +Value subviewTMADescriptor(BuilderT &builder, Location loc, Value alloc, + Value counter) { + Value tmaSizeVal = builder.template create( + loc, ttng::TMA_SIZE_BYTES, 32); + Value offset = + builder.template create(loc, tmaSizeVal, counter); + return builder.template create(loc, alloc.getType(), alloc, + offset); +} + +LogicalResult rewriteTMABufferUpdates( + scf::ForOp forOp, + const llvm::MapVector &tmaBufferMapping, + ArrayRef tmaCounters, int numStages, Value one, Value zero) { + assert(tmaBufferMapping.size() == tmaCounters.size()); + + Value numStagesVal = mlir::OpBuilder(forOp).create( + forOp.getLoc(), numStages, 32); + + for (auto [iOp, pair] : llvm::enumerate(tmaBufferMapping)) { + auto &[op, alloc] = pair; + + // Rewriter MakeTensorDescOp as writing a TMA descriptor + auto makeDescOp = cast(op); + + OpBuilderForStage stageBuilder(makeDescOp); + auto loc = makeDescOp.getLoc(); + + BlockArgument counter = tmaCounters[iOp]; + Value nextBuf = subviewTMADescriptor(stageBuilder, loc, alloc, counter); + if (failed(ttng::createTMADesc(nextBuf, makeDescOp, stageBuilder))) { + return failure(); + } + stageBuilder.create( + loc, nextBuf); + Value nextDesc = stageBuilder.create( + loc, makeDescOp.getType(), nextBuf); + + makeDescOp.getResult().replaceAllUsesWith(nextDesc); + + // Increment the buffer index counter + Value nextCounter = createIncrementModulo(stageBuilder, loc, counter, + numStagesVal, zero, one); + + // If we are in a (potentially nested) if region, propagate the counter + // up to the main for op body scope + Operation *curOp = op; + Operation *parent = op->getParentOp(); + while (parent != forOp.getOperation()) { + auto ifOp = dyn_cast(parent); + if (!ifOp) { + std::string msg; + llvm::raw_string_ostream ss(msg); + ss << "Cannot pipeline MakeTensorDescOp inside:\n"; + parent->print(ss); + ss << "\nOnly scf.if regions are supported"; + return makeDescOp->emitOpError(std::move(msg)); + } + + IRRewriter rewriter(parent); + auto newIfOp = + replaceIfOpWithNewSignature(rewriter, ifOp, {nextCounter.getType()}); + + auto yieldNewBlock = newIfOp.thenBlock(); + auto yieldOldBlock = newIfOp.elseBlock(); + + if (yieldNewBlock != curOp->getBlock()) { + std::swap(yieldNewBlock, yieldOldBlock); + } + cast(yieldNewBlock->getTerminator()) + .getResultsMutable() + .append(nextCounter); + cast(yieldOldBlock->getTerminator()) + .getResultsMutable() + .append(counter); + + ifOp.erase(); + nextCounter = newIfOp.getResults().back(); + curOp = newIfOp; + parent = newIfOp->getParentOp(); + } + + // Finally, rewrite the loop level yield + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield.setOperand(counter.getArgNumber() - 1, nextCounter); + } + return success(); +} + +// Convert load ops into their async version and apply multi-buffering based on // the required number of buffers. static SmallVector createAsyncOps(scf::ForOp &forOp, @@ -732,6 +873,11 @@ createAsyncOps(scf::ForOp &forOp, numBuffers++; }; + llvm::MapVector tmaBufferMapping; + if (failed(allocTMABuffers(forOp, tmaBufferMapping, numStages))) { + llvm_unreachable("TMA pipelining failed"); + } + SmallVector asyncLoads; SmallVector allocs; bool hasTMALoad = false; @@ -755,7 +901,10 @@ createAsyncOps(scf::ForOp &forOp, builder.setInsertionPoint(forOp); Location loc = forOp.getLoc(); - // Create two new counters to index into the allocs. + // Create a counter to index into the allocations per loop iteration. + // NOTE: We create two duplicates values, insertIdx and extractIdx so that the + // pipeliner will re-materialize the value in later stages of the pipeline + // instead of carrying it as a dependency across multiple iterations. Value minusOne = builder.create(loc, -1, 32); Value zero = builder.create(loc, 0, 32); Value one = builder.create(loc, 1, 32); @@ -768,9 +917,19 @@ createAsyncOps(scf::ForOp &forOp, newOperands.push_back(insertIdx); newOperands.push_back(extractIdx); if (hasTMALoad) { + // A single barrier arrival sequence is a "phase" and two phases can + // overlap, provided the phases are differentiated with an alternating + // boolean value. phase = builder.create(loc, 0, 32); newOperands.push_back(phase); } + // Also create one counter per TMA buffer. This allows the descriptors to be + // updated independently without needing to write duplicate of existing tma + // descriptors. + for (int i = 0; i < tmaBufferMapping.size(); ++i) { + newOperands.push_back(zero); + } + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); // Patch the loop to add the new loop carried dependencies. scf::ForOp newForOp = @@ -782,6 +941,20 @@ createAsyncOps(scf::ForOp &forOp, if (phase) { phase = newForOp.getBody()->getArgument(newOperandIndex + 2); } + auto tmaCounters = ArrayRef(newForOp.getBody()->getArguments()) + .slice(newOperandIndex + (phase ? 3 : 2)); + + // Update yield op with temporary yield values + auto forYield = cast(newForOp.getBody()->getTerminator()); + for (unsigned i = 0; i < newOperands.size(); ++i) { + forYield.getResultsMutable().append(newOperands[i]); + } + + if (failed(rewriteTMABufferUpdates(newForOp, tmaBufferMapping, tmaCounters, + numStages, one, zero))) { + llvm_unreachable("Failed to rewrite TMA ops"); + } + tmaBufferMapping.clear(); // FIXME: loads can be in different (stage, cluster) // Create two counters for the insert and extract indices to avoid creating @@ -815,11 +988,12 @@ createAsyncOps(scf::ForOp &forOp, loadToInfo, numStages); } } - SmallVector newYieldOperands = {insertIdx, extractIdx}; - if (phase) - newYieldOperands.push_back(phase); // Patch the yield with the updated counters. - appendToForOpYield(forOp, newYieldOperands); + forYield.setOperand(newOperandIndex + -1, insertIdx); + forYield.setOperand(newOperandIndex + 0, extractIdx); + if (phase) { + forYield.setOperand(newOperandIndex + 1, phase); + } tt::CoarseSchedule coarseSchedule(numStages); coarseSchedule.deSerialize(forOp); @@ -957,12 +1131,11 @@ static int minNumInterleavedCommitOps(Operation *waitOp) { if (thisHistorySum >= minCommitNumber) return minCommitNumber; - // get the value value assigned to the argument coming from outside the - // loop + // get the value assigned to the argument coming from outside the loop Value incomingVal = forOp.getInitArgs()[arg.getArgNumber() - 1]; int min1 = minOverHistories(incomingVal, forOp, thisHistorySum); - // get the value value assigned to the argument coming from the previous + // get the value assigned to the argument coming from the previous // iteration Operation *yieldOp = block->getTerminator(); Value prevVal = yieldOp->getOperand(arg.getArgNumber() - 1); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index aab5607707..3ab65c6105 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -7,6 +7,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Casting.h" using namespace mlir; namespace tt = mlir::triton; @@ -198,15 +199,23 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, op->erase(); } -std::pair mlir::triton::getStageCluster(Operation *op) { - auto stage = cast(op->getAttr(mlir::triton::kLoopStageAttrName)) - .getValue() - .getSExtValue(); +std::optional> +mlir::triton::maybeGetStageCluster(Operation *op) { + auto stage = + dyn_cast_if_present(op->getAttr(tt::kLoopStageAttrName)); auto clusterId = - cast(op->getAttr(mlir::triton::kLoopClusterAttrName)) - .getValue() - .getSExtValue(); - return std::make_pair(stage, clusterId); + dyn_cast_if_present(op->getAttr(tt::kLoopClusterAttrName)); + if (!stage || !clusterId) { + return std::nullopt; + } + + return { + {stage.getValue().getSExtValue(), clusterId.getValue().getSExtValue()}}; +} +std::pair mlir::triton::getStageCluster(Operation *op) { + auto res = maybeGetStageCluster(op); + assert(res.has_value() || "Operation is missing stage & cluster attribute"); + return *res; } void mlir::triton::setStageCluster(Operation *op, int stage, int cluster) { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index cb9ae9dd0f..f412755d55 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -2,13 +2,13 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/Passes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" #include @@ -115,87 +115,11 @@ class TMACreateDescLowering : public OpRewritePattern { PatternRewriter &rewriter) const override { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); - constexpr auto kTmaNbytes = 128; - constexpr auto kTmaAlignment = 128; auto alloc = rewriter.create( - loc, getPointerType(rewriter.getI8Type()), kTmaNbytes, kTmaAlignment); - auto mkI32Constant = [&](int32_t val) { - return rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(val)); - }; - - auto elemType = op.getBase().getType().getPointeeType(); - auto elemSize = elemType.getIntOrFloatBitWidth() / 8; - - int32_t contig_dim_size = op.getTensorShape().back(); - int32_t contig_dim_size_in_bytes = contig_dim_size * elemSize; - if (contig_dim_size_in_bytes > 128) { - contig_dim_size = 128 / elemSize; - } - llvm::SmallVector boxDim; - boxDim.push_back(mkI32Constant(contig_dim_size)); - for (int k = op.getTensorShape().size() - 2; k >= 0; --k) { - boxDim.push_back(mkI32Constant(op.getTensorShape()[k])); - } - - int32_t swizzle_mode; - if (contig_dim_size_in_bytes >= 128) { - swizzle_mode = 3; - } else if (contig_dim_size_in_bytes == 64) { - swizzle_mode = 2; - } else if (contig_dim_size_in_bytes == 32) { - swizzle_mode = 1; - } else { - op->emitError() - << "contiguous box dimension must be at least 32 bytes but got " - << contig_dim_size_in_bytes; - return failure(); - } - - Value elemSizeVal = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(elemSize)); - Value globalStride = - rewriter.create(loc, op.getStrides()[0], elemSizeVal); - // TODO: Workaround for ptxas bug, remove when we update ptxas - Value four = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(4)); - globalStride = rewriter.create(loc, globalStride, four); - - int elemTypeEnum; - switch (elemSize) { - case 1: { - elemTypeEnum = 0; - break; - } - case 2: { - elemTypeEnum = 1; - break; - } - case 4: { - elemTypeEnum = 2; - break; - } - default: { - op->emitError() - << "Tensor descriptor element type must have size 1, 2, or 4 but got " - << elemSize; + loc, getPointerType(rewriter.getI8Type()), TMA_SIZE_BYTES, TMA_ALIGN); + if (failed(createTMADesc(alloc, op, rewriter))) { return failure(); } - } - - auto one = mkI32Constant(1); - rewriter.create( - loc, - /*desc_ptr=*/alloc.getResult(), - /*global_address=*/op.getBase(), - /*box_dim=*/boxDim, - /*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]}, - /*global_stride=*/ValueRange{globalStride}, - /*element_strides=*/ValueRange{one, one}, - /*elem_type*/ rewriter.getI32IntegerAttr(elemTypeEnum), - /*interleave_layout*/ rewriter.getI32IntegerAttr(0), - /*swizzle_mode=*/rewriter.getI32IntegerAttr(swizzle_mode), - /*fill_mode=*/rewriter.getI32IntegerAttr(0)); rewriter.create( loc, alloc.getResult()); auto newDesc = rewriter.create( diff --git a/python/src/passes.cc b/python/src/passes.cc index 235eba4465..b0efc3cb88 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -31,6 +31,7 @@ void init_triton_passes_common(py::module &&m) { ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass); ADD_PASS_WRAPPER_0("add_cse", createCSEPass); ADD_PASS_WRAPPER_0("add_licm", createLoopInvariantCodeMotionPass); + ADD_PASS_WRAPPER_0("print_ir", createPrintIRPass); } void init_triton_passes_ttir(py::module &&m) { diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 23065953d6..c32c017dc7 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -538,3 +538,107 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): ) torch.testing.assert_close(ref_out, A) assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm["ptx"] + + +@triton.jit +def batched_gemm_kernel(a_ptr, b_ptr, c_ptr, # + B, M, N, K, # + dtype: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_SMS: tl.constexpr): + start_pid = tl.program_id(axis=0) + num_tiles_m = tl.cdiv(M, BLOCK_M) + num_tiles_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles_per_batch = num_tiles_m * num_tiles_n + num_tiles = B * num_tiles_per_batch + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + tile_m = 0 + tile_n = 0 + tile_b = 0 + + offs_m = 0 + offs_n = 0 + offs_b = 0 + + a_desc = tl._experimental_make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1], [BLOCK_M, BLOCK_K]) + b_desc = tl._experimental_make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1], [BLOCK_N, BLOCK_K]) + c_desc = tl._experimental_make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1], [BLOCK_M, BLOCK_N]) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for _ in range(k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + tile_b = tile_id // num_tiles_per_batch + tile_m = (tile_id // num_tiles_n) % num_tiles_m + tile_n = tile_id % num_tiles_n + + offs_b = tile_b + offs_m = tile_m * BLOCK_M + offs_n = tile_n * BLOCK_N + + a_desc = tl._experimental_make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1], + [BLOCK_M, BLOCK_K]) + b_desc = tl._experimental_make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1], + [BLOCK_N, BLOCK_K]) + c_desc = tl._experimental_make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1], + [BLOCK_M, BLOCK_N]) + + offs_k = ki * BLOCK_K + + a = a_desc.load([offs_m, offs_k]) + b = b_desc.load([offs_n, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + if ki == k_tiles - 1: + c = accumulator.to(dtype) + + c_desc.store([offs_m, offs_n], c) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + +@requires_tma +def test_tensor_descriptor_batched_gemm(): + device = "cuda" + B, M, N, K = 2, 1024, 1024, 128 + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 64 + NUM_SMS = 96 + num_stages = 3 + + grid = (min(NUM_SMS, B * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), ) + + a = torch.randn((B, M, K), device=device, dtype=torch.float16) + b = torch.randn((B, N, K), device=device, dtype=torch.float16) + c = torch.empty((B, M, N), device=device, dtype=torch.float16) + + expect = torch.bmm(a, b.mT) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + # TODO: should only need num_stages * 3 descriptors per SM + assert size == 128 * 3 * (num_stages + 1) * grid[0] + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + h = batched_gemm_kernel[grid]( + a, b, c, # + B, M, N, K, # + tl.float16, # + BLOCK_M, BLOCK_N, BLOCK_K, # + NUM_SMS, # + num_stages=num_stages, num_warps=8) + print(h.n_regs) + torch.cuda.synchronize() + + torch.testing.assert_close(c, expect, rtol=1e-3, atol=1e-3) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 4e908ace31..2b8134373c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6247,6 +6247,24 @@ def kernel(In, Out, # assert torch.all(ref == result) +@triton.jit +def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, + src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, + idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, + out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, + out_stride1: tl.constexpr): + src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) + src = tl.load(src_ptr + src_offs) + + idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) + tl.store(out_ptr + out_offs, out) + + @pytest.mark.parametrize("src_shape, indices_shape, axis", [ ([4, 4], [8, 4], 0), ([128, 64], [256, 64], 0), @@ -6256,29 +6274,13 @@ def test_gather(src_shape, indices_shape, axis, device): if is_xpu(): pytest.skip("Fail on XPU") - @triton.jit - def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, - src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, - idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, - out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, - out_stride1: tl.constexpr): - src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) - src = tl.load(src_ptr + src_offs) - - idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) - idx = tl.load(idx_ptr + idx_offs) - - out = tl.gather(src, idx, axis) - - out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) - tl.store(out_ptr + out_offs, out) - def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) - gather_kernel[(1, )](src, indices, output, axis, src.shape[0], src.shape[1], - src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), - indices.stride(1), output.shape[0], output.shape[1], output.stride(0), output.stride(1)) + gather_test_kernel[(1, )](src, indices, output, axis, src.shape[0], + src.shape[1], src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], + indices.stride(0), indices.stride(1), output.shape[0], output.shape[1], + output.stride(0), output.stride(1)) return output @@ -6287,3 +6289,79 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): ref = torch.gather(src, axis, indices) result = triton_gather(src, axis, indices) torch.testing.assert_close(result, ref, rtol=0, atol=0) + + +# These layouts are specially chosen to trigger the warp shuffle codegen. +@pytest.mark.parametrize("src_shape, indices_shape, axis, src_layout, indices_layout", [ + ([32, 16], [32, 16], 0, + "linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>", + "linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}>" + ), + ([128, 64], [256, 64], 0, + "linear<{register = [[0, 2], [32, 0], [2, 0], [0, 16], [0, 32], [64, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>", + "linear<{register = [[0, 2], [32, 0], [0, 32], [2, 0], [0, 16], [64, 0], [128, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>" + ), +]) +def test_gather_warp_shuffle(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path, + device): + if is_xpu(): + pytest.skip("warp-local gather has issues on XPU") + if is_hip(): + pytest.skip("warp-local gather has issues on HIP") + + def prepare_kernel(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + compiled = gather_test_kernel.warmup(src, indices, output, axis, src.shape[0], src.shape[1], src.stride(0), + src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), + indices.stride(1), output.shape[0], output.shape[1], output.stride(0), + output.stride(1), grid=(1, )) + return output, compiled + + def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout, idx_layout): + ir = f""" +#src_layout = #ttg.{src_layout} +#idx_layout = #ttg.{idx_layout} +{ir}""" + + dtypes = {torch.int32: "i32", torch.float32: "f32", torch.int64: "i64", torch.float64: "f64"} + + src_spec = f"{src.shape[0]}x{src.shape[1]}x{dtypes[src.dtype]}" + indices_spec = f"{indices.shape[0]}x{indices.shape[1]}x{dtypes[indices.dtype]}" + output_spec = f"{indices.shape[0]}x{indices.shape[1]}x{dtypes[src.dtype]}" + + pat = r"(%[0-9]+) = tt.gather (%[0-9]+)\[(%[0-9]+)\] {axis = " + pat += str(axis) + pat += r" : i32} : \(tensor\<" + pat += src_spec + pat += r", (#[a-z]+[0-9]+)\>, tensor\<" + pat += indices_spec + pat += r", (#[a-z]+[0-9]+)\>\) -> tensor\<" + pat += output_spec + pat += r", (#[a-z]+[0-9]+)\>" + + repl = r""" + %src = ttg.convert_layout \2 : tensor<""" + src_spec + r""", \4> -> tensor<""" + src_spec + r""", #src_layout> + %idx = ttg.convert_layout \3 : tensor<""" + indices_spec + r""", \5> -> tensor<""" + indices_spec + r""", #idx_layout> + %out = tt.gather %src[%idx] {axis = """ + str( + axis + ) + r""" : i32} : (tensor<""" + src_spec + r""", #src_layout>, tensor<""" + indices_spec + r""", #idx_layout>) -> tensor<""" + output_spec + r""", #idx_layout> + \1 = ttg.convert_layout %out : tensor<""" + output_spec + r""", #idx_layout> -> tensor<""" + output_spec + r""", \6>""" + return re.sub(pat, repl, ir) + + src = torch.randn(src_shape, device=device) + indices = torch.randint(0, src.shape[axis], indices_shape, device=device) + ref = torch.gather(src, axis, indices) + + output, compiled = prepare_kernel(src, axis, indices) + ir = compiled.asm["ttgir"] + ir = inject_layout(ir, src, axis, indices, src_layout, indices_layout) + + temp_file = tmp_path / "test_warp_gather.ttgir" + temp_file.write_text(ir) + + kernel = triton.compile(str(temp_file)) + assert ("nvvm.shfl.sync.idx" in kernel.asm["llir"]) or ("llvm.amdgcn.ds.bpermute" in kernel.asm["llir"]) + + kernel[(1, 1, 1)](src, indices, output) + + torch.testing.assert_close(output, ref, rtol=0, atol=0) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 9c7e9f28a4..c9427c78fd 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -171,9 +171,9 @@ def parse(full_name, ext, context): module = ir.parse_mlir_module(full_name, context) module.context = context return module - if ext == "llir" or ext == "ptx": + if ext == "llir" or ext == "ptx" or ext == "amdgcn": return Path(full_name).read_text() - if ext == "cubin": + if ext == "cubin" or ext == "hsaco": return Path(full_name).read_bytes() diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 49a8bb32c4..4b4a08857d 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -28,6 +28,8 @@ import triton.profiler as proton from contextlib import contextmanager +from typing import Optional + if torch.cuda.is_available(): from triton._C.libtriton import nvidia cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) @@ -374,8 +376,7 @@ def matmul_tma_persistent(a, b): @triton.jit(launch_metadata=_matmul_launch_metadata) -def matmul_kernel_device_tma_persistent(workspace_ptr, # - tiles_per_update: tl.constexpr, # +def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, # a_ptr, b_ptr, c_ptr, # M, N, K, # BLOCK_SIZE_M: tl.constexpr, # @@ -391,24 +392,24 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - TMA_SIZE: tl.constexpr = 128 - workspace_base = workspace_ptr + start_pid * 3 * TMA_SIZE - a_desc_ptr = workspace_base - b_desc_ptr = workspace_base + TMA_SIZE - c_desc_ptr = workspace_base + 2 * TMA_SIZE - - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K], - element_ty=a_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr, - load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K], - element_ty=b_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N], - element_ty=c_ptr.dtype.element_ty) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + b_desc = tl._experimental_make_tensor_descriptor( + b_ptr, + shape=[N, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + c_desc = tl._experimental_make_tensor_descriptor( + c_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) tiles_per_SM = num_tiles // NUM_SMS if start_pid < num_tiles % NUM_SMS: @@ -426,6 +427,9 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # num_pid_in_group = GROUP_SIZE_M * num_pid_n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # Create an opaque value to prevent the descriptor creation from being + # hoisted out of the loop + zero = tl.inline_asm_elementwise("mov.b32 $0, 0;", "=r", [], dtype=tl.int32, is_pure=True, pack=1) for _ in range(0, k_tiles * tiles_per_SM): ki = tl.where(ki == k_tiles - 1, 0, ki + 1) @@ -434,21 +438,24 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # # Simulate a grouped gemm if ni == tiles_per_update: - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr, - load_size=[BLOCK_SIZE_M, - BLOCK_SIZE_K], global_size=[M, K], - element_ty=a_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr, - load_size=[BLOCK_SIZE_N, - BLOCK_SIZE_K], global_size=[N, K], - element_ty=b_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr, - load_size=[BLOCK_SIZE_M, - BLOCK_SIZE_N], global_size=[M, N], - element_ty=c_ptr.dtype.element_ty) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr + zero, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + b_desc = tl._experimental_make_tensor_descriptor( + b_ptr + zero, + shape=[N, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + c_desc = tl._experimental_make_tensor_descriptor( + c_ptr + zero, + shape=[M, N], + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) ni = 0 tile_id += NUM_SMS @@ -463,19 +470,19 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # offs_k = ki * BLOCK_SIZE_K - a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) - b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype) + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) accumulator = tl.dot(a, b.T, accumulator) if ki == k_tiles - 1: c = accumulator.to(dtype) - tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) + c_desc.store([offs_am, offs_bn], c) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) -def matmul_device_tma_persistent(a, b, tiles_per_update): +def matmul_descriptor_persistent(a, b, tiles_per_update): # Autotuner does not work with TMA. Use manual config. configs = { torch.float8_e4m3fn: { @@ -497,12 +504,15 @@ def matmul_device_tma_persistent(a, b, tiles_per_update): c = torch.empty((M, N), device=a.device, dtype=dtype) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - tma_size = 128 - workspace = torch.empty(NUM_SMS * 3 * tma_size, dtype=torch.uint8, device="cuda") + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) - matmul_kernel_device_tma_persistent[grid]( - workspace, # + matmul_kernel_descriptor_persistent[grid]( tiles_per_update, # a, b, c, # M, N, K, # @@ -576,7 +586,7 @@ def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000): bench_fn(reps, warmup_reps, matmul_persistent, a, b.T) if supports_tma(): bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b) - bench_fn(reps, warmup_reps, matmul_device_tma_persistent, a, b, tiles_per_update) + bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b, tiles_per_update) def validate(M, N, K, dtype, tiles_per_update): @@ -589,7 +599,7 @@ def validate(M, N, K, dtype, tiles_per_update): naive_result = matmul(a, b.T) persistent_result = matmul_persistent(a, b.T) tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None - device_tma_persistent_result = matmul_device_tma_persistent(a, b, tiles_per_update) if supports_tma() else None + descriptor_persistent_result = matmul_descriptor_persistent(a, b, tiles_per_update) if supports_tma() else None if torch_result is not None: naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16), @@ -602,9 +612,9 @@ def validate(M, N, K, dtype, tiles_per_update): if tma_persistent_result is not None: naive_vs_tma_persistent = "✅" if torch.allclose(cublas_result.to(torch.float16), tma_persistent_result.to(torch.float16), atol=1.0) else "❌" - if device_tma_persistent_result is not None: - naive_vs_device_tma_persistent = "✅" if torch.allclose(cublas_result.to( - torch.float16), device_tma_persistent_result.to(torch.float16), atol=1.0) else "❌" + if descriptor_persistent_result is not None: + naive_vs_descriptor_persistent = "✅" if torch.allclose(cublas_result.to( + torch.float16), descriptor_persistent_result.to(torch.float16), atol=1.0) else "❌" print(f"M={M}, N={N}, K={K} verification naive vs: ", end="") if torch_result is not None: print(f"torch: {naive_vs_torch} ", end="") @@ -613,8 +623,8 @@ def validate(M, N, K, dtype, tiles_per_update): print(f"persistent: {naive_vs_persistent} ", end="") if tma_persistent_result is not None: print(f"TMA persistent: {naive_vs_tma_persistent} ", end="") - if device_tma_persistent_result is not None: - print(f"Device TMA persistent: {naive_vs_device_tma_persistent} ", end="") + if descriptor_persistent_result is not None: + print(f"Device TMA persistent: {naive_vs_descriptor_persistent} ", end="") print() @@ -639,7 +649,7 @@ def show_profile(precision, profile_name): type=int, default=1, help= - "Number of output tiles calculated for each update of the tma descriptor in matmul_device_tma_persistent_kernel", + "Number of output tiles calculated for each update of the tma descriptor in matmul_descriptor_persistent_kernel", ) parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16") args = parser.parse_args() diff --git a/test/Conversion/allocate_shared_memory.mlir b/test/Conversion/allocate_shared_memory.mlir index 345714f5b2..f3c2ed7033 100644 --- a/test/Conversion/allocate_shared_memory.mlir +++ b/test/Conversion/allocate_shared_memory.mlir @@ -1,14 +1,16 @@ // RUN: triton-opt %s --allocate-shared-memory | FileCheck %s +#blocked = #ttg.blocked<{sizePerThread = [32, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> + // CHECK-LABEL: module // CHECK-SAME: ttg.shared = 131072 : i32 module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @gather_op // TODO(jeff): Optimize the lowering to reduce shared memory usage. -tt.func @gather_op(%arg0: tensor<1024x256xi32>, %arg1: tensor<128x256xf32>) { +tt.func @gather_op(%arg0: tensor<1024x256xi32, #blocked>, %arg1: tensor<128x256xf32, #blocked>) { // CHECK-NEXT: allocation.offset = 0 : i32 - %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<128x256xf32>, tensor<1024x256xi32>) -> tensor<1024x256xf32> + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<128x256xf32, #blocked>, tensor<1024x256xi32, #blocked>) -> tensor<1024x256xf32, #blocked> tt.return } diff --git a/test/Conversion/gather_to_llvm.mlir b/test/Conversion/gather_to_llvm.mlir new file mode 100644 index 0000000000..28a8a7e6b2 --- /dev/null +++ b/test/Conversion/gather_to_llvm.mlir @@ -0,0 +1,271 @@ +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s + +// Check the optimized LLVMIR, since InstCombine makes the linear layout +// logic understandable enough (in simple cases) to check correctness by eye. + +#trivial_layout = #ttg.linear<{register = [], lane = [[1], [2], [4], [8], [16]], warp = [], block = []}> + +#trivial_layout_wider = #ttg.linear<{register = [[32]], lane = [[1], [2], [4], [8], [16]], warp = [], block = []}> + +#trivial_layout_wider_reg_stride_1 = #ttg.linear<{register = [[1]], lane = [[2], [4], [8], [16], [32]], warp = [], block = []}> + +#trivial_2d_one_col = #ttg.linear<{register = [[0, 1]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [], block = []}> + +#span_2d_cols = #ttg.linear<{register = [[1, 0]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [], block = []}> + +#crazy_2d_src = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}> +#crazy_2d_idx = #ttg.linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + +// Each source element is mapped to a single thread, so we expect one index shuffle. +// CHECK-LABEL: @gather_warp_local_trivial +tt.func private @gather_warp_local_trivial(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32 + // CHECK-NEXT: [[RES_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32xf32, #trivial_layout>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Same as above, but there are two index elements per thread. Expect 2 index shuffles +// with the results packed together. +// CHECK-LABEL: @gather_warp_local_larger_output +tt.func private @gather_warp_local_larger_output(%arg0: tensor<64xi32, #trivial_layout_wider>, %arg1: tensor<32xf32, #trivial_layout>) -> tensor<64xf32, #trivial_layout_wider> { + // CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[IDX0]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32 + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID0]], i32 31) + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[IDX1]], 31 + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID1]], i32 31) + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32xf32, #trivial_layout>, tensor<64xi32, #trivial_layout_wider>) -> tensor<64xf32, #trivial_layout_wider> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<64xf32, #trivial_layout_wider> +} + +// Each thread has 2 elements of the source tensor, strided 32 apart, so we +// expect two index shuffles, using the MSB to select between the two. +// CHECK-LABEL: @gather_warp_local_larger_input +tt.func private @gather_warp_local_larger_input(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<64xf32, #trivial_layout_wider>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 31 + // CHECK-NEXT: [[REGID:%.*]] = and i32 [[IDX]], 32 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64xf32, #trivial_layout_wider>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: [[PICK:%.*]] = icmp eq i32 [[REGID]], 0 + // CHECK-NEXT: [[RES_i32:%.*]] = select i1 [[PICK]], i32 [[RES0]], i32 [[RES1]] + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Same as above, except the RegID comes from the LSB. +// CHECK-LABEL: @gather_warp_local_larger_input +tt.func private @gather_warp_local_larger_input_stride_1(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<64xf32, #trivial_layout_wider_reg_stride_1>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[REGID:%.*]] = and i32 [[IDX]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX]], 1 + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[TMP]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64xf32, #trivial_layout_wider_reg_stride_1>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: [[PICK:%.*]] = icmp eq i32 [[REGID]], 0 + // CHECK-NEXT: [[RES_i32:%.*]] = select i1 [[PICK]], i32 [[RES0]], i32 [[RES1]] + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Each thread has 1 element in 2 gather columns, so this is the same as the +// trivial case except now it's 2D. We expect 2 independent index shuffles. +// CHECK-LABEL: @gather_2d_trivial +tt.func private @gather_2d_trivial(%arg0: tensor<32x2xi32, #trivial_2d_one_col>, %arg1: tensor<32x2xf32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[IDX0]], 31 + // CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID0]], i32 31) + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[IDX1]], 31 + // CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID1]], i32 31) + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x2xf32, #trivial_2d_one_col>, tensor<32x2xi32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<32x2xf32, #trivial_2d_one_col> +} + +// The single warp is split into two columns. Each column has half contiguous +// threads, each with 2 contiguous elements. Expect 4 index shuffles: two per +// column. Thus, the index should be dependent on the thread id, since the +// register alone is not enough to determine the column. +// CHECK-LABEL: @gather_2d_span_2 +tt.func private @gather_2d_span_2(%arg0: tensor<32x2xi32, #span_2d_cols>, %arg1: tensor<32x2xf32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // This uses tid to select between the two columns: + // CHECK-NEXT: [[TID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + // CHECK-NEXT: [[COL:%.*]] = and i32 [[TID]], 16 + + // Break the index into reg and thread (within column) components: + // CHECK-NEXT: [[REGID0:%.*]] = and i32 [[IDX0]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX0]], 1 + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[TMP]], 15 + + // CHECK-NEXT: [[SHUFFLE_IDX:%.*]] = or disjoint i32 [[LANEID0]], [[COL]] + + // CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[SRES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[SHUFFLE_IDX]], i32 31) + // CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[SRES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[SHUFFLE_IDX]], i32 31) + + // Use the reg id to select between the two results: + // CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID0]], 0 + // CHECK-NEXT: [[RES0_i32:%.*]] = select i1 [[PICK0]], i32 [[SRES0]], i32 [[SRES1]] + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[REGID1:%.*]] = and i32 [[IDX1]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX1]], 1 + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[TMP]], 15 + + // CHECK-NEXT: [[SHUFFLE_IDX:%.*]] = or disjoint i32 [[LANEID1]], [[COL]] + + // CHECK-NEXT: [[SRES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[SHUFFLE_IDX]], i32 31) + // CHECK-NEXT: [[SRES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[SHUFFLE_IDX]], i32 31) + + // CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID1]], 0 + // CHECK-NEXT: [[RES1_i32:%.*]] = select i1 [[PICK0]], i32 [[SRES0]], i32 [[SRES1]] + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x2xf32, #span_2d_cols>, tensor<32x2xi32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<32x2xf32, #span_2d_cols> +} + +// CHECK-LABEL: @gather_2d_crazy +tt.func private @gather_2d_crazy(%arg0: tensor<32x16xi32, #crazy_2d_idx>, %arg1: tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx> { + // The specific logic becomes hard to grasp here. Just check the shuffles. + + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float, float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float, float, float } %1, 1 + // CHECK-NEXT: [[SRC2:%.*]] = extractvalue { float, float, float, float } %1, 2 + // CHECK-NEXT: [[SRC3:%.*]] = extractvalue { float, float, float, float } %1, 3 + + // CHECK: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], + // CHECK-NEXT: [[VALUE2:%.*]] = bitcast float [[SRC2]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]], + + // CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]], + + // CHECK: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], + // CHECK-NEXT: [[VALUE3:%.*]] = bitcast float [[SRC3]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]], + + // CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]], + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x16xf32, #crazy_2d_src>, tensor<32x16xi32, #crazy_2d_idx>) -> tensor<32x16xf32, #crazy_2d_idx> + tt.return %0 : tensor<32x16xf32, #crazy_2d_idx> +} + +// Keep LLVM from DCE'ing the above functions. Use volatile stores to stop LLVM +// from removing unused function results. +tt.func @anchor(%ptr: !llvm.ptr, + %arg0: tensor<32xi32, #trivial_layout>, + %arg1: tensor<32xf32, #trivial_layout>, + %arg2: tensor<64xi32, #trivial_layout_wider>, + %arg3: tensor<64xf32, #trivial_layout_wider>, + %arg4: tensor<64xf32, #trivial_layout_wider_reg_stride_1>, + %arg5: tensor<32x2xi32, #trivial_2d_one_col>, + %arg6: tensor<32x2xf32, #trivial_2d_one_col>, + %arg7: tensor<32x2xi32, #span_2d_cols>, + %arg8: tensor<32x2xf32, #span_2d_cols>, + %arg9: tensor<32x16xi32, #crazy_2d_idx>, + %arg10: tensor<32x16xf32, #crazy_2d_src>) { + + %0 = tt.call @gather_warp_local_trivial(%arg0, %arg1) : (tensor<32xi32, #trivial_layout>, tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + %1 = builtin.unrealized_conversion_cast %0 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %1, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %2 = tt.call @gather_warp_local_larger_output(%arg2, %arg1) : (tensor<64xi32, #trivial_layout_wider>, tensor<32xf32, #trivial_layout>) -> tensor<64xf32, #trivial_layout_wider> + %3 = builtin.unrealized_conversion_cast %2 : tensor<64xf32, #trivial_layout_wider> to !llvm.struct<(f32, f32)> + llvm.store volatile %3, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + %4 = tt.call @gather_warp_local_larger_input(%arg0, %arg3) : (tensor<32xi32, #trivial_layout>, tensor<64xf32, #trivial_layout_wider>) -> tensor<32xf32, #trivial_layout> + %5 = builtin.unrealized_conversion_cast %4 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %5, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %6 = tt.call @gather_warp_local_larger_input_stride_1(%arg0, %arg4) : (tensor<32xi32, #trivial_layout>, tensor<64xf32, #trivial_layout_wider_reg_stride_1>) -> tensor<32xf32, #trivial_layout> + %7 = builtin.unrealized_conversion_cast %6 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %7, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %8 = tt.call @gather_2d_trivial(%arg5, %arg6) : (tensor<32x2xi32, #trivial_2d_one_col>, tensor<32x2xf32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> + %9 = builtin.unrealized_conversion_cast %8 : tensor<32x2xf32, #trivial_2d_one_col> to !llvm.struct<(f32, f32)> + llvm.store volatile %9, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + %10 = tt.call @gather_2d_span_2(%arg7, %arg8) : (tensor<32x2xi32, #span_2d_cols>, tensor<32x2xf32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> + %11 = builtin.unrealized_conversion_cast %10 : tensor<32x2xf32, #span_2d_cols> to !llvm.struct<(f32, f32)> + llvm.store volatile %11, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + %12 = tt.call @gather_2d_crazy(%arg9, %arg10) : (tensor<32x16xi32, #crazy_2d_idx>, tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx> + %13 = builtin.unrealized_conversion_cast %12 : tensor<32x16xf32, #crazy_2d_idx> to !llvm.struct<(f32, f32, f32, f32)> + llvm.store volatile %13, %ptr : !llvm.struct<(f32, f32, f32, f32)>, !llvm.ptr + + tt.return +} + +} diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir index 8cc3ae64f4..8cf3bdcafd 100644 --- a/test/TritonGPU/amd/amd-instruction-sched.mlir +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -1,8 +1,8 @@ -// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0 -// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm_iglp_0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0 +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm_iglp_1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1 // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2 -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local_prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1 // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2 diff --git a/test/TritonGPU/samples/simulated-grouped-gemm.mlir b/test/TritonGPU/samples/simulated-grouped-gemm.mlir new file mode 100644 index 0000000000..24cbc4dc01 --- /dev/null +++ b/test/TritonGPU/samples/simulated-grouped-gemm.mlir @@ -0,0 +1,377 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// The script is designed to make adding checks to +// a test case fast, it is *not* designed to be authoritative +// about what constitutes a good test! The CHECK should be +// minimized and named to reflect the test intent. + +// CHECK: #[[$ATTR_0:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +// CHECK: #[[$ATTR_1:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +// CHECK: #[[$ATTR_2:.+]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +// CHECK: #[[$ATTR_3:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +// CHECK: #[[$ATTR_4:.+]] = #ttg.shared_memory +// To regenerate this test case, run the command +// triton-opt test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \ +// utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \ +// -o test/TritonGPU/samples/simulated-grouped-gemm.mlir +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s +// CHECK-LABEL: tt.func public @matmul_kernel_descriptor_persistent( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_4:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_5:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_6:.*]] = arith.constant 4 : i64 +// CHECK: %[[VAL_7:.*]] = arith.constant 2 : i64 +// CHECK: %[[VAL_8:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_9:.*]] = arith.constant 3 : i32 +// CHECK: %[[VAL_10:.*]] = arith.constant false +// CHECK: %[[VAL_11:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_12:.*]] = arith.constant 132 : i32 +// CHECK: %[[VAL_13:.*]] = arith.constant -1 : i32 +// CHECK: %[[VAL_14:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_15:.*]] = arith.constant 8 : i32 +// CHECK: %[[VAL_16:.*]] = arith.constant 128 : i32 +// CHECK: %[[VAL_17:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_18:.*]] = arith.constant 64 : i32 +// CHECK: %[[VAL_19:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_20:.*]] = arith.constant 127 : i32 +// CHECK: %[[VAL_21:.*]] = arith.constant 255 : i32 +// CHECK: %[[VAL_22:.*]] = arith.constant 63 : i32 +// CHECK: %[[VAL_23:.*]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_24:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_3]], %[[VAL_20]] : i32 +// CHECK: %[[VAL_26:.*]] = arith.divsi %[[VAL_25]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_4]], %[[VAL_21]] : i32 +// CHECK: %[[VAL_28:.*]] = arith.divsi %[[VAL_27]], %[[VAL_17]] : i32 +// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_5]], %[[VAL_22]] : i32 +// CHECK: %[[VAL_30:.*]] = arith.divsi %[[VAL_29]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_31:.*]] = arith.muli %[[VAL_26]], %[[VAL_28]] : i32 +// CHECK: %[[VAL_32:.*]] = arith.extsi %[[VAL_5]] : i32 to i64 +// CHECK: %[[VAL_33:.*]] = tt.make_tensor_descriptor %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_5]]], {{\[}}%[[VAL_32]], %[[VAL_19]]] : , > +// CHECK: %[[VAL_34:.*]] = tt.make_tensor_descriptor %[[VAL_1]], {{\[}}%[[VAL_4]], %[[VAL_5]]], {{\[}}%[[VAL_32]], %[[VAL_19]]] : , > +// CHECK: %[[VAL_35:.*]] = arith.extsi %[[VAL_4]] : i32 to i64 +// CHECK: %[[VAL_36:.*]] = tt.make_tensor_descriptor %[[VAL_2]], {{\[}}%[[VAL_3]], %[[VAL_4]]], {{\[}}%[[VAL_35]], %[[VAL_19]]] : , > +// CHECK: %[[VAL_37:.*]] = arith.divsi %[[VAL_31]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_38:.*]] = arith.remsi %[[VAL_31]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_39:.*]] = arith.cmpi slt, %[[VAL_24]], %[[VAL_38]] : i32 +// CHECK: %[[VAL_40:.*]] = scf.if %[[VAL_39]] -> (i32) { +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_37]], %[[VAL_11]] : i32 +// CHECK: scf.yield %[[VAL_41]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_37]] : i32 +// CHECK: } +// CHECK: %[[VAL_42:.*]] = arith.subi %[[VAL_24]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_43:.*]] = arith.muli %[[VAL_28]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_44:.*]] = tt.elementwise_inline_asm "mov.b32 $0, 0;" {constraints = "=r", packed_element = 1 : i32, pure = true} -> i32 +// CHECK: %[[VAL_45:.*]] = arith.muli %[[VAL_30]], %[[VAL_40]] : i32 +// CHECK: %[[VAL_46:.*]] = arith.subi %[[VAL_30]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_47:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr +// CHECK: %[[VAL_48:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr +// CHECK: %[[VAL_49:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr +// CHECK: %[[VAL_50:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_51:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_52:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_53:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_14]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_53]], 1 : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_54:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_11]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_54]], 1 : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_55:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_8]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_55]], 1 : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_56:.*]] = arith.cmpi sgt, %[[VAL_45]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_56]], %[[VAL_24]], %[[VAL_42]] : i32 +// CHECK: %[[VAL_58:.*]] = arith.select %[[VAL_56]], %[[VAL_14]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_59:.*]]:2 = scf.if %[[VAL_56]] -> (i32, i32) { +// CHECK: %[[VAL_60:.*]] = arith.divsi %[[VAL_24]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_61:.*]] = arith.muli %[[VAL_60]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_62:.*]] = arith.subi %[[VAL_26]], %[[VAL_61]] : i32 +// CHECK: %[[VAL_63:.*]] = arith.minsi %[[VAL_62]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_64:.*]] = arith.remsi %[[VAL_24]], %[[VAL_63]] : i32 +// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_61]], %[[VAL_64]] : i32 +// CHECK: %[[VAL_66:.*]] = arith.remsi %[[VAL_24]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_67:.*]] = arith.divsi %[[VAL_66]], %[[VAL_63]] : i32 +// CHECK: %[[VAL_68:.*]] = arith.muli %[[VAL_65]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_69:.*]] = arith.muli %[[VAL_67]], %[[VAL_17]] : i32 +// CHECK: scf.yield %[[VAL_68]], %[[VAL_69]] : i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_14]], %[[VAL_14]] : i32, i32 +// CHECK: } +// CHECK: %[[VAL_70:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_14]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_70]], 49152, %[[VAL_56]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_71:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_14]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_72:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_33]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_72]]{{\[}}%[[VAL_73:.*]]#0, %[[VAL_14]]] %[[VAL_71]], %[[VAL_70]], %[[VAL_56]] : , <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> <128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_74:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_14]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_75:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_34]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_75]]{{\[}}%[[VAL_73]]#1, %[[VAL_14]]] %[[VAL_74]], %[[VAL_70]], %[[VAL_56]] : , <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> <256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_76:.*]] = arith.cmpi sgt, %[[VAL_45]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_77:.*]] = arith.cmpi ne, %[[VAL_46]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_78:.*]] = arith.extui %[[VAL_77]] : i1 to i32 +// CHECK: %[[VAL_79:.*]] = arith.cmpi eq, %[[VAL_78]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_80:.*]] = arith.andi %[[VAL_76]], %[[VAL_79]] : i1 +// CHECK: %[[VAL_81:.*]]:10 = scf.if %[[VAL_80]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32) { +// CHECK: %[[VAL_82:.*]] = arith.addi %[[VAL_58]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_83:.*]] = arith.cmpi eq, %[[VAL_82]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_84:.*]] = arith.select %[[VAL_83]], %[[VAL_14]], %[[VAL_82]] : i32 +// CHECK: %[[VAL_85:.*]] = arith.extui %[[VAL_83]] : i1 to i32 +// CHECK: %[[VAL_86:.*]] = arith.extui %[[VAL_83]] : i1 to i32 +// CHECK: %[[VAL_87:.*]] = arith.extui %[[VAL_83]] : i1 to i32 +// CHECK: %[[VAL_88:.*]]:3 = scf.if %[[VAL_83]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>) { +// CHECK: %[[VAL_89:.*]] = tt.addptr %[[VAL_0]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_90:.*]] = arith.muli %[[VAL_32]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_91:.*]] = arith.shrsi %[[VAL_90]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_47]], %[[VAL_89]], {{\[}}%[[VAL_18]], %[[VAL_16]]], {{\[}}%[[VAL_5]], %[[VAL_3]]], {{\[}}%[[VAL_91]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_47]] : !tt.ptr +// CHECK: %[[VAL_92:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_47]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_93:.*]] = tt.addptr %[[VAL_1]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_94:.*]] = arith.muli %[[VAL_32]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_95:.*]] = arith.shrsi %[[VAL_94]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_48]], %[[VAL_93]], {{\[}}%[[VAL_18]], %[[VAL_17]]], {{\[}}%[[VAL_5]], %[[VAL_4]]], {{\[}}%[[VAL_95]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_48]] : !tt.ptr +// CHECK: %[[VAL_96:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_48]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_97:.*]] = tt.addptr %[[VAL_2]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_98:.*]] = arith.muli %[[VAL_35]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_99:.*]] = arith.shrsi %[[VAL_98]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_49]], %[[VAL_97]], {{\[}}%[[VAL_18]], %[[VAL_16]]], {{\[}}%[[VAL_4]], %[[VAL_3]]], {{\[}}%[[VAL_99]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_49]] : !tt.ptr +// CHECK: %[[VAL_100:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_49]] : !tt.ptr to !tt.tensordesc> +// CHECK: scf.yield %[[VAL_92]], %[[VAL_96]], %[[VAL_100]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc> +// CHECK: } else { +// CHECK: scf.yield %[[VAL_33]], %[[VAL_34]], %[[VAL_36]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc> +// CHECK: } +// CHECK: %[[VAL_101:.*]] = arith.addi %[[VAL_57]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_102:.*]] = arith.divsi %[[VAL_101]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_103:.*]] = arith.muli %[[VAL_102]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_104:.*]] = arith.subi %[[VAL_26]], %[[VAL_103]] : i32 +// CHECK: %[[VAL_105:.*]] = arith.minsi %[[VAL_104]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_106:.*]] = arith.remsi %[[VAL_101]], %[[VAL_105]] : i32 +// CHECK: %[[VAL_107:.*]] = arith.addi %[[VAL_103]], %[[VAL_106]] : i32 +// CHECK: %[[VAL_108:.*]] = arith.remsi %[[VAL_101]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_109:.*]] = arith.divsi %[[VAL_108]], %[[VAL_105]] : i32 +// CHECK: %[[VAL_110:.*]] = arith.muli %[[VAL_107]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_111:.*]] = arith.muli %[[VAL_109]], %[[VAL_17]] : i32 +// CHECK: scf.yield %[[VAL_112:.*]]#0, %[[VAL_112]]#1, %[[VAL_112]]#2, %[[VAL_101]], %[[VAL_84]], %[[VAL_110]], %[[VAL_111]], %[[VAL_85]], %[[VAL_86]], %[[VAL_87]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_33]], %[[VAL_34]], %[[VAL_36]], %[[VAL_57]], %[[VAL_58]], %[[VAL_73]]#0, %[[VAL_73]]#1, %[[VAL_14]], %[[VAL_14]], %[[VAL_14]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_113:.*]] = arith.muli %[[VAL_78]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_114:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_11]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_114]], 49152, %[[VAL_76]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_115:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_11]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_116:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_117:.*]]#0 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_116]]{{\[}}%[[VAL_117]]#5, %[[VAL_113]]] %[[VAL_115]], %[[VAL_114]], %[[VAL_76]] : , <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> <128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_118:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_11]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_119:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_117]]#1 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_119]]{{\[}}%[[VAL_117]]#6, %[[VAL_113]]] %[[VAL_118]], %[[VAL_114]], %[[VAL_76]] : , <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> <256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_120:.*]] = ttg.local_alloc : () -> !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_121:.*]]:24 = scf.for %[[VAL_122:.*]] = %[[VAL_14]] to %[[VAL_45]] step %[[VAL_11]] iter_args(%[[VAL_123:.*]] = %[[VAL_78]], %[[VAL_124:.*]] = %[[VAL_117]]#0, %[[VAL_125:.*]] = %[[VAL_117]]#1, %[[VAL_126:.*]] = %[[VAL_117]]#2, %[[VAL_127:.*]] = %[[VAL_117]]#3, %[[VAL_128:.*]] = %[[VAL_117]]#4, %[[VAL_129:.*]] = %[[VAL_117]]#5, %[[VAL_130:.*]] = %[[VAL_117]]#6, %[[VAL_131:.*]] = %[[VAL_23]], %[[VAL_132:.*]] = %[[VAL_10]], %[[VAL_133:.*]] = %[[VAL_11]], %[[VAL_134:.*]] = %[[VAL_13]], %[[VAL_135:.*]] = %[[VAL_14]], %[[VAL_136:.*]] = %[[VAL_117]]#7, %[[VAL_137:.*]] = %[[VAL_117]]#8, %[[VAL_138:.*]] = %[[VAL_117]]#9, %[[VAL_139:.*]] = %[[VAL_14]], %[[VAL_140:.*]] = %[[VAL_78]], %[[VAL_141:.*]] = %[[VAL_36]], %[[VAL_142:.*]] = %[[VAL_117]]#2, %[[VAL_143:.*]] = %[[VAL_73]]#0, %[[VAL_144:.*]] = %[[VAL_117]]#5, %[[VAL_145:.*]] = %[[VAL_73]]#1, %[[VAL_146:.*]] = %[[VAL_117]]#6) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_0]]>, i1, i32, i32, i32, i32, i32, i32, i32, i32, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) : i32 { +// CHECK: %[[VAL_147:.*]] = arith.subi %[[VAL_45]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_148:.*]] = arith.cmpi slt, %[[VAL_122]], %[[VAL_147]] : i32 +// CHECK: %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_123]], %[[VAL_46]] : i32 +// CHECK: %[[VAL_150:.*]] = arith.addi %[[VAL_123]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_151:.*]] = arith.select %[[VAL_149]], %[[VAL_14]], %[[VAL_150]] : i32 +// CHECK: %[[VAL_152:.*]] = arith.cmpi eq, %[[VAL_151]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_153:.*]] = arith.andi %[[VAL_148]], %[[VAL_152]] : i1 +// CHECK: %[[VAL_154:.*]]:10 = scf.if %[[VAL_153]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32) { +// CHECK: %[[VAL_155:.*]] = arith.addi %[[VAL_128]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_156:.*]] = arith.cmpi eq, %[[VAL_155]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_157:.*]] = arith.select %[[VAL_156]], %[[VAL_14]], %[[VAL_155]] : i32 +// CHECK: %[[VAL_158:.*]]:6 = scf.if %[[VAL_156]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32) { +// CHECK: %[[VAL_159:.*]] = tt.addptr %[[VAL_0]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_160:.*]] = arith.muli %[[VAL_136]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_161:.*]] = tt.addptr %[[VAL_47]], %[[VAL_160]] : !tt.ptr, i32 +// CHECK: %[[VAL_162:.*]] = arith.muli %[[VAL_32]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_163:.*]] = arith.shrsi %[[VAL_162]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_161]], %[[VAL_159]], {{\[}}%[[VAL_18]], %[[VAL_16]]], {{\[}}%[[VAL_5]], %[[VAL_3]]], {{\[}}%[[VAL_163]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_161]] : !tt.ptr +// CHECK: %[[VAL_164:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_161]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_165:.*]] = arith.addi %[[VAL_136]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_166:.*]] = arith.cmpi slt, %[[VAL_165]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_167:.*]] = arith.select %[[VAL_166]], %[[VAL_165]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_168:.*]] = tt.addptr %[[VAL_1]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_169:.*]] = arith.muli %[[VAL_137]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_170:.*]] = tt.addptr %[[VAL_48]], %[[VAL_169]] : !tt.ptr, i32 +// CHECK: %[[VAL_171:.*]] = arith.muli %[[VAL_32]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_172:.*]] = arith.shrsi %[[VAL_171]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_170]], %[[VAL_168]], {{\[}}%[[VAL_18]], %[[VAL_17]]], {{\[}}%[[VAL_5]], %[[VAL_4]]], {{\[}}%[[VAL_172]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_170]] : !tt.ptr +// CHECK: %[[VAL_173:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_170]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_174:.*]] = arith.addi %[[VAL_137]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_175:.*]] = arith.cmpi slt, %[[VAL_174]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_176:.*]] = arith.select %[[VAL_175]], %[[VAL_174]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_177:.*]] = tt.addptr %[[VAL_2]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_178:.*]] = arith.muli %[[VAL_138]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_179:.*]] = tt.addptr %[[VAL_49]], %[[VAL_178]] : !tt.ptr, i32 +// CHECK: %[[VAL_180:.*]] = arith.muli %[[VAL_35]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_181:.*]] = arith.shrsi %[[VAL_180]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_179]], %[[VAL_177]], {{\[}}%[[VAL_18]], %[[VAL_16]]], {{\[}}%[[VAL_4]], %[[VAL_3]]], {{\[}}%[[VAL_181]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_179]] : !tt.ptr +// CHECK: %[[VAL_182:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_179]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_183:.*]] = arith.addi %[[VAL_138]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_184:.*]] = arith.cmpi slt, %[[VAL_183]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_185:.*]] = arith.select %[[VAL_184]], %[[VAL_183]], %[[VAL_14]] : i32 +// CHECK: scf.yield %[[VAL_164]], %[[VAL_173]], %[[VAL_182]], %[[VAL_167]], %[[VAL_176]], %[[VAL_185]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_124]], %[[VAL_125]], %[[VAL_126]], %[[VAL_136]], %[[VAL_137]], %[[VAL_138]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_186:.*]] = arith.addi %[[VAL_127]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_187:.*]] = arith.divsi %[[VAL_186]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_188:.*]] = arith.muli %[[VAL_187]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_189:.*]] = arith.subi %[[VAL_26]], %[[VAL_188]] : i32 +// CHECK: %[[VAL_190:.*]] = arith.minsi %[[VAL_189]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_191:.*]] = arith.remsi %[[VAL_186]], %[[VAL_190]] : i32 +// CHECK: %[[VAL_192:.*]] = arith.addi %[[VAL_188]], %[[VAL_191]] : i32 +// CHECK: %[[VAL_193:.*]] = arith.remsi %[[VAL_186]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_194:.*]] = arith.divsi %[[VAL_193]], %[[VAL_190]] : i32 +// CHECK: %[[VAL_195:.*]] = arith.muli %[[VAL_192]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_196:.*]] = arith.muli %[[VAL_194]], %[[VAL_17]] : i32 +// CHECK: scf.yield %[[VAL_197:.*]]#0, %[[VAL_197]]#1, %[[VAL_197]]#2, %[[VAL_186]], %[[VAL_157]], %[[VAL_195]], %[[VAL_196]], %[[VAL_197]]#3, %[[VAL_197]]#4, %[[VAL_197]]#5 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_124]], %[[VAL_125]], %[[VAL_126]], %[[VAL_127]], %[[VAL_128]], %[[VAL_129]], %[[VAL_130]], %[[VAL_136]], %[[VAL_137]], %[[VAL_138]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_198:.*]] = arith.addi %[[VAL_134]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_199:.*]] = arith.cmpi slt, %[[VAL_198]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_200:.*]] = arith.select %[[VAL_199]], %[[VAL_198]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_201:.*]] = arith.xori %[[VAL_135]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_202:.*]] = arith.select %[[VAL_199]], %[[VAL_135]], %[[VAL_201]] : i32 +// CHECK: %[[VAL_203:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_200]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.wait_barrier %[[VAL_203]], %[[VAL_202]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_204:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_200]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_205:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_200]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_206:.*]] = ttg.memdesc_trans %[[VAL_204]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> -> !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_207:.*]] = ttng.warp_group_dot %[[VAL_205]], %[[VAL_206]], %[[VAL_131]], %[[VAL_132]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> -> tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_208:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_207]], %[[VAL_205]], %[[VAL_206]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_0]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_209:.*]] = arith.addi %[[VAL_133]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_210:.*]] = arith.cmpi slt, %[[VAL_209]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_211:.*]] = arith.select %[[VAL_210]], %[[VAL_209]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_212:.*]] = arith.muli %[[VAL_151]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_213:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_211]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_213]], 49152, %[[VAL_148]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_214:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_211]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_215:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_216:.*]]#0 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_215]]{{\[}}%[[VAL_216]]#5, %[[VAL_212]]] %[[VAL_214]], %[[VAL_213]], %[[VAL_148]] : , <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> <128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_217:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_211]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_218:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_216]]#1 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_218]]{{\[}}%[[VAL_216]]#6, %[[VAL_212]]] %[[VAL_217]], %[[VAL_213]], %[[VAL_148]] : , <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> <256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_219:.*]] = arith.cmpi eq, %[[VAL_139]], %[[VAL_46]] : i32 +// CHECK: %[[VAL_220:.*]] = arith.cmpi ne, %[[VAL_139]], %[[VAL_46]] : i32 +// CHECK: scf.if %[[VAL_219]] { +// CHECK: %[[VAL_221:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_208]]#0, %[[VAL_205]], %[[VAL_206]] {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_0]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_222:.*]] = arith.truncf %[[VAL_221]]#0 : tensor<128x256xf32, #[[$ATTR_0]]> to tensor<128x256xf16, #[[$ATTR_0]]> +// CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} +// CHECK: ttg.local_store %[[VAL_222]], %[[VAL_120]] : tensor<128x256xf16, #[[$ATTR_0]]> -> !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: ttng.fence_async_shared {bCluster = false} +// CHECK: %[[VAL_223:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_141]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_local_to_global %[[VAL_223]]{{\[}}%[[VAL_143]], %[[VAL_145]]] %[[VAL_120]] : , <128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: } +// CHECK: scf.yield %[[VAL_151]], %[[VAL_216]]#0, %[[VAL_216]]#1, %[[VAL_216]]#2, %[[VAL_216]]#3, %[[VAL_216]]#4, %[[VAL_216]]#5, %[[VAL_216]]#6, %[[VAL_208]]#0, %[[VAL_220]], %[[VAL_211]], %[[VAL_200]], %[[VAL_202]], %[[VAL_216]]#7, %[[VAL_216]]#8, %[[VAL_216]]#9, %[[VAL_140]], %[[VAL_151]], %[[VAL_142]], %[[VAL_216]]#2, %[[VAL_144]], %[[VAL_216]]#5, %[[VAL_146]], %[[VAL_216]]#6 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_0]]>, i1, i32, i32, i32, i32, i32, i32, i32, i32, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 +// CHECK: } +// CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} +// CHECK: ttg.local_dealloc %[[VAL_120]] : !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_224:.*]] = ttng.warp_group_dot_wait %[[VAL_225:.*]]#8 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_226:.*]] = ttg.async_wait {num = 0 : i32} +// CHECK: %[[VAL_227:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_14]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_227]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_228:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_11]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_228]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_229:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_8]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_229]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttg.local_dealloc %[[VAL_50]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: ttg.local_dealloc %[[VAL_51]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: tt.return +// CHECK: } +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_descriptor_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c132_i32 = arith.constant 132 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.extsi %arg5 : i32 to i64 + %9 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%8, %c1_i64] : , > + %10 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%8, %c1_i64] : , > + %11 = arith.extsi %arg4 : i32 to i64 + %12 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%11, %c1_i64] : , > + %13 = arith.divsi %7, %c132_i32 : i32 + %14 = arith.remsi %7, %c132_i32 : i32 + %15 = arith.cmpi slt, %0, %14 : i32 + %16 = scf.if %15 -> (i32) { + %23 = arith.addi %13, %c1_i32 : i32 + scf.yield %23 : i32 + } else { + scf.yield %13 : i32 + } + %17 = arith.subi %0, %c132_i32 : i32 + %18 = arith.muli %4, %c8_i32 : i32 + %19 = tt.elementwise_inline_asm "mov.b32 $0, 0;" {constraints = "=r", packed_element = 1 : i32, pure = true} -> i32 + %20 = arith.muli %6, %16 : i32 + %21 = arith.subi %6, %c1_i32 : i32 + %true = arith.constant true + %false = arith.constant false + %22:10 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %c-1_i32, %arg8 = %9, %arg9 = %10, %arg10 = %12, %arg11 = %17, %arg12 = %c-1_i32, %arg13 = %c0_i32, %arg14 = %c0_i32, %arg15 = %cst, %arg16 = %false) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1) : i32 { + %23 = arith.cmpi eq, %arg7, %21 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %24 = arith.addi %arg7, %c1_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %25 = arith.select %23, %c0_i32, %24 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %26 = arith.cmpi eq, %25, %c0_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %27:7 = scf.if %26 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) { + %37 = arith.addi %arg12, %c1_i32 : i32 + %38 = arith.cmpi eq, %37, %c1_i32 : i32 + %39:4 = scf.if %38 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32) { + %51 = tt.addptr %arg0, %19 : !tt.ptr, i32 + %52 = tt.make_tensor_descriptor %51, [%arg3, %arg5], [%8, %c1_i64] : , > + %53 = tt.addptr %arg1, %19 : !tt.ptr, i32 + %54 = tt.make_tensor_descriptor %53, [%arg4, %arg5], [%8, %c1_i64] : , > + %55 = tt.addptr %arg2, %19 : !tt.ptr, i32 + %56 = tt.make_tensor_descriptor %55, [%arg3, %arg4], [%11, %c1_i64] : , > + scf.yield %52, %54, %56, %c0_i32 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %37 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } + %40 = arith.addi %arg11, %c132_i32 : i32 + %41 = arith.divsi %40, %18 : i32 + %42 = arith.muli %41, %c8_i32 : i32 + %43 = arith.subi %2, %42 : i32 + %44 = arith.minsi %43, %c8_i32 : i32 + %45 = arith.remsi %40, %44 : i32 + %46 = arith.addi %42, %45 : i32 + %47 = arith.remsi %40, %18 : i32 + %48 = arith.divsi %47, %44 : i32 + %49 = arith.muli %46, %c128_i32 : i32 + %50 = arith.muli %48, %c256_i32 : i32 + scf.yield %39#0, %39#1, %39#2, %40, %39#3, %49, %50 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } {loop.cluster = 0 : i32, loop.stage = 0 : i32} + %28 = arith.muli %25, %c64_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + %29 = tt.experimental_descriptor_load %27#0[%27#5, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %30 = ttg.local_alloc %29 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %31 = tt.experimental_descriptor_load %27#1[%27#6, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %32 = ttg.local_alloc %31 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %33 = ttg.memdesc_trans %32 {loop.cluster = 1 : i32, loop.stage = 2 : i32, order = array} : !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> + %34 = ttng.warp_group_dot %30, %33, %arg15, %arg16 {inputPrecision = 0 : i32, loop.cluster = 1 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %35 = arith.cmpi eq, %25, %21 {loop.cluster = 3 : i32, loop.stage = 2 : i32} : i32 + %36 = scf.if %35 -> (i1) { + %37 = arith.truncf %34 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %38 = ttg.convert_layout %37 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %27#2[%27#5, %27#6], %38 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + scf.yield %false : i1 + } else { + scf.yield %true : i1 + } {loop.cluster = 3 : i32, loop.stage = 2 : i32} + scf.yield %25, %27#0, %27#1, %27#2, %27#3, %27#4, %27#5, %27#6, %34, %36 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1 + } + tt.return + } +} diff --git a/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in b/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in new file mode 100644 index 0000000000..7b096e8883 --- /dev/null +++ b/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in @@ -0,0 +1,104 @@ +// To regenerate this test case, run the command +// triton-opt test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \ +// utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \ +// -o test/TritonGPU/samples/simulated-grouped-gemm.mlir +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_descriptor_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c132_i32 = arith.constant 132 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.extsi %arg5 : i32 to i64 + %9 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%8, %c1_i64] : , > + %10 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%8, %c1_i64] : , > + %11 = arith.extsi %arg4 : i32 to i64 + %12 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%11, %c1_i64] : , > + %13 = arith.divsi %7, %c132_i32 : i32 + %14 = arith.remsi %7, %c132_i32 : i32 + %15 = arith.cmpi slt, %0, %14 : i32 + %16 = scf.if %15 -> (i32) { + %23 = arith.addi %13, %c1_i32 : i32 + scf.yield %23 : i32 + } else { + scf.yield %13 : i32 + } + %17 = arith.subi %0, %c132_i32 : i32 + %18 = arith.muli %4, %c8_i32 : i32 + %19 = tt.elementwise_inline_asm "mov.b32 $0, 0;" {constraints = "=r", packed_element = 1 : i32, pure = true} -> i32 + %20 = arith.muli %6, %16 : i32 + %21 = arith.subi %6, %c1_i32 : i32 + %true = arith.constant true + %false = arith.constant false + %22:10 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %c-1_i32, %arg8 = %9, %arg9 = %10, %arg10 = %12, %arg11 = %17, %arg12 = %c-1_i32, %arg13 = %c0_i32, %arg14 = %c0_i32, %arg15 = %cst, %arg16 = %false) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1) : i32 { + %23 = arith.cmpi eq, %arg7, %21 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %24 = arith.addi %arg7, %c1_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %25 = arith.select %23, %c0_i32, %24 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %26 = arith.cmpi eq, %25, %c0_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %27:7 = scf.if %26 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) { + %37 = arith.addi %arg12, %c1_i32 : i32 + %38 = arith.cmpi eq, %37, %c1_i32 : i32 + %39:4 = scf.if %38 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32) { + %51 = tt.addptr %arg0, %19 : !tt.ptr, i32 + %52 = tt.make_tensor_descriptor %51, [%arg3, %arg5], [%8, %c1_i64] : , > + %53 = tt.addptr %arg1, %19 : !tt.ptr, i32 + %54 = tt.make_tensor_descriptor %53, [%arg4, %arg5], [%8, %c1_i64] : , > + %55 = tt.addptr %arg2, %19 : !tt.ptr, i32 + %56 = tt.make_tensor_descriptor %55, [%arg3, %arg4], [%11, %c1_i64] : , > + scf.yield %52, %54, %56, %c0_i32 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %37 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } + %40 = arith.addi %arg11, %c132_i32 : i32 + %41 = arith.divsi %40, %18 : i32 + %42 = arith.muli %41, %c8_i32 : i32 + %43 = arith.subi %2, %42 : i32 + %44 = arith.minsi %43, %c8_i32 : i32 + %45 = arith.remsi %40, %44 : i32 + %46 = arith.addi %42, %45 : i32 + %47 = arith.remsi %40, %18 : i32 + %48 = arith.divsi %47, %44 : i32 + %49 = arith.muli %46, %c128_i32 : i32 + %50 = arith.muli %48, %c256_i32 : i32 + scf.yield %39#0, %39#1, %39#2, %40, %39#3, %49, %50 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } {loop.cluster = 0 : i32, loop.stage = 0 : i32} + %28 = arith.muli %25, %c64_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + %29 = tt.experimental_descriptor_load %27#0[%27#5, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %30 = ttg.local_alloc %29 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %31 = tt.experimental_descriptor_load %27#1[%27#6, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %32 = ttg.local_alloc %31 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %33 = ttg.memdesc_trans %32 {loop.cluster = 1 : i32, loop.stage = 2 : i32, order = array} : !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> + %34 = ttng.warp_group_dot %30, %33, %arg15, %arg16 {inputPrecision = 0 : i32, loop.cluster = 1 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %35 = arith.cmpi eq, %25, %21 {loop.cluster = 3 : i32, loop.stage = 2 : i32} : i32 + %36 = scf.if %35 -> (i1) { + %37 = arith.truncf %34 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %38 = ttg.convert_layout %37 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %27#2[%27#5, %27#6], %38 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + scf.yield %false : i1 + } else { + scf.yield %true : i1 + } {loop.cluster = 3 : i32, loop.stage = 2 : i32} + scf.yield %25, %27#0, %27#1, %27#2, %27#3, %27#4, %27#5, %27#6, %34, %36 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1 + } + tt.return + } +} diff --git a/third_party/amd/backend/include/hsa/amd_hsa_elf.h b/third_party/amd/backend/include/hsa/amd_hsa_elf.h index 74f15d7d7a..65a77f041b 100644 --- a/third_party/amd/backend/include/hsa/amd_hsa_elf.h +++ b/third_party/amd/backend/include/hsa/amd_hsa_elf.h @@ -136,7 +136,7 @@ enum : unsigned { EF_AMDGPU_MACH_AMDGCN_GFX942 = 0x04c, EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4D = 0x04d, EF_AMDGPU_MACH_AMDGCN_GFX1201 = 0x04e, - EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4F = 0x04f, + EF_AMDGPU_MACH_AMDGCN_GFX950 = 0x04f, EF_AMDGPU_MACH_AMDGCN_RESERVED_0X50 = 0x050, EF_AMDGPU_MACH_AMDGCN_GFX9_GENERIC = 0x051, EF_AMDGPU_MACH_AMDGCN_GFX10_1_GENERIC = 0x052, diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt b/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt index 25a57075be..094ecfc7d4 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt @@ -11,6 +11,8 @@ add_mlir_doc(TritonAMDGPUOps TritonAMDGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(TritonAMDGPUTableGen) set(LLVM_TARGET_DEFINITIONS TritonAMDGPUAttrDefs.td) +mlir_tablegen(TritonAMDGPUEnums.h.inc -gen-enum-decls) +mlir_tablegen(TritonAMDGPUEnums.cpp.inc -gen-enum-defs) mlir_tablegen(TritonAMDGPUAttrDefs.h.inc -gen-attrdef-decls) mlir_tablegen(TritonAMDGPUAttrDefs.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(TritonAMDGPUAttrDefsIncGen) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index 486fd60293..a771e55609 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -33,6 +33,7 @@ // clang-format off #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" +#include "amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.h.inc" // clang-format on #define GET_ATTRDEF_CLASSES diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td index c0aa08421b..491989669c 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -26,6 +26,7 @@ include "mlir/IR/AttrTypeBase.td" include "TritonAMDGPUDialect.td" +include "mlir/IR/EnumAttr.td" class TritonAMDGPU_Attr traits = [], string baseCppClass = "::mlir::Attribute"> @@ -58,5 +59,32 @@ def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> { let assemblyFormat = "`<` params `>`"; } +class TritonAMDGPU_I32Enum cases> + : I32EnumAttr { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::triton::amdgpu"; +} + +class TritonAMDGPU_I32EnumAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; + let cppNamespace = "::mlir::triton::amdgpu"; +} + +def SchedHintCaseNone : I32EnumAttrCase<"none", 0>; +def SchedHintCaseLLVMIglp0 : I32EnumAttrCase<"llvm_iglp_0", 1>; +def SchedHintCaseLLVMIglp1 : I32EnumAttrCase<"llvm_iglp_1", 2>; +def SchedHintCaseLocalPrefetch : I32EnumAttrCase<"local_prefetch", 3>; + +def TritonAMDGPU_SchedHintsEnum : TritonAMDGPU_I32Enum< + "SchedHint", "Instruction Scheduling Hints for AMD GPUs", [ + SchedHintCaseNone, + SchedHintCaseLLVMIglp0, + SchedHintCaseLLVMIglp1, + SchedHintCaseLocalPrefetch, + ]>; + +def TritonAMDGPU_SchedHintVariantAttr : + TritonAMDGPU_I32EnumAttr<"SchedHintVariant", TritonAMDGPU_SchedHintsEnum>; #endif diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 0e2a9304eb..3f141ee2ef 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -48,6 +48,8 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { >(); } +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.cpp.inc" + #define GET_ATTRDEF_CLASSES #include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc" diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index d93f2ca6c6..333a78b4c5 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -217,31 +217,22 @@ struct InstructionSchedHintsRewriter : OpRewritePattern(ctx), numStages(numStages) { this->machineDescr = MachineDescr::get(arch); - std::transform(variant.begin(), variant.end(), variant.begin(), - [](unsigned char c) { return std::tolower(c); }); - - this->schedulingType = - llvm::StringSwitch(variant) - .Case("none", SchedulingType::NONE) - .Case("llvm-iglp-0", SchedulingType::LLVM_IGLP_0) - .Case("llvm-iglp-1", SchedulingType::LLVM_IGLP_1) - .Case("local-prefetch", SchedulingType::LOCAL_PREFETCH) - .Default(SchedulingType::UNKNOWN); + this->schedHint = mlir::triton::amdgpu::SchedHint::none; if (this->numStages < 2) { - this->schedulingType = SchedulingType::NONE; LDBG("ignoring instruction scheduling due to a very low num. " "stages value. Must be >= 2"); + return; } - } - enum class SchedulingType : uint32_t { - NONE = 0, - LLVM_IGLP_0, - LLVM_IGLP_1, - LOCAL_PREFETCH, - UNKNOWN - }; + std::transform(variant.begin(), variant.end(), variant.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (auto maybeSchedHint = triton::amdgpu::symbolizeSchedHint(variant)) + this->schedHint = maybeSchedHint.value(); + else + LDBG("ignoring instruction scheduling because " + "unknown instruction scheduling variant has been provided"); + } // The following is inspired by ROCm Composable Kernel library's V3 pipelining // (see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp). @@ -416,25 +407,18 @@ struct InstructionSchedHintsRewriter LogicalResult matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, PatternRewriter &rewriter) const override { - if (this->schedulingType == SchedulingType::NONE) { + if (this->schedHint == mlir::triton::amdgpu::SchedHint::none) { rewriter.eraseOp(instructionSchedHint); return success(); } - if (this->schedulingType == SchedulingType::UNKNOWN) { - instructionSchedHint.emitError( - "unknown instruction scheduling variant has been provided"); - return failure(); - } - // The switch controls whether instructions are allowed to cross the basic // block boundaries at the very top and at the very bottom. Note, this is // not supposed to be used together with IGLP OPT according to the AMDGPU // backend documentation. const bool limitSchedulingRange = - !(schedulingType == SchedulingType::NONE || - schedulingType == SchedulingType::LLVM_IGLP_0 || - schedulingType == SchedulingType::LLVM_IGLP_1); + this->schedHint == mlir::triton::amdgpu::SchedHint::local_prefetch; + ; Location loc = instructionSchedHint->getLoc(); Block *block = instructionSchedHint->getBlock(); if (limitSchedulingRange) { @@ -445,15 +429,15 @@ struct InstructionSchedHintsRewriter rewriter.setInsertionPoint(block, std::prev(block->end())); - switch (schedulingType) { - case SchedulingType::LLVM_IGLP_0: - case SchedulingType::LLVM_IGLP_1: - createIglpOpt(rewriter, loc, static_cast(schedulingType) - 1); + switch (this->schedHint) { + case mlir::triton::amdgpu::SchedHint::llvm_iglp_0: + case mlir::triton::amdgpu::SchedHint::llvm_iglp_1: + createIglpOpt(rewriter, loc, static_cast(this->schedHint) - 1); break; - case SchedulingType::LOCAL_PREFETCH: + case mlir::triton::amdgpu::SchedHint::local_prefetch: createLocalPrefetchSchedule(rewriter, loc, instructionSchedHint); break; - case SchedulingType::NONE: + case mlir::triton::amdgpu::SchedHint::none: default: break; } @@ -468,7 +452,7 @@ struct InstructionSchedHintsRewriter private: int32_t numStages; - SchedulingType schedulingType; + mlir::triton::amdgpu::SchedHint schedHint; std::unique_ptr machineDescr; }; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp index 63fb972f79..7ab6fd68a5 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp @@ -11,6 +11,7 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) { // CDNA ISA cases switch (kind) { + case llvm::AMDGPU::GK_GFX950: case llvm::AMDGPU::GK_GFX942: case llvm::AMDGPU::GK_GFX941: case llvm::AMDGPU::GK_GFX940: diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp index 459a00c1a1..1bc6eb7cf0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp @@ -9,12 +9,13 @@ #include "mlir/Transforms/DialectConversion.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" using namespace mlir; using namespace mlir::triton; +using namespace mlir::triton::nvidia_gpu; namespace { -constexpr int64_t TMA_SIZE_BYTES = 128; void tensormap_cp_fenceproxy(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, Value outPtr,