diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 1b8f65d76d..3eb8a3a3d7 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -7,13 +7,7 @@ add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED) # TODO: what's this? llvm_update_compile_flags(triton-opt) target_link_libraries(triton-opt PRIVATE - TritonLLVMIR - TritonAnalysis - TritonTransforms - TritonGPUTransforms - TritonNvidiaGPUTransforms TritonIntelLLVMIR - MLIRGPUToROCDLTransforms ${dialect_libs} ${conversion_libs} ${triton_libs} @@ -32,11 +26,6 @@ mlir_check_all_link_libraries(triton-reduce) llvm_update_compile_flags(triton-reduce) target_link_libraries(triton-reduce PRIVATE - TritonLLVMIR - TritonAnalysis - TritonTransforms - TritonGPUTransforms - TritonNvidiaGPUTransforms ${dialect_libs} ${conversion_libs} ${triton_libs} @@ -54,10 +43,6 @@ add_llvm_executable(triton-lsp triton-lsp.cpp PARTIAL_SOURCES_INTENDED) llvm_update_compile_flags(triton-lsp) target_link_libraries(triton-lsp PRIVATE - TritonAnalysis - TritonTransforms - TritonGPUTransforms - TritonNvidiaGPUTransforms ${dialect_libs} ${conversion_libs} ${triton_libs} @@ -96,8 +81,6 @@ export_executable_symbols_for_plugins(triton-llvm-opt) add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED) target_link_libraries(triton-tensor-layout PRIVATE - TritonGPUIR - TritonNvidiaGPUIR ${triton_libs} ${conversion_libs} ${dialect_libs} diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index e06db19c6d..a3e38e177d 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -231,6 +231,11 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy); +// Check if MFMA layout can be converted to the dot operand +// layout using warp shuffle. +bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, + RankedTensorType dstTy); + // TODO: Move utility functions that belong to ConvertLayoutOp to class // ConvertLayoutOpHelper in the future bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index a1c37efb52..4b2179611a 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1154,15 +1154,15 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, // Returns true on success. [[nodiscard]] bool emitTransferBetweenRegistersAndShared( RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, - Type elemLlvmTy, std::optional maxVecElems, Value shmemBase, - ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, + Type elemLlvmTy, std::optional maxVecElems, + const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback); inline DenseMap getSwizzledSharedPtrs( Location loc, const TargetInfoBase &target, unsigned inVec, RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout, - Type resElemTy, SharedMemoryObject smemObj, RewriterBase &rewriter, + Type resElemTy, const SharedMemoryObject &smemObj, RewriterBase &rewriter, ArrayRef offsetVals, ArrayRef srcStrides) { // This utility computes the pointers for accessing the provided swizzled // shared memory layout `resSharedLayout`. More specifically, it computes, @@ -1324,14 +1324,14 @@ inline DenseMap getSwizzledSharedPtrs( SmallVector loadSharedToDistributed(RankedTensorType dstTy, triton::gpu::MemDescType srcTy, Type elemLlvmTy, - SharedMemoryObject smemObj, + const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target); void storeDistributedToShared( triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, - ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, - Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + ArrayRef srcVals, const SharedMemoryObject &smemObj, Location loc, + RewriterBase &rewriter, const TargetInfoBase &target, std::pair *const llvmOpCount = nullptr); inline Value getStructFromSharedMemoryObject(Location loc, diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index cfc4c0d13b..344603795a 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -414,6 +414,10 @@ class LinearLayout { bool isSurjective() const { return surjective; } + bool isInvertible() const { + return surjective && getTotalInDimSize() == getTotalOutDimSize(); + } + const BasesT &getBases() const { return bases; } // Get the pos'th basis vector for the inDim -> outDim mapping. @@ -673,6 +677,9 @@ class LinearLayout { // don't place any guarantees on the choices made by this function. [[nodiscard]] LinearLayout invertAndCompose(const LinearLayout &outer) const; + // Get the layout that is the inverse of this layout. + [[nodiscard]] LinearLayout invert() const; + // For each in-dim, returns a bitmask of the "free variables" in the layout // function. // diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 2c2fc80999..8170b923ee 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" +#include "triton/Conversion/MLIRTypes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -650,6 +651,25 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, return ans; } +bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto mfmaLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + if (!mfmaLayout || !dotOperandLayout) + return false; + + // Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case + return dotOperandLayout.getParent() == mfmaLayout && + dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && + dotOperandLayout.getKWidth() == 8 && + getContigPerThread(mfmaLayout)[1] == 4 && + ((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) || + (mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) && + triton::type::isFloat8(srcTy.getElementType()) && + triton::type::isFloat8(dstTy.getElementType()) && + mfmaLayout.getWarpsPerCTA()[1] == 1; +} + // We get the smallest submap of srcTy^{-1} * dstTy that is not the identity // under kBlock, kWarp or kLane (in that order). The idea here is that if we // have a transformation that's the identity on kBlock, we don't need to use @@ -749,7 +769,10 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { return !cvtReordersRegisters(srcTy, dstTy) && !triton::gpu::intel::isDpasToDotShortcut(srcTy, dstTy) && !isBlockedToDotShortcut(srcTy, dstTy) && - !matchMmaV3AndDotOperandLayout(srcTy, dstTy); + !matchMmaV3AndDotOperandLayout(srcTy, dstTy) && + // to be removed when generalized warp shuffle conversions + // are ready: + !matchMFMAAndDotOperandShuffleCase(srcTy, dstTy); } bool atomicNeedsSharedMemory(Value value) { diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 84d30097cb..9fd1570d55 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -402,6 +402,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return failure(); } + // The following check can be removed when generalized warp shuffle + // conversions are ready: + if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) { + return failure(); + } + assert(cvtNeedsSharedMemory(srcTy, dstTy)); SmallVector inVals = diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 1e6e1c1fd7..81a4d6f6bf 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -25,11 +25,9 @@ void lowerDistributedToShared( auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); auto elemTy = typeConverter->convertType(srcTy.getElementType()); - auto smemBase = smemObj.getBase(); - auto dstStrides = smemObj.getStrides(); auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); - storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides, - loc, rewriter, targetInfo, llvmOpCount); + storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter, + targetInfo, llvmOpCount); } struct GlobalScratchAllocOpConversion @@ -157,14 +155,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { // If we remove this one, ldmatrix will IMA. It can probably be relaxed // though canUseLdmatrix &= - srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth; - // To be removed in https://github.com/triton-lang/triton/pull/5154 - bool legacyLoweringIsBuggy = - (kWidth >= 8 || (kWidth == 4 && bitwidth == 32) || - dstTy.getRank() == 3) && - mma.isAmpere(); - return (mma.isHopper() && !canUseLdmatrix) || - (mma.isAmpere() && legacyLoweringIsBuggy); + srcTy.getShape()[0] >= 8 && + srcTy.getShape()[1] >= 4 * kWidth & dstTy.getRank() <= 2; + return !canUseLdmatrix; } if (isa(dot.getParent())) return true; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index a310cdba5f..02b3b121f4 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -169,10 +169,139 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, return ret; } +namespace { + +Value getSmemVecAddr(RankedTensorType registerTy, + triton::gpu::MemDescType sharedTy, Type elemLlvmTy, + Location loc, RewriterBase &rewriter, + const LinearLayout ®ToSharedLayout, Value regId, + Value laneId, Value warpId, + const SharedMemoryObject &smemObj) { + MLIRContext *ctx = rewriter.getContext(); + StringAttr kBlock = str_attr("block"); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + auto shape = sharedTy.getShape(); + auto rank = shape.size(); + auto allocShape = sharedTy.getAllocShape(); + auto sharedEnc = + dyn_cast(sharedTy.getEncoding()); + + auto smemBase = smemObj.getBase(); + auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding()); + auto smemOffsets = smemObj.getOffsets(); + auto smemStrides = smemObj.getStrides(); + Value smemOffset; + // When loading or storing to shared memory, we consider two cases for + // performance reasons: + // + // 1. Non-swizzled shared memory. + // 2. Swizzled shared memory. + // + // Consider lowering `ttg.local_load %a`. In the first case, we can + // directly construct a linear layout using `%a`'s shape and shared memory + // encoding, irrespective of `%a`'s rank or whether it represents a slice of a + // larger tensor. + // + // The method does not apply for swizzled shared memory in some scenarios. + // Key properties of swizzling in Triton are: + // + // - Swizzling applies only to tensors with rank ≥ 2. + // - It is restricted to the last two dimensions of the tensor. + // - These last two dimensions are always treated as the most "minor." + // + // An important edge case arises when `%a` results from `%a = ttg.subview %b`, + // where `%b` is swizzled (and so is `%a`). In this case, constructing a + // layout and determining shared memory offsets using `%a`'s shape is + // incorrect. This is because swizzling depends on the original shape of `%b`, + // which differs from `%a`'s shape. As a result, some locations may fall + // outside `%a`'s contiguous view of memory. Specifically, an element `[i + // (row_idx), j (col_idx)]` in `%a` might map to `[i, j']` after swizzling, + // where `j'` lies outside `%a`'s shape but still within `%b`'s shape. + // + // We propose case 2 (see comments below), which provides a more general + // solution for all swizzled shared memory scenarios, including the edge case + // mentioned above. + if (/*no swizzling*/ sharedEnc.getMaxPhase() == 1 || + /*swizzling but same shape*/ shape == allocShape || + /*swizzling and rank-reduced and rank >= 2*/ + (shape == allocShape.take_back(rank) && rank >= 2)) { // Case 1 + // Get the address to load/store. The multi-dim address is (offsetX1, ..., + // offsetXN, block), where the offsets appear in minor-to-major order, and + // we drop_end to drop block, which we know from above will be 0. + smemOffsets = llvm::to_vector(llvm::drop_end(llvm::make_second_range( + applyLinearLayout(loc, rewriter, regToSharedLayout, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, i32_val(0)}})))); + // Reorder strides according to `order`. This way they match the + // multi-dimensional offsets in regToSharedLayout. + smemOffset = dot(rewriter, loc, smemOffsets, + applyPermutation(smemStrides, sharedOrder)); + } else { // Case 2 -> rank-reduced swizzling + assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2"); + // We define both tensor offsets and shared memory offsets: + // + // - Tensor offsets: Relative offsets within a given tensor. + // - Shared memory offsets: Absolute offsets within the shared memory. + // + // In Triton, the shared memory layout provides an invertible, one-to-one + // mapping between tensor offsets and shared memory offsets. The `base` + // field of any shared memory object represents both the shared memory + // offset and the tensor offset relative to the original tensor at + // allocation, prior to any subview operations. + // + // To determine the shared memory offsets for a specific register when + // dealing with swizzled and sliced tensors, the process involves: + // + // 1. Retrieving the original tensor's `invertAllocSharedLayout`, which + // maps the allocated tensor's offsets back to shared memory offsets. + // 2. Reconstructing the register's offsets in the allocated tensor by + // summing: + // - The shared memory offsets of the current view's base, and + // - The relative tensor offsets of the register. + // + // This approach ensures that "absolute" tensor offsets can be + // mapped to the correct shared memory addresses using + // `invertAllocSharedLayout`. + std::optional regLayout = + triton::gpu::toLinearLayout(shape, registerTy.getEncoding()); + auto allocSharedLayout = triton::gpu::toLinearLayout( + allocShape.take_back(rank), sharedTy.getEncoding(), + elemLlvmTy.getIntOrFloatBitWidth()); + assert(allocSharedLayout.has_value() && + "Failed to convert layout to linear layout"); + auto invertAllocSharedLayout = allocSharedLayout->invert(); + auto multiDimTensorOffsets = + llvm::to_vector(applyLinearLayout(loc, rewriter, *regLayout, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, i32_val(0)}})); + for (auto i = 0; i < rank; i++) { + multiDimTensorOffsets[i].second = + add(multiDimTensorOffsets[i].second, smemOffsets[i]); + } + smemOffset = applyLinearLayout(loc, rewriter, invertAllocSharedLayout, + multiDimTensorOffsets)[0] + .second; + Value baseToAllocBaseDist = dot(rewriter, loc, smemOffsets, smemStrides); + smemOffset = sub(smemOffset, baseToAllocBaseDist); + } + auto ptrTy = smemBase.getType(); + auto vecAddr = gep(ptrTy, elemLlvmTy, smemBase, smemOffset); + vecAddr.setInbounds(true); + return vecAddr; +} + +} // namespace + bool emitTransferBetweenRegistersAndShared( RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, - Type elemLlvmTy, std::optional maxVecElems, Value shmemBase, - ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, + Type elemLlvmTy, std::optional maxVecElems, + const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback) { MLIRContext *ctx = rewriter.getContext(); @@ -230,28 +359,12 @@ bool emitTransferBetweenRegistersAndShared( int numElems = regToSharedLayout->getInDimSize(kRegister); auto vecTy = vec_ty(elemLlvmTy, vecElems); - auto ptrTy = shmemBase.getType(); Value zero = i32_val(0); SmallVector ret; for (int i = 0; i < numElems / vecElems; i++) { - // Get the address to load/store. The multi-dim address is (offsetX1, ..., - // offsetXN, block), where the offsets appear in minor-to-major order, and - // we drop_end to drop block, which we know from above will be 0. - auto multiDimShmemOffset = - llvm::to_vector(llvm::drop_end(llvm::make_second_range( - applyLinearLayout(loc, rewriter, *regToSharedLayout, - {{kRegister, i32_val(i * vecElems)}, - {kLane, laneId}, - {kWarp, warpId}, - {kBlock, zero}})))); - - // Reorder strides according to `order`. This way they match the - // multi-dimensional offsets in regToSharedLayout. - auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding()); - Value shmemOffset = dot(rewriter, loc, multiDimShmemOffset, - applyPermutation(shmemStrides, sharedOrder)); - auto vecAddr = gep(ptrTy, elemLlvmTy, shmemBase, shmemOffset); - vecAddr.setInbounds(true); + auto vecAddr = getSmemVecAddr( + registerTy, sharedTy, elemLlvmTy, loc, rewriter, *regToSharedLayout, + i32_val(i * vecElems), laneId, warpId, smemObj); perVectorCallback(vecTy, vecAddr); } @@ -261,14 +374,13 @@ bool emitTransferBetweenRegistersAndShared( SmallVector loadSharedToDistributed(RankedTensorType dstTy, triton::gpu::MemDescType srcTy, Type elemLlvmTy, - SharedMemoryObject smemObj, + const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target) { SmallVector ret; bool success = emitTransferBetweenRegistersAndShared( - dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj.getBase(), - smemObj.getStrides(), loc, rewriter, target, - [&](VectorType vecTy, Value vecAddr) { + dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, + rewriter, target, [&](VectorType vecTy, Value vecAddr) { auto vecVal = load(vecTy, vecAddr); vecVal.setAlignment(vecTy.getNumElements() * elemLlvmTy.getIntOrFloatBitWidth() / 8); @@ -285,14 +397,14 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, void storeDistributedToShared(triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, - ArrayRef srcVals, Value smemBase, - ArrayRef dstStrides, Location loc, + ArrayRef srcVals, + const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, std::pair *const llvmOpCount) { bool success = emitTransferBetweenRegistersAndShared( - srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, - dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { + srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, + rewriter, target, [&](VectorType vecTy, Value vecAddr) { ArrayRef vals = srcVals.take_front(vecTy.getNumElements()); srcVals = srcVals.drop_front(vecTy.getNumElements()); diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index ea05490c7a..1dcba27e15 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -394,7 +394,7 @@ struct MemDescSubviewOpConversion int rankReduced = srcTy.getRank() - destRank; for (int i = rankReduced; i < opOffsetVals.size(); i++) { strides.push_back(smemObj.strides[i]); - offsetVals.push_back(opOffsetVals[i]); + offsetVals.push_back(add(opOffsetVals[i], smemObj.offsets[i])); } // Compute the offset based on the original strides of the shared memory // object diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index f0fe8d43f4..548702fdc0 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -272,7 +272,7 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, builder.setInsertionPointAfter(viewLoad); auto sharedLoad = builder.createWithStage( - loc, stage, clusterId, loadOp.getType(), + loc, stageForFirstUse, clusterForFirstUse, loadOp.getType(), viewLoad /*,wait->getResult(0)*/); auto result = sharedLoad->getResults(); loadOp->replaceAllUsesWith(result); diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 3a81231ac8..66c8bb7f1a 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -337,6 +337,7 @@ LinearLayout::checkInvariants(bool requireSurjective) { "can be reached by some `in` coordinate, but was not:" + toString(); } + return std::nullopt; } @@ -918,6 +919,17 @@ LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { return flatComposed.reshapeIns(retInDims).reshapeOuts(retOutDims); } +LinearLayout LinearLayout::invert() const { + // A^-1(x) = A^-1(I(x)), thus A.invert() = I.invertAndCompose(A) + assert(isInvertible() && + "A linear layout must be surjective and square to be invertible"); + LinearLayout identity = LinearLayout::empty(); + for (auto outDim : getOutDimNames()) { + identity *= LinearLayout::identity1D(getOutDimSize(outDim), outDim, outDim); + } + return identity.invertAndCompose(*this); +} + llvm::MapVector LinearLayout::getFreeVariableMasks() const { std::unique_ptr mat = getMatrix(*this); diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index c775714859..f0ab578cbd 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -7,7 +7,7 @@ import triton.language as tl from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure import traceback -from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300, is_xpu +from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300, is_hip_mi200, is_xpu def test_err_undefined_variable(): @@ -380,6 +380,42 @@ def dtype_kernel(dtype: tl.constexpr): raise assertion_err from e.value +@pytest.mark.parametrize("dtype", [tl.float8e5, tl.int8, tl.float16]) +def test_min_dot_size(dtype): + error_msg = "Input shapes should have " + if is_cuda(): + if dtype.primitive_bitwidth == 8: + error_msg += "M >= 16, N >= 16 and K >= 32" + else: + error_msg = "M >= 16, N >= 16 and K >= 16" + elif is_hip_mi300(): + if dtype.is_int8(): + error_msg += "M >= 16, N >= 16 and K >= 16" + else: + error_msg += "M >= 16, N >= 16 and K >= 8" + elif is_hip_mi200(): + error_msg += "M >= 16, N >= 16 and K >= 8" + elif is_hip(): + error_msg = "M >= 16, N >= 16 and K >= 16" + else: + pytest.skip("Test only supported on CUDA and HIP") + + @triton.jit + def dot_kernel(dtype: tl.constexpr): + SIZE: tl.constexpr = 8 + a = tl.full((SIZE, SIZE), 0.0, dtype) + b = tl.full((SIZE, SIZE), 0.0, dtype) + tl.dot(a, b) + + with pytest.raises(CompilationError) as e: + triton.compile( + triton.compiler.ASTSource(fn=dot_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) + try: + assert (error_msg in str(e.value.__cause__)) + except AssertionError as assertion_err: + raise assertion_err from e.value + + def test_max_num_imprecise_acc_limit(): @triton.jit diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6d41b6eece..4e908ace31 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -21,6 +21,7 @@ from triton._internal_testing import ( integral_dtypes, int_dtypes, + str_to_triton_dtype, uint_dtypes, float_dtypes, float_dtypes_with_bfloat16, @@ -1680,7 +1681,7 @@ def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr): ('float32', 'bfloat16', False, 1024), ('bfloat16', 'float32', False, 1024), ('float32', 'int32', True, 1024), - ('float32', 'int1', False, 1024), + ('float32', 'bool', False, 1024), ('int8', 'bfloat16', False, 1024), ] + [(f'uint{x}', f'int{x}', True, 1024) for x in [8, 16, 32, 64]] + [(f'int{x}', f'uint{x}', True, 1024) @@ -1726,19 +1727,21 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): # triton kernel @triton.jit - def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr): + def kernel(X, Z, TO_TYPE: tl.constexpr, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr): x_ptr = X + tl.arange(0, SIZE) z_ptr = Z + tl.arange(0, SIZE) x = tl.load(x_ptr) # Depending on the value of ARG_HASH (a "random" number determined by # the test parameters), spell the cast one of three different ways. - if ARG_HASH % 3 == 0: + if ARG_HASH % 4 == 0: z = x.to(Z.dtype.element_ty, bitcast=BITCAST) - elif ARG_HASH % 3 == 1: + elif ARG_HASH % 4 == 1: z = x.cast(Z.dtype.element_ty, bitcast=BITCAST) - else: + elif ARG_HASH % 4 == 2: z = tl.cast(x, Z.dtype.element_ty, bitcast=BITCAST) + else: + z = tl.cast(x, TO_TYPE, bitcast=BITCAST) tl.store(z_ptr, z) @@ -1746,7 +1749,7 @@ def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constex # This way we don't have to increase the number of tests. arg_hash = hash((dtype_x, dtype_z, bitcast, size, num_ctas)) - dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' + dtype_z_np = dtype_z if dtype_z != 'bool' else 'bool_' # triton result if dtype_z.startswith('bfloat'): z_tri = torch.empty((size, ), dtype=getattr(torch, dtype_z), device=device) @@ -1754,7 +1757,10 @@ def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constex z_tri = torch.empty((size, ), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z)) else: z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) - kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, num_ctas=num_ctas) + + dtype_z_tri = str_to_triton_dtype(dtype_z) + kernel[(1, )](x_tri, z_tri, TO_TYPE=dtype_z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, + num_ctas=num_ctas) # torch result if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith( 'float8') or dtype_x.startswith('float8'): diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index 377eed877a..87836a886c 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -9,7 +9,7 @@ from numpy.random import RandomState from typing import Optional, Union -from triton.runtime.jit import TensorWrapper, reinterpret +from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict int_dtypes = ['int8', 'int16', 'int32', 'int64'] uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] @@ -119,6 +119,10 @@ def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torc return torch.tensor(x, device=device) +def str_to_triton_dtype(x: str) -> tl.dtype: + return tl.str_to_ty(type_canonicalisation_dict[x]) + + def torch_dtype_name(dtype) -> str: if isinstance(dtype, triton.language.dtype): return dtype.name diff --git a/python/triton/backends/driver.py b/python/triton/backends/driver.py index 202ae15686..6606b21ca8 100644 --- a/python/triton/backends/driver.py +++ b/python/triton/backends/driver.py @@ -1,4 +1,4 @@ -from abc import ABCMeta, abstractmethod, abstractclassmethod +from abc import ABCMeta, abstractmethod from typing import Callable, List, Protocol, Sequence @@ -10,7 +10,8 @@ def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) - class DriverBase(metaclass=ABCMeta): - @abstractclassmethod + @classmethod + @abstractmethod def is_active(self): pass @@ -18,6 +19,10 @@ def is_active(self): def get_current_target(self): pass + @abstractmethod + def get_active_torch_device(self): + pass + @abstractmethod def get_benchmarker(self) -> Benchmarker: """ diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 85d5f6beba..13d097dc34 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1004,13 +1004,7 @@ def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: """ Alias for :py:func:`tensor.cast`. """ - # Triton doesn't like core functions calling other core functions, so we - # just copy-paste the implementation of cast here. It's not too bad. - dtype = _unwrap_if_constexpr(dtype) - bitcast = _unwrap_if_constexpr(bitcast) - if bitcast: - return semantic.bitcast(self, dtype, _builder) - return semantic.cast(self, dtype, _builder, fp_downcast_rounding) + return cast(self, dtype, fp_downcast_rounding, bitcast, _builder=_builder) # Type stubs for functions added by the _tensor_member_fn decorator. # (Unfortunately these can't be created automatically.) @@ -1594,8 +1588,9 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas :type bitcast: bool, optional """ input = semantic.to_tensor(input, _builder) - if isinstance(bitcast, constexpr): - bitcast = bitcast.value + dtype = _constexpr_to_value(dtype) + fp_downcast_rounding = _constexpr_to_value(fp_downcast_rounding) + bitcast = _constexpr_to_value(bitcast) if bitcast: return semantic.bitcast(input, dtype, _builder) return semantic.cast(input, dtype, _builder, fp_downcast_rounding) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 57c44e0fa6..7759adae80 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -828,10 +828,6 @@ def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tenso def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, fp_downcast_rounding: Optional[str] = None) -> tl.tensor: src_ty = input.type - if isinstance(dst_ty, tl.constexpr): - dst_ty = dst_ty.value - if isinstance(fp_downcast_rounding, tl.constexpr): - fp_downcast_rounding = fp_downcast_rounding.value if src_ty.is_block(): dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) if src_ty == dst_ty: @@ -1473,6 +1469,7 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}" if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): + # We upcast because there's no fp8e4b15 type in MLIR lhs = cast(lhs, tl.float16, builder) rhs = cast(rhs, tl.float16, builder) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 0449b802e5..1e77ca7c1d 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -23,6 +23,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. @@ -60,7 +62,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. def add(x: torch.Tensor, y: torch.Tensor): # We need to preallocate the output. output = torch.empty_like(x) - assert x.is_xpu and y.is_xpu and output.is_xpu + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE n_elements = output.numel() # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. @@ -81,8 +83,8 @@ def add(x: torch.Tensor, y: torch.Tensor): torch.manual_seed(0) size = 98432 -x = torch.rand(size, device='xpu') -y = torch.rand(size, device='xpu') +x = torch.rand(size, device=DEVICE) +y = torch.rand(size, device=DEVICE) output_torch = x + y output_triton = add(x, y) print(output_torch.cpu()) @@ -116,8 +118,8 @@ def add(x: torch.Tensor, y: torch.Tensor): args={}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(size, provider): - x = torch.rand(size, device='xpu', dtype=torch.float32) - y = torch.rand(size, device='xpu', dtype=torch.float32) + x = torch.rand(size, device=DEVICE, dtype=torch.float32) + y = torch.rand(size, device=DEVICE, dtype=torch.float32) quantiles = [0.5, 0.2, 0.8] if provider == 'torch': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index f8e72e05a2..b6ab04fc82 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -27,6 +27,8 @@ import triton.language as tl from triton.runtime import driver +DEVICE = triton.runtime.driver.active.get_active_torch_device() + def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" @@ -110,8 +112,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n # %% # We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. -device = torch.xpu.current_device() -properties = driver.active.utils.get_device_properties(device) +properties = driver.active.utils.get_device_properties(DEVICE.index) NUM_SM = properties["multiprocessor_count"] SIZE_SMEM = properties["max_shared_mem"] WARPS_PER_EU = 8 # TODO: Get from properties @@ -194,7 +195,7 @@ def allocated_slm_size(size_smem): # This will allow us to verify that our padding mechanism works. torch.manual_seed(0) -x = torch.randn(1823, 781, device='xpu') +x = torch.randn(1823, 781, device=DEVICE) y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) @@ -226,9 +227,9 @@ def allocated_slm_size(size_smem): args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` )) def benchmark(M, N, provider): - x = torch.randn(M, N, device='xpu', dtype=torch.float32) - stream = torch.xpu.Stream() - torch.xpu.set_stream(stream) + x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) + stream = getattr(torch, DEVICE.type).Stream() + getattr(torch, DEVICE.type).set_stream(stream) if provider == 'torch': ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == 'triton': diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index d7f4fda652..07121e6b9c 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -154,6 +154,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" @@ -390,8 +392,8 @@ def matmul(a, b, activation=""): # We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). torch.manual_seed(0) -a = torch.randn((512, 512), device='xpu', dtype=torch.float16) -b = torch.randn((512, 512), device='xpu', dtype=torch.float16) +a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) +b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) triton_output = matmul(a, b) torch_output = torch.matmul(a, b) print(f"triton_output_with_fp16_inputs={triton_output}") @@ -408,8 +410,8 @@ def matmul(a, b, activation=""): TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2") if TORCH_HAS_FP8 and is_cuda(): torch.manual_seed(0) - a = torch.randn((512, 512), device="xpu", dtype=torch.float16) - b = torch.randn((512, 512), device="xpu", dtype=torch.float16) + a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) + b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) a = a.to(torch.float8_e5m2) # pre-transpose b for efficiency. b = b.T @@ -458,8 +460,8 @@ def matmul(a, b, activation=""): @triton.testing.perf_report(configs) def benchmark(M, N, K, provider, fp8_inputs): - a = torch.randn((M, K), device='xpu', dtype=torch.float16) - b = torch.randn((K, N), device='xpu', dtype=torch.float16) + a = torch.randn((M, K), device=DEVICE, dtype=torch.float16) + b = torch.randn((K, N), device=DEVICE, dtype=torch.float16) if TORCH_HAS_FP8 and fp8_inputs: a = a.to(torch.float8_e5m2) b = b.T diff --git a/python/tutorials/04-low-memory-dropout.py b/python/tutorials/04-low-memory-dropout.py index fc1fceb5a7..3dd84da47e 100644 --- a/python/tutorials/04-low-memory-dropout.py +++ b/python/tutorials/04-low-memory-dropout.py @@ -38,6 +38,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.jit def _dropout( @@ -71,10 +73,10 @@ def dropout(x, x_keep, p): # Input tensor -x = torch.randn(size=(10, )).xpu() +x = torch.randn(size=(10, ), device=DEVICE) # Dropout mask p = 0.5 -x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).xpu() +x_keep = (torch.rand(size=(10, ), device=DEVICE) > p).to(torch.int32) # output = dropout(x, x_keep=x_keep, p=p) print(tabulate.tabulate([ @@ -138,7 +140,7 @@ def seeded_dropout(x, p, seed): return output -x = torch.randn(size=(10, )).xpu() +x = torch.randn(size=(10, ), device=DEVICE) # Compare this to the baseline - dropout mask is never instantiated! output = seeded_dropout(x, p=0.5, seed=123) output2 = seeded_dropout(x, p=0.5, seed=123) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index c9d00d4593..85a8500308 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -42,6 +42,8 @@ except ModuleNotFoundError: HAS_APEX = False +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.jit def _layer_norm_fwd_fused( @@ -290,7 +292,7 @@ def backward(ctx, dy): layer_norm = LayerNorm.apply -def test_layer_norm(M, N, dtype, eps=1e-5, device='xpu'): +def test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) @@ -329,7 +331,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='xpu'): plot_name='layer-norm-backward', args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}, )) -def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='xpu'): +def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device=DEVICE): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index b646683763..b753d331f6 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -19,6 +19,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" @@ -526,13 +528,13 @@ def backward(ctx, do): @pytest.mark.parametrize("causal", [True]) def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): torch.manual_seed(20) - q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xpu").normal_(mean=0.0, std=0.5).requires_grad_()) - k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xpu").normal_(mean=0.0, std=0.5).requires_grad_()) - v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xpu").normal_(mean=0.0, std=0.5).requires_grad_()) + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) sm_scale = 0.5 dout = torch.randn_like(q) # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX), device="xpu")) + M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE)) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale if causal: p[:, :, M == 0] = float("-inf") @@ -600,7 +602,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): @triton.testing.perf_report(configs) -def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="xpu"): +def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device=DEVICE): assert mode in ["fwd", "bwd"] dtype = torch.float16 if "triton" in provider: diff --git a/python/tutorials/07-extern-functions.py b/python/tutorials/07-extern-functions.py index 45e4c697c4..6c1007befe 100644 --- a/python/tutorials/07-extern-functions.py +++ b/python/tutorials/07-extern-functions.py @@ -25,6 +25,8 @@ from pathlib import Path +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.jit def asin_kernel( @@ -49,8 +51,8 @@ def asin_kernel( torch.manual_seed(0) size = 98432 -x = torch.rand(size, device='xpu') -output_triton = torch.zeros(size, device='xpu') +x = torch.rand(size, device=DEVICE) +output_triton = torch.zeros(size, device=DEVICE) output_torch = torch.asin(x) n_elements = output_torch.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) diff --git a/python/tutorials/08-grouped-gemm.py b/python/tutorials/08-grouped-gemm.py index 8814187230..6f55fd9dcc 100644 --- a/python/tutorials/08-grouped-gemm.py +++ b/python/tutorials/08-grouped-gemm.py @@ -31,6 +31,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" @@ -145,7 +147,6 @@ def grouped_matmul_kernel( def group_gemm_fn(group_A, group_B): - device = torch.device('xpu') assert len(group_A) == len(group_B) group_size = len(group_A) @@ -161,7 +162,7 @@ def group_gemm_fn(group_A, group_B): assert A.shape[1] == B.shape[0] M, K = A.shape K, N = B.shape - C = torch.empty((M, N), device=device, dtype=A.dtype) + C = torch.empty((M, N), device=DEVICE, dtype=A.dtype) group_C.append(C) A_addrs.append(A.data_ptr()) B_addrs.append(B.data_ptr()) @@ -170,11 +171,11 @@ def group_gemm_fn(group_A, group_B): g_lds += [A.stride(0), B.stride(0), C.stride(0)] # note these are device tensors - d_a_ptrs = torch.tensor(A_addrs, device=device) - d_b_ptrs = torch.tensor(B_addrs, device=device) - d_c_ptrs = torch.tensor(C_addrs, device=device) - d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device) - d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device) + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) # we use a fixed number of CTA, and it's auto-tunable grid = lambda META: (META['NUM_SM'], ) grouped_matmul_kernel[grid]( @@ -201,8 +202,8 @@ def group_gemm_fn(group_A, group_B): M = group_m[i] N = group_n[i] K = group_k[i] - A = torch.rand((M, K), device="xpu", dtype=torch.float16) - B = torch.rand((K, N), device="xpu", dtype=torch.float16) + A = torch.rand((M, K), device=DEVICE, dtype=torch.float16) + B = torch.rand((K, N), device=DEVICE, dtype=torch.float16) group_A.append(A) group_B.append(B) @@ -264,9 +265,9 @@ def benchmark(N, provider): g_lds = [] group_C = [] for i in range(group_size): - A = torch.rand((N, N), device="xpu", dtype=torch.float16) - B = torch.rand((N, N), device="xpu", dtype=torch.float16) - C = torch.empty((N, N), device="xpu", dtype=torch.float16) + A = torch.rand((N, N), device=DEVICE, dtype=torch.float16) + B = torch.rand((N, N), device=DEVICE, dtype=torch.float16) + C = torch.empty((N, N), device=DEVICE, dtype=torch.float16) group_A.append(A) group_B.append(B) group_C.append(C) @@ -276,11 +277,11 @@ def benchmark(N, provider): g_sizes += [N, N, N] g_lds += [N, N, N] - d_a_ptrs = torch.tensor(A_addrs, device="xpu") - d_b_ptrs = torch.tensor(B_addrs, device="xpu") - d_c_ptrs = torch.tensor(C_addrs, device="xpu") - d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="xpu") - d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="xpu") + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) quantiles = [0.5, 0.2, 0.8] if provider == ref_lib.lower(): diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 9a9764d992..bdf6db4e69 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s +// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s #mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> #dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> @@ -27,3 +27,191 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr tt.return } } + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_f8_mfma32 + tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + + // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] + + // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) + + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]] + + // Input (8 values): (vec0, vec1) + // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): + // resVec0 resVec1 + // lanes 0-31: (vec0 , vec0 >> 32) (mask0=1) + // lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0) + + // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]] + // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]] + + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] + + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_bf8_mfma32 + tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_f8_mfma16 + tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + + // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] + + // CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) + // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32) + // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) + + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + + // CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] + // CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]] + // CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] + // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]] + // CHECK: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]] + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]] + // CHECK: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]] + + // Input (8 values): (vec0, vec1) + // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): + // resVec0 resVec1 + // lanes 0-15: (vec0 , vec0 >> 16) (mask0=1, mask1=1) + // lanes 16-31: (vec0 >> 16, vec0 >> 32) (mask0=1, mask1=0) + // lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1) + // lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0) + + // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8> + // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + + // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8> + // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] + + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_bf8_mfma16 + tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> + tt.return + } +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 61911356af..f834d726d2 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -469,6 +469,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue + // CHECK-NEXT: llvm.add + // CHECK-NEXT: llvm.add // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: llvm.mul // CHECK-NEXT: llvm.add diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 81b07f2e7d..64668a04e8 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -17,7 +17,8 @@ def min_dot_size(target: GPUTarget): # CDNA 3.0 supports k==8 in all mfma variants except for int8 # (where the smallest `k` supported is 16) if "gfx94" in arch_str: - return lambda lhsType, rhsType: (16, 16, 16) if (lhsType.is_int8() or rhsType.is_int8()) else (16, 16, 8) + return lambda lhsType, rhsType: (16, 16, 16) if (lhsType.scalar.is_int8() or rhsType.scalar.is_int8()) else ( + 16, 16, 8) # CDNA 2.0 always supports `k==8` if "gfx9" in arch_str: return lambda lhsType, rhsType: (16, 16, 8) diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 99e5509eca..537604d8d4 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -502,6 +502,11 @@ def get_current_target(self): warp_size = device_properties['warpSize'] return GPUTarget("hip", arch.split(':')[0], warp_size) + def get_active_torch_device(self): + import torch + # when using hip devices, the device string in pytorch is "cuda" + return torch.device("cuda", self.get_current_device()) + def get_benchmarker(self): from triton.testing import do_bench return do_bench diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 208483beb8..3b61fb8cc4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -116,6 +116,158 @@ struct LocalLoadOpConversion } }; +struct ConvertLayoutOpMFMAToDotOpConversion + : public ConvertOpToLLVMPattern { +public: + explicit ConvertLayoutOpMFMAToDotOpConversion( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = cast(op.getSrc().getType()); + auto dstType = cast(op.getType()); + + if (!matchMFMAAndDotOperandShuffleCase(srcType, dstType)) + return failure(); + + auto loc = op.getLoc(); + + SmallVector inVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + if (inVals.empty() || inVals.size() % 8 != 0) + return failure(); + + auto mfmaLayout = dyn_cast(srcType.getEncoding()); + assert((mfmaLayout.getMDim() == 16 || mfmaLayout.getMDim() == 32) && + "Expected MFMA size 16 or 32"); + assert(triton::gpu::getWarpSize(mfmaLayout) == 64 && + "Expected warp size 64 for MFMA"); + + auto elemTy = int_ty(8); + auto vecTy = vec_ty(elemTy, 4); + + Value c16 = i32_val(16); + Value c32 = i32_val(32); + Value c48 = i32_val(48); + Value c64 = i32_val(64); + + Value threadId = tid_val(); + Value laneId = urem(threadId, c64); + + Value mask0 = icmp_slt(laneId, c32); + Value mask1 = icmp_slt(urem(laneId, c32), c16); + + Value addrShift16 = urem(add(laneId, c16), c64); + Value addrShift32 = urem(add(laneId, c32), c64); + Value addrShift48 = urem(add(laneId, c48), c64); + + SmallVector outVals; + for (size_t startIdx = 0; startIdx < inVals.size(); startIdx += 8) { + Value vec0 = undef(vecTy); + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + vec0 = + insert_element(vecTy, vec0, inVals[startIdx + vIdx], i32_val(vIdx)); + } + Value vec1 = undef(vecTy); + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + vec1 = insert_element(vecTy, vec1, inVals[startIdx + vIdx + 4], + i32_val(vIdx)); + } + + Value resVec0, resVec1; + if (mfmaLayout.getMDim() == 32) { + /* + Using wave shuffle to convert layouts (32x32x16 case): + 1) Input MMA layout (32x32, fp8, 16 values): + _____________________________________________________________ + |(t0 v0 v1 v2 v3) (t32 v0 v1 v2 v3) ... (t32 v12 v13 v14 v15)| + | ... ... | + |(t31 v0 v1 v2 v3) (t63 v0 v1 v2 v3) ... (t63 v12 v13 v14 v15)| + |_____________________________________________________________| + + 2) Output Dot operand layout (two 32x16 tiles, fp8, 8 values each): + ____________________________________________________________ ___ + |(t0 v0 v1 v2 v3 v4 v5 v6 v7) (t32 v0 v1 v2 v3 v4 v5 v6 v7) || + | ... ... ||... + |(t31 v0 v1 v2 v3 v4 v5 v6 v7) (t63 v0 v1 v2 v3 v4 v5 v6 v7) || + |____________________________________________________________||___ + */ + + Value shflVec0 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32), + vecTy); + Value shflVec1 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32), + vecTy); + + resVec0 = select(mask0, vec0, shflVec1); + resVec1 = select(mask0, shflVec0, vec1); + } else if (mfmaLayout.getMDim() == 16) { + /* + 16x16x32 case: + 1) Input MMA layout (two 16x16, fp8, 4 values each): + _________________________________________________________ ___________ + |(t0 v0 v1 v2 v3) (t16 v0 v1 v2 v3) ... (t48 v0 v1 v2 v3)||(t0 v4 ... + | ... ... || ... + |(t15 v0 v1 v2 v3) (t31 v0 v1 v2 v3) ... (t63 v0 v1 v2 v3)||(t15 v4 ... + |_________________________________________________________||___________ + + 2) Output Dot operand layout (16x32 tile, fp8, 8 values): + ________________________________________________________________ + |(t0 v0 v1 v2 v3 v4 v5 v6 v7) ... (t48 v0 v1 v2 v3 v4 v5 v6 v7) | + | ... ... | + |(t15 v0 v1 v2 v3 v4 v5 v6 v7) ... (t63 v0 v1 v2 v3 v4 v5 v6 v7) | + |________________________________________________________________| + */ + + Value shflVec0_16 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec0, int_ty(32)), addrShift16), + vecTy); + Value shflVec0_32 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32), + vecTy); + Value shflVec1_32 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32), + vecTy); + Value shflVec1_48 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec1, int_ty(32)), addrShift48), + vecTy); + + resVec0 = select(mask0, select(mask1, vec0, shflVec0_16), + select(mask1, shflVec1_32, shflVec1_48)); + resVec1 = select(mask0, select(mask1, shflVec0_16, shflVec0_32), + select(mask1, shflVec1_48, vec1)); + } + + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + outVals.push_back(extract_element(elemTy, resVec0, i32_val(vIdx))); + } + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + outVals.push_back(extract_element(elemTy, resVec1, i32_val(vIdx))); + } + } + + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + +protected: + const TargetInfoBase &targetInfo; +}; + } // namespace namespace mlir::triton::AMD { @@ -124,5 +276,7 @@ void populateConvertLayoutOpToLLVMPatterns( RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, + benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index 9510e73b9c..f2a237c35b 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -558,6 +558,10 @@ def get_current_target(self): warp_size = 32 return GPUTarget("xpu", dev_property, warp_size) + def get_active_torch_device(self): + import torch + return torch.device("xpu", self.get_current_device()) + def get_device_interface(self): import torch return torch.xpu diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp index 826eb46092..1f84013007 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp @@ -23,12 +23,11 @@ using namespace mlir::triton::gpu; // blocked -> shared. // Swizzling in shared memory to avoid bank conflict. Normally used for // A/B operands of dots. -void lowerDistributedToShared(Location loc, Value src, Value dst, - Value adaptorSrc, - const SharedMemoryObject &smemObj, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, - const TargetInfoBase &targetInfo) { +void lowerDistributedToShared( + Location loc, Value src, Value dst, Value adaptorSrc, + const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo, + std::pair *const llvmOpCount = nullptr) { auto srcTy = cast(src.getType()); auto dstTy = cast(dst.getType()); auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); @@ -38,8 +37,8 @@ void lowerDistributedToShared(Location loc, Value src, Value dst, auto dstStrides = smemObj.getStrides(); auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); mlir::triton::intel::storeDistributedToShared(dstTy, srcTy, elemTy, inVals, - smemBase, dstStrides, loc, - rewriter, targetInfo); + smemObj, loc, rewriter, + targetInfo, llvmOpCount); } struct LocalAllocOpConversion @@ -235,13 +234,16 @@ struct LocalStoreOpConversion public: using ConvertTritonGPUOpToLLVMPattern< triton::gpu::LocalStoreOp>::ConvertTritonGPUOpToLLVMPattern; + using BackendCallbackType = + decltype(BackendCallbacks::localStoreOpConversion); LocalStoreOpConversion(const LLVMTypeConverter &converter, const TargetInfoBase &targetInfo, + BackendCallbackType backendCallback, PatternBenefit benefit = 1) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), - targetInfo(targetInfo) {} + targetInfo(targetInfo), backendCallback(backendCallback) {} LogicalResult matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, @@ -251,24 +253,36 @@ struct LocalStoreOpConversion getTypeConverter()->convertType(op.getDst().getType().getElementType()); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + + std::pair llvmOpCount; lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(), adaptor.getSrc(), smemObj, getTypeConverter(), - rewriter, targetInfo); + rewriter, targetInfo, &llvmOpCount); + + if (backendCallback) + (backendCallback)(op, llvmOpCount.first, llvmOpCount.second); + rewriter.eraseOp(op); return success(); } private: const TargetInfoBase &targetInfo; + BackendCallbackType backendCallback; }; } // namespace void mlir::triton::intel::populateMemoryOpToLLVMPattern( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, PatternBenefit benefit) { + RewritePatternSet &patterns, PatternBenefit benefit, + std::optional backendCallbacks) { patterns.add(typeConverter, targetInfo, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, targetInfo, benefit); - patterns.add(typeConverter, targetInfo, benefit); + + auto backendCall = + backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr; + patterns.add(typeConverter, targetInfo, backendCall, + benefit); } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h index dd361daf71..32658825d9 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -5,6 +5,7 @@ #include "TritonGPUToLLVMBase.h" #include "intel/include/Analysis/AxisInfo.h" #include "intel/include/TritonIntelGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" namespace mlir::triton::intel { @@ -83,10 +84,10 @@ void populatePrintOpToLLVMPattern( RewritePatternSet &patterns, const TargetInfoBase &targetInfo, PatternBenefit benefit); -void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter, - const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, - PatternBenefit benefit); +void populateMemoryOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit, + std::optional backendCallbacks = std::nullopt); void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h index e8ec3eef6e..5d4b597d2b 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h @@ -770,17 +770,17 @@ inline DenseMap getSwizzledSharedPtrs( inline SmallVector loadSharedToDistributed(RankedTensorType dstTy, triton::gpu::MemDescType srcTy, - Type elemLlvmTy, SharedMemoryObject &memObj, + Type elemLlvmTy, const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target) { SmallVector ret; bool success = emitTransferBetweenRegistersAndShared( - dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, memObj.getBase(), - memObj.getStrides(), loc, rewriter, target, - [&](VectorType vecTy, Value vecAddr) { + dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, + rewriter, target, [&](VectorType vecTy, Value vecAddr) { auto vecVal = load(vecTy, vecAddr); vecVal.setAlignment(vecTy.getNumElements() * elemLlvmTy.getIntOrFloatBitWidth() / 8); + for (int v = 0; v < vecTy.getNumElements(); v++) { ret.push_back(extract_element(elemLlvmTy, vecVal, i32_val(v))); } @@ -791,15 +791,15 @@ loadSharedToDistributed(RankedTensorType dstTy, triton::gpu::MemDescType srcTy, return ret; } -inline void storeDistributedToShared(triton::gpu::MemDescType dstTy, - RankedTensorType srcTy, Type elemLlvmTy, - ArrayRef srcVals, Value smemBase, - ArrayRef dstStrides, Location loc, - RewriterBase &rewriter, - const TargetInfoBase &target) { +inline void +storeDistributedToShared(triton::gpu::MemDescType dstTy, RankedTensorType srcTy, + Type elemLlvmTy, ArrayRef srcVals, + const SharedMemoryObject &smemObj, Location loc, + RewriterBase &rewriter, const TargetInfoBase &target, + std::pair *const llvmOpCount) { bool success = emitTransferBetweenRegistersAndShared( - srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, - dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { + srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, + rewriter, target, [&](VectorType vecTy, Value vecAddr) { ArrayRef vals = srcVals.take_front(vecTy.getNumElements()); srcVals = srcVals.drop_front(vecTy.getNumElements()); @@ -810,7 +810,12 @@ inline void storeDistributedToShared(triton::gpu::MemDescType dstTy, store(vec, vecAddr) .setAlignment(vecTy.getNumElements() * elemLlvmTy.getIntOrFloatBitWidth() / 8); + if (llvmOpCount) { + ++(llvmOpCount->first); + llvmOpCount->second = vecTy; + } }); + if (!success) llvm::report_fatal_error("Failed to emit transfer from register to shared"); } diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index d94be93872..137fef4bd0 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -17,7 +17,17 @@ def min_dot_size(target: GPUTarget): - return lambda lhsType, rhsType: (16, 32, 16) if lhsType.is_int8() else (16, 16, 16) + + def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, n, k] + lhs_bitwidth = lhs_type.scalar.primitive_bitwidth + rhs_bitwidth = rhs_type.scalar.primitive_bitwidth + assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same" + if lhs_bitwidth == 8: + return (16, 16, 32) + else: + return (16, 16, 16) + + return check_dot_compatibility @functools.lru_cache() diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index edeab969ab..e41b4a1386 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -504,6 +504,10 @@ def get_current_target(self): warp_size = 32 return GPUTarget("cuda", capability, warp_size) + def get_active_torch_device(self): + import torch + return torch.device("cuda", self.get_current_device()) + def get_device_interface(self): import torch return torch.cuda diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index d2cef405eb..f21120dbca 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -935,8 +935,8 @@ struct AsyncCopyGlobalToLocalOpConversion VectorType vecTy; SmallVector shmemAddrs; bool ok = emitTransferBetweenRegistersAndShared( - srcTy, dstTy, resElemTy, maxVec, smemObj.base, smemObj.strides, loc, - rewriter, targetInfo, [&](VectorType vecTy_, Value shmemAddr) { + srcTy, dstTy, resElemTy, maxVec, smemObj, loc, rewriter, targetInfo, + [&](VectorType vecTy_, Value shmemAddr) { vecTy = vecTy_; shmemAddrs.push_back(shmemAddr); });