diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 1390d67615..540deb414b 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -335,18 +335,12 @@ SmallVector delinearize(RewriterBase &rewriter, Location loc, SmallVector delinearize(RewriterBase &rewriter, Location loc, Value linear, ArrayRef shape); -SmallVector delinearize(unsigned linear, ArrayRef shape, - ArrayRef order); - Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, ArrayRef shape, ArrayRef order); Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, ArrayRef shape); -size_t linearize(ArrayRef multiDim, ArrayRef shape, - ArrayRef order); - Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, StringRef content); @@ -501,24 +495,6 @@ inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef offsets, return ret; } -/// Extend 2d shared object to 3d. -/// -/// If tensor has 3 dimensions, returns original shared object. -/// If tensor shape is [M, N], return shared object describing shape [1, M, N] -/// -/// This Function is used to simplify processing of 2d and 3d dot operands, -/// particularly in the conversion of local_load operation. -/// -/// \param rewriter -/// \param loc -/// \param smemObj -/// \param shape shape of a tensor represented by smemObj -/// \returns shared object describing 3d tensor -SharedMemoryObject -getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, - SharedMemoryObject smemObj, - ArrayRef shape); - // ----------------------------------------------------------------------- // Blocked layout indices // ----------------------------------------------------------------------- diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index b81ecf103a..e592a9d6d1 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -234,12 +234,6 @@ void dumpHWLayout(RankedTensorType tensorType); // Return a string representation of the layout of the tensor. std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView); -template -llvm::SmallVector expandMatrixShapeWithBatch(llvm::ArrayRef s); - -llvm::SmallVector -expandMatrixOrderWithBatch(llvm::ArrayRef o); - } // namespace gpu } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index ae86fb7588..4914fd712b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -1,6 +1,5 @@ #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" using ValueTable = std::map, Value>; using ::mlir::LLVM::delinearize; @@ -8,8 +7,6 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getStridesFromShapeAndOrder; using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::DotOperandEncodingAttr; -using ::mlir::triton::gpu::expandMatrixOrderWithBatch; -using ::mlir::triton::gpu::expandMatrixShapeWithBatch; using ::mlir::triton::gpu::getContigPerThread; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; @@ -18,6 +15,47 @@ using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::MemDescType; using ::mlir::triton::gpu::SharedEncodingAttr; +SmallVector +getThreadIds(Value threadId, ArrayRef shapePerCTATile, + ArrayRef sizePerThread, ArrayRef order, + ConversionPatternRewriter &rewriter, Location loc) { + int dim = order.size(); + SmallVector threadIds(dim); + for (unsigned k = 0; k < dim - 1; k++) { + Value dimK = i32_val(shapePerCTATile[order[k]] / sizePerThread[order[k]]); + Value rem = urem(threadId, dimK); + threadId = udiv(threadId, dimK); + threadIds[order[k]] = rem; + } + Value dimK = i32_val(shapePerCTATile[order[dim - 1]]); + threadIds[order[dim - 1]] = urem(threadId, dimK); + return threadIds; +} + +// Get shapePerCTATile for M or N axis. +int getShapePerCTATileForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto shapePerCTATile = getShapePerCTATile(layout); + + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + return isM ? mShapePerCTATile : nShapePerCTATile; +} + +// Get sizePerThread for M or N axis. +int getSizePerThreadForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto sizePerThread = getSizePerThread(layout); + + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + return isM ? mSizePerThread : nSizePerThread; +} + Value getStructFromValueTable(ArrayRef vals, ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter *typeConverter, @@ -33,329 +71,154 @@ Value getStructFromValueTable(ArrayRef vals, return packLLElements(loc, typeConverter, elems, rewriter, structTy); } -bool isSwizzled(SharedEncodingAttr layout) { return layout.getMaxPhase() != 1; } - -SmallVector swizzleIndices(ConversionPatternRewriter &rewriter, - Location loc, SmallVector rawIndices, - SharedEncodingAttr layout) { - const auto &order = layout.getOrder(); - auto rank = order.size(); - - if (!isSwizzled(layout)) - return rawIndices; - - auto vec = i32_val(layout.getVec()); - auto perPhase = i32_val(layout.getPerPhase()); - auto maxPhase = i32_val(layout.getMaxPhase()); - - auto fastIdx = rawIndices[order[0]]; - auto secondIdx = rawIndices[order[1]]; - // Original algorithm taken from getSwizzledSharedPtrs function - // (TritonGPUToLLVMBase.h) - // - // phase = (secondIdx // perPhase) % maxPhase - // swizzledGroup = ((fastIdx // vec) ^ phase) * vec - // groupRemainder = fastIdx % vec - // colOff = swizzledGroup + groupRemainder - auto phase = urem(udiv(secondIdx, perPhase), maxPhase); - auto swizzledGroup = mul(xor_(udiv(fastIdx, vec), phase), vec); - auto groupRemainder = urem(fastIdx, vec); - auto colOff = add(swizzledGroup, groupRemainder); - - SmallVector swizzledIndices = rawIndices; - swizzledIndices[order[0]] = colOff; - - return swizzledIndices; -} - -struct DimIdx { - unsigned batch; - unsigned k; - unsigned nonK; -}; - -/// Put elements from Value vec to appropriate indexes in opValues array. -/// -/// This function maps elements of 3d sub-tensor in linear array. -/// Axes are arranged in an order provided "opOrder" argument -void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc, - SmallVector &opValues, Value vec, - ArrayRef perThreadTileShape, - unsigned kIdx, unsigned nonKIdx, unsigned bIdx, - const DimIdx &dim, int vecDim, - ArrayRef opOrder) { - auto vecTy = cast(vec.getType()); - auto vectorSize = vecTy.getNumElements(); - auto elemTy = vecTy.getElementType(); - for (int elem = 0; elem < vectorSize; ++elem) { - unsigned spatialIdx[3] = {}; - spatialIdx[dim.batch] = bIdx; - spatialIdx[dim.k] = kIdx; - spatialIdx[dim.nonK] = nonKIdx; - spatialIdx[vecDim] += elem; - - unsigned linearIdx = linearize(spatialIdx, perThreadTileShape, opOrder); - opValues[linearIdx] = extract_element(elemTy, vec, i32_val(elem)); +ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA, + int sizePerThread, + ConversionPatternRewriter &rewriter, + Location loc, + const LLVMTypeConverter *typeConverter, + Type type) { + ValueTable res; + auto elems = unpackLLElements(loc, val, rewriter); + int index = 0; + for (unsigned k = 0; k < K; ++k) { + for (unsigned m = 0; m < n0; m += shapePerCTA) + for (unsigned mm = 0; mm < sizePerThread; ++mm) { + res[{m + mm, k}] = elems[index++]; + } } + return res; } -void verifyCTALayout(CTALayoutAttr ctaLayout) { - auto ctaSplit = ctaLayout.getCTASplitNum(); - for (auto split : ctaSplit) { - if (split != 1) - llvm::report_fatal_error("tensors splited in CGA(thread group clusters) " - "are not supported in FMA dot yet."); - } -} +Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, + Location loc, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto aTensorTy = cast(A.getType()); + auto aLayout = cast(aTensorTy.getEncoding()); + auto aShapePerCTA = getShapePerCTA(aTensorTy); -/// Get a linear offset of first element loaded by thread. -/// -/// In unswizzled case offset of any element computed with formula: -/// smem.base + first_element_offset + constant_offset. -/// -/// first_element_offset depends on lane Id and warp Id -/// constant_offset depends on value number, which is same for all threads. -/// \returns first_element_offset -Value getUnswizzledFirstElemOffset(ConversionPatternRewriter &rewriter, - Location loc, unsigned B, unsigned NonK, - Value bTileOffset, Value nonKTileOffset, - Value bStride, Value nonKStride) { - auto bOffset = mul(urem(bTileOffset, i32_val(B)), bStride); - auto nonKOffset = mul(urem(nonKTileOffset, i32_val(NonK)), nonKStride); - Value threadIdDependantOffset = add(bOffset, nonKOffset); - return threadIdDependantOffset; -} + auto aOrder = aLayout.getOrder(); + auto order = dLayout.getOrder(); -/// \returns number of elements stored by one thread across each dimension -SmallVector getElemsPerThreadInOp(ArrayRef opTensorShape, - ArrayRef shapePerCTATile, - ArrayRef sizePerThread) { - int rank = opTensorShape.size(); - SmallVector elemsPerThread(rank); - for (int d = 0; d < rank; ++d) { - auto numReps = - ceil(static_cast(opTensorShape[d]), shapePerCTATile[d]); - elemsPerThread[d] = numReps * sizePerThread[d]; - } - return elemsPerThread; -} + bool isARow = aOrder[0] == 1; -struct Indexes { - unsigned bTile; - unsigned b; - unsigned k; - unsigned nonKTile; - unsigned nonK; -}; - -/// Computes a linear memory offset of a given element relative to -/// beginning of shared memory object. -Value computeSwizzledOffset(ConversionPatternRewriter &rewriter, Location loc, - const Indexes &i, const DimIdx &dim, - Value bTileOffset, Value nonKTileOffset, - unsigned shapePerCTABTile, - unsigned shapePerCTANonKTile, - SharedEncodingAttr sharedLayout, - ArrayRef opTensorShape, - ArrayRef strides) { - Value offset = i32_val(0); - // Compute unswizzled multi dim coordinates in shared memmory object - SmallVector elemMultiDimIndices(3); - elemMultiDimIndices[dim.batch] = - add(bTileOffset, i32_val(i.bTile * shapePerCTABTile + i.b)); - elemMultiDimIndices[dim.nonK] = - add(nonKTileOffset, i32_val(i.nonKTile * shapePerCTANonKTile + i.nonK)); - elemMultiDimIndices[dim.k] = i32_val(i.k); - - // Apply swizzling pattern to fastest dimension - SmallVector swizzledIndices = - swizzleIndices(rewriter, loc, elemMultiDimIndices, sharedLayout); - - // Linearize shared mem object dimensions into flat offset - for (int d = 0; d < 3; ++d) { - // wrap index if it is larger than tensor - auto wrappedDimIndex = urem(swizzledIndices[d], i32_val(opTensorShape[d])); - auto dimOffset = mul(wrappedDimIndex, strides[d]); - offset = add(offset, dimOffset); + auto aSmem = getSharedMemoryObjectFromStruct( + loc, llA, typeConverter->convertType(aTensorTy.getElementType()), + rewriter); + Value strideAM = aSmem.strides[0]; + Value strideAK = aSmem.strides[1]; + Value strideA0 = isARow ? strideAK : strideAM; + Value strideA1 = isARow ? strideAM : strideAK; + int aNumPtr = 8; + int K = aShapePerCTA[1]; + int M = aShapePerCTA[0]; + + auto shapePerCTATile = getShapePerCTATile(dLayout); + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + + Value mContig = i32_val(sizePerThread[order[1]]); + + // threadId in blocked layout + auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, + rewriter, loc); + Value threadIdM = threadIds[0]; + + Value offA0 = isARow ? _0 : mul(threadIdM, mContig); + Value offA1 = isARow ? mul(threadIdM, mContig) : _0; + SmallVector aOff(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) { + aOff[i] = add(mul(offA0, strideA0), mul(offA1, strideA1)); } - return offset; -} + auto elemTy = typeConverter->convertType(aTensorTy.getElementType()); + + Type ptrTy = aSmem.base.getType(); + SmallVector aPtrs(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) + aPtrs[i] = gep(ptrTy, elemTy, aSmem.base, aOff[i]); + + SmallVector vas; + + int mShapePerCTATile = getShapePerCTATileForMN(dLayout, true /*isM*/); + int mSizePerThread = getSizePerThreadForMN(dLayout, true /*isM*/); + + for (unsigned k = 0; k < K; ++k) + for (unsigned m = 0; m < M; m += mShapePerCTATile) + for (unsigned mm = 0; mm < mSizePerThread; ++mm) { + Value offset = + add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK)); + Value pa = gep(ptrTy, elemTy, aPtrs[0], offset); + Value va = load(elemTy, pa); + vas.emplace_back(va); + } -/// Computes memory offset of a given element relative to the -/// first element loaded by a thread. -Value computeNonSwizzledOffset(ConversionPatternRewriter &rewriter, - Location loc, const Indexes &i, - const DimIdx &dim, ArrayRef tensorShape, - unsigned shapePerCTABTile, - unsigned shapePerCTANonKTile, - ArrayRef strides) { - SmallVector offsetIndices(3); - offsetIndices[dim.batch] = - i32_val((i.bTile * shapePerCTABTile + i.b) % tensorShape[dim.batch]); - offsetIndices[dim.nonK] = i32_val( - (i.nonKTile * shapePerCTANonKTile + i.nonK) % tensorShape[dim.nonK]); - offsetIndices[dim.k] = i32_val(i.k); - - Value offset = i32_val(0); - for (int d = 0; d < 3; ++d) - offset = add(offset, mul(offsetIndices[d], strides[d])); - return offset; + return getStructFromValueTable(vas, rewriter, loc, typeConverter, elemTy); } -/// Generates llvm IR for loading FMA dot operand from shared memory. -/// -/// \param srcVal triton_gpu MemDescType value -/// \param llVal llvm IR values corresponding to srcVal -/// \param dLayout parent dot operand layout -/// \param thread thread id -/// \param loc -/// \param typeConverter -/// \param rewriter -/// \param dotOpNo -/// \returns llvm value with loaded elements -Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, - Value thread, Location loc, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, const int dotOpNo) { - verifyCTALayout(dLayout.getCTALayout()); - - DimIdx dim; - dim.batch = 0; - dim.k = dotOpNo == 0 ? 2 : 1; - dim.nonK = dotOpNo == 0 ? 1 : 2; - auto opTensorTy = cast(srcVal.getType()); - auto opTensorShape = expandMatrixShapeWithBatch(opTensorTy.getShape()); - auto sharedLayout = cast(opTensorTy.getEncoding()); - - auto opOrder = expandMatrixOrderWithBatch(dLayout.getOrder()); - - auto origSmem = getSharedMemoryObjectFromStruct( - loc, llVal, typeConverter->convertType(opTensorTy.getElementType()), +Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, + Location loc, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto bTensorTy = cast(B.getType()); + auto bLayout = cast(bTensorTy.getEncoding()); + auto bShapePerCTA = getShapePerCTA(bTensorTy); + + auto bOrder = bLayout.getOrder(); + auto order = dLayout.getOrder(); + + bool isBRow = bOrder[0] == 1; + + auto bSmem = getSharedMemoryObjectFromStruct( + loc, llB, typeConverter->convertType(bTensorTy.getElementType()), rewriter); - auto smem = getExpandedSharedMemoryObject(rewriter, loc, origSmem, - opTensorTy.getShape()); - auto strides = smem.strides; - int B = opTensorShape[dim.batch]; - int K = opTensorShape[dim.k]; - int NonK = opTensorShape[dim.nonK]; - - auto shapePerCTATile = - expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); - shapePerCTATile[dim.k] = K; - auto sizePerThread = - expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout))); - sizePerThread[dim.k] = K; - auto threadsPerWarp = - expandMatrixShapeWithBatch(ArrayRef(dLayout.getThreadsPerWarp())); - auto warpsPerCTA = - expandMatrixShapeWithBatch(ArrayRef(dLayout.getWarpsPerCTA())); - - auto warpSize = i32_val(triton::gpu::getWarpSize(dLayout)); - auto laneId = urem(thread, warpSize); - auto warpId = udiv(thread, warpSize); - auto laneIds = - mlir::LLVM::delinearize(rewriter, loc, laneId, threadsPerWarp, opOrder); - auto warpIds = - mlir::LLVM::delinearize(rewriter, loc, warpId, warpsPerCTA, opOrder); - auto sizePerWarpB = sizePerThread[dim.batch] * threadsPerWarp[dim.batch]; - auto sizePerWarpNonK = sizePerThread[dim.nonK] * threadsPerWarp[dim.nonK]; - - Value bTileOffset = - mul(laneIds[dim.batch], i32_val(sizePerThread[dim.batch])); - bTileOffset = - add(bTileOffset, mul(warpIds[dim.batch], i32_val(sizePerWarpB))); - Value nonKTileOffset = - mul(laneIds[dim.nonK], i32_val(sizePerThread[dim.nonK])); - nonKTileOffset = - add(nonKTileOffset, mul(warpIds[dim.nonK], i32_val(sizePerWarpNonK))); - - auto elemTy = typeConverter->convertType(opTensorTy.getElementType()); - Type ptrTy = smem.base.getType(); - - auto sharedOrder = expandMatrixOrderWithBatch(sharedLayout.getOrder()); - // compute contiguity of fastest dimension in shared layout. - unsigned vectorSize = sizePerThread[sharedOrder[0]]; - vectorSize = std::min(vectorSize, 128 / elemTy.getIntOrFloatBitWidth()); - - bool swizzlePath = isSwizzled(sharedLayout); - - if (swizzlePath) - vectorSize = std::min(vectorSize, sharedLayout.getVec()); - auto vecTy = vec_ty(elemTy, vectorSize); - // loop increments depend on fastest dim - unsigned dimStep[3] = {1, 1, 1}; - dimStep[sharedOrder[0]] = vectorSize; - - auto shapePerCTABTile = shapePerCTATile[dim.batch]; - auto shapePerCTANonKTile = shapePerCTATile[dim.nonK]; - auto sizeBPerThread = sizePerThread[dim.batch]; - auto sizeNonKPerThread = sizePerThread[dim.nonK]; - auto numBTiles = std::max(1u, B / shapePerCTABTile); - auto numNonKTiles = std::max(1u, NonK / shapePerCTANonKTile); - - auto perThreadShape = - getElemsPerThreadInOp(opTensorShape, shapePerCTATile, sizePerThread); - - SmallVector opValues(numBTiles * sizeBPerThread * K * numNonKTiles * - sizeNonKPerThread); - - // In swizzled memory case basePtr stores pointer to the beginning of shared - // memmory object. - // - // If memory is not swizzled, algorithm breaks element offset pointer into - // constant and non-constant part. Non-constant (depends on thread id) part is - // the offset of the first element of the thread, which is same for all - // elements of the thread. It is computed only once. basePtr stores this - // non-constant part - Value basePtr; - if (swizzlePath) { - basePtr = smem.base; - } else { - auto laneOffset = getUnswizzledFirstElemOffset( - rewriter, loc, B, NonK, bTileOffset, nonKTileOffset, strides[dim.batch], - strides[dim.nonK]); - basePtr = gep(ptrTy, elemTy, smem.base, laneOffset); + Value strideBN = bSmem.strides[1]; + Value strideBK = bSmem.strides[0]; + Value strideB0 = isBRow ? strideBN : strideBK; + Value strideB1 = isBRow ? strideBK : strideBN; + int bNumPtr = 8; + int K = bShapePerCTA[0]; + int N = bShapePerCTA[1]; + + auto shapePerCTATile = getShapePerCTATile(dLayout); + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + + Value nContig = i32_val(sizePerThread[order[0]]); + + // threadId in blocked layout + auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, + rewriter, loc); + Value threadIdN = threadIds[1]; + + Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; + Value offB1 = isBRow ? _0 : mul(threadIdN, nContig); + SmallVector bOff(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) { + bOff[i] = add(mul(offB0, strideB0), mul(offB1, strideB1)); } + auto elemTy = typeConverter->convertType(bTensorTy.getElementType()); + + Type ptrTy = bSmem.base.getType(); + SmallVector bPtrs(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) + bPtrs[i] = gep(ptrTy, elemTy, bSmem.base, bOff[i]); + + SmallVector vbs; + + int nShapePerCTATile = getShapePerCTATileForMN(dLayout, false /*isM*/); + int nSizePerThread = getSizePerThreadForMN(dLayout, false /*isM*/); + + for (unsigned k = 0; k < K; ++k) + for (unsigned n = 0; n < N; n += nShapePerCTATile) + for (unsigned nn = 0; nn < nSizePerThread; ++nn) { + Value offset = + add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK)); + Value pb = gep(ptrTy, elemTy, bPtrs[0], offset); + Value vb = load(elemTy, pb); + vbs.emplace_back(vb); + } - // This loop nest iterates over all values loaded in one thread across batch, - // k and nonK dimensions. Blocked dot operand layout allocates data in tiles - // of size ** for batch and nonK - // dimensions. If tensor shape is larger than tile, pattern repeats. To take - // these repeats into account iterations for batch and nonK are split into - // "intra tile" + "inter tile" indexes: b + bTile, nonK + nonKTile - for (unsigned bTile = 0; bTile < numBTiles; ++bTile) - for (unsigned b = 0; b < sizeBPerThread; b += dimStep[dim.batch]) - for (unsigned k = 0; k < K; k += dimStep[dim.k]) - for (unsigned nonKTile = 0; nonKTile < numNonKTiles; ++nonKTile) - for (unsigned nonK = 0; nonK < sizeNonKPerThread; - nonK += dimStep[dim.nonK]) { - Value offset = i32_val(0); - Indexes idx = {bTile, b, k, nonKTile, nonK}; - - // swizzled variant is more general, but it limits optimization of - // address computation, - if (swizzlePath) { - offset = computeSwizzledOffset( - rewriter, loc, idx, dim, bTileOffset, nonKTileOffset, - shapePerCTABTile, shapePerCTANonKTile, sharedLayout, - opTensorShape, strides); - } else { - offset = computeNonSwizzledOffset(rewriter, loc, idx, dim, - opTensorShape, shapePerCTABTile, - shapePerCTANonKTile, strides); - } - - Value elemAddr = gep(ptrTy, elemTy, basePtr, offset); - Value vec = load(vecTy, elemAddr); - storeValuesInLinearVector( - rewriter, loc, opValues, vec, perThreadShape, /*kIdx*/ k, - /*nonKIdx*/ nonKTile * sizeNonKPerThread + nonK, - /*bIdx*/ bTile * sizeBPerThread + b, dim, sharedOrder[0], - opOrder); - } - - return getStructFromValueTable(opValues, rewriter, loc, typeConverter, - elemTy); + return getStructFromValueTable(vbs, rewriter, loc, typeConverter, elemTy); } namespace SharedToDotOperandFMA { @@ -363,7 +226,9 @@ Value convertLayout(int opIdx, Value val, Value llVal, BlockedEncodingAttr dLayout, Value thread, Location loc, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter) { - return loadFMAOp(val, llVal, dLayout, thread, loc, typeConverter, rewriter, - opIdx); + if (opIdx == 0) + return loadAFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); + else + return loadBFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); } } // namespace SharedToDotOperandFMA diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp index e32b3e0d6e..afb5bf01d4 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -1,36 +1,28 @@ #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" using namespace mlir; using namespace mlir::triton; -using namespace ::mlir::triton::gpu; -using ::mlir::LLVM::linearize; -using ::mlir::triton::gpu::expandMatrixOrderWithBatch; -using ::mlir::triton::gpu::expandMatrixShapeWithBatch; +using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; -using ValueTableFMA = std::map, Value>; +using ValueTableFMA = std::map, Value>; static ValueTableFMA -getValueTableFromStructFMA(Value val, ArrayRef perTileShape, - unsigned kDim, unsigned nonKDim, +getValueTableFromStructFMA(Value val, int K, int n0, int shapePerCTATile, + int sizePerThread, ConversionPatternRewriter &rewriter, Location loc, - ArrayRef order) { + const LLVMTypeConverter *typeConverter, Type type) { ValueTableFMA res; auto elems = unpackLLElements(loc, val, rewriter); - assert(perTileShape.size() == 3); - assert(elems.size() == product(perTileShape)); - assert(kDim == 1 || kDim == 2); - assert(nonKDim == 1 || nonKDim == 2); - const unsigned bDim = 0; - - for (unsigned idx = 0; idx < elems.size(); ++idx) { - auto spatialIdx = mlir::LLVM::delinearize(idx, perTileShape, order); - res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim]}] = elems[idx]; + int index = 0; + for (unsigned k = 0; k < K; ++k) { + for (unsigned m = 0; m < n0; m += shapePerCTATile) + for (unsigned mm = 0; mm < sizePerThread; ++mm) { + res[{m + mm, k}] = elems[index++]; + } } return res; } @@ -42,60 +34,68 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, auto loc = op.getLoc(); auto A = op.getA(); + auto B = op.getB(); + auto C = op.getC(); auto D = op.getResult(); auto aTensorTy = cast(A.getType()); + auto bTensorTy = cast(B.getType()); auto dTensorTy = cast(D.getType()); - SmallVector aShapePerCTA = - expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); - auto dShapePerCTA = - expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); + auto aShapePerCTA = getShapePerCTA(aTensorTy); + auto bShapePerCTA = getShapePerCTA(bTensorTy); BlockedEncodingAttr dLayout = cast(dTensorTy.getEncoding()); - auto order = expandMatrixOrderWithBatch(dLayout.getOrder()); + auto order = dLayout.getOrder(); auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); Value llA = adaptor.getA(); Value llB = adaptor.getB(); - auto sizePerThread = - expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout))); - auto shapePerCTATile = - expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); - - unsigned K = aShapePerCTA[2]; - - unsigned perThreadShape[3]; - for (int i = 0; i < 3; ++i) { - unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i]; - numRep = std::max(static_cast(1), numRep); - perThreadShape[i] = numRep * sizePerThread[i]; + auto sizePerThread = getSizePerThread(dLayout); + auto shapePerCTATile = getShapePerCTATile(dLayout); + + int K = aShapePerCTA[1]; + int M = aShapePerCTA[0]; + int N = bShapePerCTA[1]; + + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + + auto has = + getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread, + rewriter, loc, typeConverter, aTensorTy); + auto hbs = + getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread, + rewriter, loc, typeConverter, bTensorTy); + + SmallVector ret = cc; + bool isCRow = order[0] == 1; + + for (unsigned k = 0; k < K; k++) { + for (unsigned m = 0; m < M; m += mShapePerCTATile) + for (unsigned n = 0; n < N; n += nShapePerCTATile) + for (unsigned mm = 0; mm < mSizePerThread; ++mm) + for (unsigned nn = 0; nn < nSizePerThread; ++nn) { + int mIdx = m / mShapePerCTATile * mSizePerThread + mm; + int nIdx = n / nShapePerCTATile * nSizePerThread + nn; + + int z = isCRow + ? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx + : nIdx * M / mShapePerCTATile * nSizePerThread + mIdx; + ret[z] = rewriter.create(loc, has[{m + mm, k}], + hbs[{n + nn, k}], ret[z]); + } } - auto has = getValueTableFromStructFMA( - llA, {perThreadShape[0], perThreadShape[1], K}, - /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, order); - auto hbs = getValueTableFromStructFMA( - llB, {perThreadShape[0], K, perThreadShape[2]}, - /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, order); - - SmallVector acc = cc; - - for (unsigned b = 0; b < perThreadShape[0]; ++b) - for (unsigned m = 0; m < perThreadShape[1]; ++m) - for (unsigned n = 0; n < perThreadShape[2]; ++n) { - SmallVector multiDimAccumIdx = {b, m, n}; - unsigned linearAccumIdx = - linearize(multiDimAccumIdx, perThreadShape, order); - for (unsigned k = 0; k < K; ++k) { - acc[linearAccumIdx] = rewriter.create( - loc, has[{b, m, k}], hbs[{b, n, k}], acc[linearAccumIdx]); - } - } - - auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); + auto res = packLLElements(loc, typeConverter, ret, rewriter, dTensorTy); rewriter.replaceOp(op, res); return success(); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 8a84229069..eb2c82cfc9 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -634,19 +634,6 @@ SmallVector delinearize(RewriterBase &rewriter, Location loc, return multiDim; } -SmallVector delinearize(unsigned linear, ArrayRef shape, - ArrayRef order) { - auto rank = shape.size(); - assert(order.size() == rank); - SmallVector multiDim(rank); - for (auto dim : order) { - multiDim[dim] = linear % shape[dim]; - linear /= shape[dim]; - } - assert(linear == 0); - return multiDim; -} - Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, ArrayRef shape, ArrayRef order) { return linearize(rewriter, loc, applyPermutation(multiDim, order), @@ -668,14 +655,6 @@ Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, return linear; } -size_t linearize(ArrayRef multiDim, ArrayRef shape, - ArrayRef order) { - size_t linear = 0; - for (unsigned dim : llvm::reverse(order)) - linear = linear * shape[dim] + multiDim[dim]; - return linear; -} - Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, StringRef content) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); @@ -913,23 +892,4 @@ Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, }; } // namespace LLVM - -SharedMemoryObject -getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, - SharedMemoryObject smemObj, - ArrayRef shape) { - assert(shape.size() == 2 || shape.size() == 3); - auto strides = smemObj.getStrides(); - auto offsets = smemObj.getOffsets(); - auto rank = strides.size(); - assert(rank == shape.size()); - if (rank == 3) - return smemObj; - strides.insert(strides.begin(), i32_val(shape[0] * shape[1])); - offsets.insert(offsets.begin(), i32_val(0)); - auto expandedSmemObj = SharedMemoryObject( - smemObj.getBase(), smemObj.getBaseElemType(), strides, offsets); - return expandedSmemObj; -} - } // namespace mlir diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 67ab63beb7..464b150dc1 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -227,11 +227,6 @@ struct TritonDotPattern : public OpConversionPattern { retSizePerThread[rank - 1] = 4; retSizePerThread[rank - 2] = 4; } - retSizePerThread[rank - 1] = std::min( - retSizePerThread[rank - 1], static_cast(origShape[rank - 1])); - retSizePerThread[rank - 2] = std::min( - retSizePerThread[rank - 2], static_cast(origShape[rank - 2])); - SmallVector retOrder(rank); for (unsigned i = 0; i < rank; ++i) retOrder[i] = rank - 1 - i; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 052e23fa42..504dfc8941 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1093,26 +1093,29 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, } } if (auto blockedLayout = mlir::dyn_cast(getParent())) { - auto shapePerCTA = - expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(*this, shape))); - auto shapePerCTATile = - expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(blockedLayout))); - auto sizePerThread = - expandMatrixShapeWithBatch(ArrayRef(blockedLayout.getSizePerThread())); - - int batchDim = 0; - int kDim = getOpIdx() == 0 ? 2 : 1; - int nonKDim = getOpIdx() == 0 ? 1 : 2; - - int batchSize = - std::max(shapePerCTA[batchDim] / shapePerCTATile[batchDim], 1) * - sizePerThread[batchDim]; - int kSize = shapePerCTA[kDim]; - int nonKSize = - std::max(shapePerCTA[nonKDim] / shapePerCTATile[nonKDim], 1) * - sizePerThread[nonKDim]; - - return batchSize * kSize * nonKSize; + auto shapePerCTA = getShapePerCTA(*this, shape); + auto shapePerCTATile = getShapePerCTATile(blockedLayout); + auto order = blockedLayout.getOrder(); + auto sizePerThread = blockedLayout.getSizePerThread(); + + int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0]; + int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0]; + + bool isM = getOpIdx() == 0; + + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int sizePerThreadMN = isM ? mSizePerThread : nSizePerThread; + + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int shapePerCTAMNTile = isM ? mShapePerCTATile : nShapePerCTATile; + + return K * std::max(otherDim / shapePerCTAMNTile, 1) * sizePerThreadMN; } llvm_unreachable("unknown dot operand parent layout"); return 0; @@ -3357,36 +3360,6 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType, return layoutStr; } -template -llvm::SmallVector -mlir::triton::gpu::expandMatrixShapeWithBatch(llvm::ArrayRef s) { - auto rank = s.size(); - assert(rank == 2 || rank == 3); - if (rank == 3) - return llvm::SmallVector(s); - return {1, s[0], s[1]}; -} - -template llvm::SmallVector -mlir::triton::gpu::expandMatrixShapeWithBatch( - llvm::ArrayRef s); - -template llvm::SmallVector -mlir::triton::gpu::expandMatrixShapeWithBatch( - llvm::ArrayRef s); - -llvm::SmallVector -mlir::triton::gpu::expandMatrixOrderWithBatch(llvm::ArrayRef o) { - int rank = o.size(); - assert(rank == 2 || rank == 3); - if (rank == 3) - return llvm::SmallVector(o); - llvm::SmallVector expanded(3, 0); - for (int i = 0; i < rank; ++i) - expanded[i] += o[i] + 1; - return expanded; -} - std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView) { auto layout = tensorType.getEncoding(); diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index 30537a462e..f0ab578cbd 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -389,22 +389,14 @@ def test_min_dot_size(dtype): else: error_msg = "M >= 16, N >= 16 and K >= 16" elif is_hip_mi300(): - if dtype == tl.float16: - pytest.skip("fp16 FMA path supports all sizes") - elif dtype == tl.int8: + if dtype.is_int8(): error_msg += "M >= 16, N >= 16 and K >= 16" else: error_msg += "M >= 16, N >= 16 and K >= 8" elif is_hip_mi200(): - if dtype == tl.float16: - pytest.skip("fp16 FMA path supports all sizes") - else: - error_msg += "M >= 16, N >= 16 and K >= 8" + error_msg += "M >= 16, N >= 16 and K >= 8" elif is_hip(): - if dtype == tl.float16: - pytest.skip("fp16 FMA path supports all sizes") - else: - error_msg = "M >= 16, N >= 16 and K >= 16" + error_msg = "M >= 16, N >= 16 and K >= 16" else: pytest.skip("Test only supported on CUDA and HIP") diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 05eb7c4f38..addba58ae8 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3213,21 +3213,13 @@ def convert_fp8_to_fp32(x, device, dtype_str): ([(16, 16, 8, 4, False, False, 'None', 'ieee', 'float32', 'float32', 1), (32, 16, 8, 4, False, False, 'None', 'ieee', 'float16', 'float16', 1)] if "gfx9" in get_arch() else []) + [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1) - for float8_type in ["float8e5", "float8e4nv"]] + - [(*shape_nw, False, False, epilogue, 'ieee', in_dtype, out_dtype, 1) - for shape_nw in [(2, 2, 16, 1), (1, 64, 64, 1), (64, 2, 64, 2), (64, 64, 4, 4)] - for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] - for in_dtype, out_dtype in [('float16', 'float16'), ('float32', 'float32')]]) + for float8_type in ["float8e5", "float8e4nv"]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, num_ctas, device): if is_interpreter(): - if M < 16 or N < 16 or K < 16: - pytest.skip("small dots are supported only on HIP at the moment") if in_dtype == 'bfloat16': pytest.xfail("bfloat16 is not supported in the interpreter") else: - if not is_hip() and (M < 16 or N < 16 or K < 16): - pytest.skip("small dots are supported only on HIP at the moment") if is_cuda(): capability = torch.cuda.get_device_capability() @@ -3673,14 +3665,7 @@ def make_finite(x, dtype): for in_dtype_str, out_dtype_str in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')]] + # Large block sizes - [(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')] + - # Small block sizes - [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) - for B in [1, 2, 8] - for num_warps in [1, 2, 4] - for BLOCK_M, BLOCK_N in [(1, 32), (32, 2), (8, 8)] - for M, N, K in [(32, 32, 32)] - for in_dtype_str, out_dtype_str in [('float16', 'float16'), ('float32', 'float32')]]) + [(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')]) def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str, device): if is_hip(): # hip does not support tf32 precision, so use ieee for all tests @@ -3693,8 +3678,6 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_ pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") else: input_precision = "tf32" if (is_cuda() or is_xpu()) and in_dtype_str == 'float32' else "ieee" - if BLOCK_M < 16 or BLOCK_N < 16: - pytest.skip("small dots are supported only on HIP at the moment") if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32": if not is_interpreter() and torch.cuda.is_available( @@ -3755,10 +3738,6 @@ def kernel( if in_dtype_str == 'int8': out = numpy_random((B, M, N), dtype_str='int32', rs=rs) else: - if is_hip() and (BLOCK_M < 16 or BLOCK_N < 16) and out_dtype_str == 'float16': - # float16 accumulator in FMA dot loose precision too fast - x *= 0.1 - y *= 0.1 out = numpy_random((B, M, N), dtype_str=out_dtype_str, rs=rs) x_tri = to_triton(x, device=device) diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir index 76fbe584cb..260dddb954 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir @@ -18,20 +18,3 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ tt.return } } - -// ----- - -// Verify that we use FMA when the N dimension is too small for any mma. -// CHECK-NOT: #ttg.amd_mfma -// CHECK-LABEL: small_n_size -#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 64], warpsPerCTA = [1, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func @small_n_size( - %a: tensor<4x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, - %b: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) - -> tensor<4x128xf32, #blocked> { - %zero_f32 = arith.constant dense<0.000000e+00> : tensor<4x128xf32, #blocked> - %result = tt.dot %a, %b, %zero_f32 : tensor<4x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<4x128xf32, #blocked> - tt.return %result : tensor<4x128xf32, #blocked> - } -} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 0dcd20a1d5..5e6ad77eb9 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -13,30 +13,17 @@ def min_dot_size(target: GPUTarget): - - def is_fma_supported(lhsType, rhsType): - return lhsType == rhsType and (lhsType.is_fp16() or lhsType.is_fp32()) - - def get_gfx94_limits(lhsType, rhsType): - if is_fma_supported(lhsType.scalar, rhsType.scalar): - return (1, 1, 1) - # CDNA 3.0 supports k==8 in all mfma variants except for int8 - # (where the smallest `k` supported is 16) - return (16, 16, 16) if (lhsType.scalar.is_int8() or rhsType.scalar.is_int8()) else (16, 16, 8) - - def get_gfx9_limits(lhsType, rhsType): - if is_fma_supported(lhsType.scalar, rhsType.scalar): - return (1, 1, 1) - # CDNA 2.0 always supports `k==8` - return (16, 16, 8) - arch_str = target.arch + # CDNA 3.0 supports k==8 in all mfma variants except for int8 + # (where the smallest `k` supported is 16) if "gfx94" in arch_str: - return get_gfx94_limits + return lambda lhsType, rhsType: (16, 16, 16) if (lhsType.scalar.is_int8() or rhsType.scalar.is_int8()) else ( + 16, 16, 8) + # CDNA 2.0 always supports `k==8` if "gfx9" in arch_str: - return get_gfx9_limits - # gfx11 and gfx12 architectures will only support 16,16,16 with wmma instructions - return lambda lhsType, rhsType: (1, 1, 1) if is_fma_supported(lhsType.scalar, rhsType.scalar) else (16, 16, 16) + return lambda lhsType, rhsType: (16, 16, 8) + # Other architectures will only support 16,16,16 + return lambda lhsType, rhsType: (16, 16, 16) @dataclass(frozen=True) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 89de36cad5..d4a6eb09fd 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -350,8 +350,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); SmallVector sharedOrder; int rank = order.size(); - // TODO rework this when shared -> dotOperand conversions support - // arbitrary shared memory ordering + // TODO rework this when shared -> dotOp conversions support arbitrary + // shared memory ordering if (rank == 3) { // Move the batch dimension (dim #0) to be the last so that it will be // the slowest varying dimension. diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index f2a6454071..b9aac96cbf 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -812,6 +812,23 @@ MemDescType getExpandedDesc(MemDescType descTy) { return expandedDesc; } +SharedMemoryObject +getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, + SharedMemoryObject smemObj, + ArrayRef shape) { + auto strides = smemObj.getStrides(); + auto offsets = smemObj.getOffsets(); + auto rank = strides.size(); + if (rank == 3) + return smemObj; + auto expandedStrides = insertValue(strides, 0, i32_val(shape[0] * shape[1])); + auto expandedOffsets = insertValue(offsets, 0, i32_val(0)); + auto expandedSmemObj = + SharedMemoryObject(smemObj.getBase(), smemObj.getBaseElemType(), + expandedStrides, expandedOffsets); + return expandedSmemObj; +} + namespace SharedToDotOperandMMAv2OrV3 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding,