From 9bd7bd42641799b992147f7ff221a2d1a6625462 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 20 Dec 2024 22:02:11 +0000 Subject: [PATCH] [Intel] Improve dot support to target FMA Signed-off-by: Whitney Tsang --- .../TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp | 147 +++++++++--------- 1 file changed, 72 insertions(+), 75 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp index 366c8ff1fb..100e08d055 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -4,30 +4,35 @@ using namespace mlir; using namespace mlir::triton; +using namespace ::mlir::triton::gpu; -using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::expandMatrixOrderWithBatch; +using ::mlir::triton::gpu::expandMatrixShapeWithBatch; using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::getSizePerThread; -using ValueTableFMA = std::map, Value>; +using ValueTableFMA = std::map, Value>; -namespace { -static ValueTableFMA getValueTableFromStructFMA( - Value val, int K, int n0, int shapePerCTATile, int sizePerThread, - ConversionPatternRewriter &rewriter, Location loc, - TritonIntelGPUToLLVMTypeConverter *typeConverter, Type type) { +static ValueTableFMA +getValueTableFromStructFMA(Value val, ArrayRef perTileShape, + unsigned kDim, unsigned nonKDim, + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef order) { ValueTableFMA res; auto elems = unpackLLElements(loc, val, rewriter); - int index = 0; - for (unsigned k = 0; k < K; ++k) { - for (unsigned m = 0; m < n0; m += shapePerCTATile) - for (unsigned mm = 0; mm < sizePerThread; ++mm) { - res[{m + mm, k}] = elems[index++]; - } + assert(perTileShape.size() == 3); + assert(elems.size() == product(perTileShape)); + assert(kDim == 1 || kDim == 2); + assert(nonKDim == 1 || nonKDim == 2); + const unsigned bDim = 0; + + for (unsigned idx = 0; idx < elems.size(); ++idx) { + auto spatialIdx = mlir::LLVM::delinearize(idx, perTileShape, order); + res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim]}] = elems[idx]; } return res; } -} // namespace namespace fma_details { LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, @@ -37,81 +42,73 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, auto loc = op.getLoc(); auto A = op.getA(); - auto B = op.getB(); - auto C = op.getC(); auto D = op.getResult(); auto aTensorTy = cast(A.getType()); - auto bTensorTy = cast(B.getType()); auto dTensorTy = cast(D.getType()); - auto aShapePerCTA = getShapePerCTA(aTensorTy); - auto bShapePerCTA = getShapePerCTA(bTensorTy); + SmallVector aShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); + auto dShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); BlockedEncodingAttr dLayout = cast(dTensorTy.getEncoding()); - auto order = dLayout.getOrder(); + auto order = expandMatrixOrderWithBatch(dLayout.getOrder()); auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); Value llA = adaptor.getA(); Value llB = adaptor.getB(); - auto sizePerThread = getSizePerThread(dLayout); - auto shapePerCTATile = getShapePerCTATile(dLayout); - - int K = aShapePerCTA[1]; - int M = aShapePerCTA[0]; - int N = bShapePerCTA[1]; - - int mShapePerCTATile = - order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int mSizePerThread = - order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int nShapePerCTATile = - order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int nSizePerThread = - order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - - auto has = - getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread, - rewriter, loc, typeConverter, aTensorTy); - auto hbs = - getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread, - rewriter, loc, typeConverter, bTensorTy); - - SmallVector ret = cc; - bool isCRow = order[0] == 1; - - for (unsigned k = 0; k < K; k++) { - for (unsigned m = 0; m < M; m += mShapePerCTATile) - for (unsigned n = 0; n < N; n += nShapePerCTATile) - for (unsigned mm = 0; mm < mSizePerThread; ++mm) - for (unsigned nn = 0; nn < nSizePerThread; ++nn) { - int mIdx = m / mShapePerCTATile * mSizePerThread + mm; - int nIdx = n / nShapePerCTATile * nSizePerThread + nn; - - int z = isCRow - ? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx - : nIdx * M / mShapePerCTATile * nSizePerThread + mIdx; - Type tgtTy = ret[z].getType(); - Value opA = has[{m + mm, k}]; - Value opB = hbs[{n + nn, k}]; - assert(opA.getType() == tgtTy); - assert(opB.getType() == tgtTy); - - llvm::TypeSwitch(tgtTy) - .Case([&](auto) { - ret[z] = - rewriter.create(loc, opA, opB, ret[z]); - }) - .Case([&](auto) { - ret[z] = rewriter.create( - loc, rewriter.create(loc, opA, opB), ret[z]); - }); - } + auto sizePerThread = + expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout))); + auto shapePerCTATile = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); + + unsigned K = aShapePerCTA[2]; + + unsigned perThreadShape[3]; + for (int i = 0; i < 3; ++i) { + unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i]; + numRep = std::max(static_cast(1), numRep); + perThreadShape[i] = numRep * sizePerThread[i]; } - auto res = packLLElements(loc, typeConverter, ret, rewriter, dTensorTy); + auto has = getValueTableFromStructFMA( + llA, {perThreadShape[0], perThreadShape[1], K}, + /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, order); + auto hbs = getValueTableFromStructFMA( + llB, {perThreadShape[0], K, perThreadShape[2]}, + /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, order); + + SmallVector acc = cc; + + for (unsigned b = 0; b < perThreadShape[0]; ++b) + for (unsigned m = 0; m < perThreadShape[1]; ++m) + for (unsigned n = 0; n < perThreadShape[2]; ++n) { + SmallVector multiDimAccumIdx = {b, m, n}; + unsigned linearAccumIdx = + linearize(multiDimAccumIdx, perThreadShape, order); + for (unsigned k = 0; k < K; ++k) { + Type tgtTy = acc[linearAccumIdx].getType(); + Value opA = has[{b, m, k}]; + Value opB = hbs[{b, n, k}]; + assert(opA.getType() == tgtTy); + assert(opB.getType() == tgtTy); + llvm::TypeSwitch(tgtTy) + .Case([&](auto) { + acc[linearAccumIdx] = rewriter.create( + loc, opA, opB, acc[linearAccumIdx]); + }) + .Case([&](auto) { + acc[linearAccumIdx] = rewriter.create( + loc, rewriter.create(loc, opA, opB), + acc[linearAccumIdx]); + }); + } + } + + auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); rewriter.replaceOp(op, res); return success();