Skip to content

Commit

Permalink
[Intel] Improve dot support to target FMA
Browse files Browse the repository at this point in the history
Signed-off-by: Whitney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang committed Dec 20, 2024
1 parent 34ad801 commit 9bd7bd4
Showing 1 changed file with 72 additions and 75 deletions.
147 changes: 72 additions & 75 deletions third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,35 @@

using namespace mlir;
using namespace mlir::triton;
using namespace ::mlir::triton::gpu;

using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::expandMatrixOrderWithBatch;
using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
using ::mlir::triton::gpu::getSizePerThread;

using ValueTableFMA = std::map<std::pair<int, int>, Value>;
using ValueTableFMA = std::map<std::tuple<int, int, int>, Value>;

namespace {
static ValueTableFMA getValueTableFromStructFMA(
Value val, int K, int n0, int shapePerCTATile, int sizePerThread,
ConversionPatternRewriter &rewriter, Location loc,
TritonIntelGPUToLLVMTypeConverter *typeConverter, Type type) {
static ValueTableFMA
getValueTableFromStructFMA(Value val, ArrayRef<unsigned> perTileShape,
unsigned kDim, unsigned nonKDim,
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<unsigned> order) {
ValueTableFMA res;
auto elems = unpackLLElements(loc, val, rewriter);
int index = 0;
for (unsigned k = 0; k < K; ++k) {
for (unsigned m = 0; m < n0; m += shapePerCTATile)
for (unsigned mm = 0; mm < sizePerThread; ++mm) {
res[{m + mm, k}] = elems[index++];
}
assert(perTileShape.size() == 3);
assert(elems.size() == product(perTileShape));
assert(kDim == 1 || kDim == 2);
assert(nonKDim == 1 || nonKDim == 2);
const unsigned bDim = 0;

for (unsigned idx = 0; idx < elems.size(); ++idx) {
auto spatialIdx = mlir::LLVM::delinearize(idx, perTileShape, order);
res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim]}] = elems[idx];
}
return res;
}
} // namespace

namespace fma_details {
LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
Expand All @@ -37,81 +42,73 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
auto loc = op.getLoc();

auto A = op.getA();
auto B = op.getB();
auto C = op.getC();
auto D = op.getResult();

auto aTensorTy = cast<RankedTensorType>(A.getType());
auto bTensorTy = cast<RankedTensorType>(B.getType());
auto dTensorTy = cast<RankedTensorType>(D.getType());

auto aShapePerCTA = getShapePerCTA(aTensorTy);
auto bShapePerCTA = getShapePerCTA(bTensorTy);
SmallVector<int64_t> aShapePerCTA =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy)));
auto dShapePerCTA =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy)));

BlockedEncodingAttr dLayout =
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
auto order = dLayout.getOrder();
auto order = expandMatrixOrderWithBatch(dLayout.getOrder());
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);

Value llA = adaptor.getA();
Value llB = adaptor.getB();

auto sizePerThread = getSizePerThread(dLayout);
auto shapePerCTATile = getShapePerCTATile(dLayout);

int K = aShapePerCTA[1];
int M = aShapePerCTA[0];
int N = bShapePerCTA[1];

int mShapePerCTATile =
order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
int mSizePerThread =
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
int nShapePerCTATile =
order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
int nSizePerThread =
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];

auto has =
getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread,
rewriter, loc, typeConverter, aTensorTy);
auto hbs =
getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread,
rewriter, loc, typeConverter, bTensorTy);

SmallVector<Value> ret = cc;
bool isCRow = order[0] == 1;

for (unsigned k = 0; k < K; k++) {
for (unsigned m = 0; m < M; m += mShapePerCTATile)
for (unsigned n = 0; n < N; n += nShapePerCTATile)
for (unsigned mm = 0; mm < mSizePerThread; ++mm)
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
int mIdx = m / mShapePerCTATile * mSizePerThread + mm;
int nIdx = n / nShapePerCTATile * nSizePerThread + nn;

int z = isCRow
? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx
: nIdx * M / mShapePerCTATile * nSizePerThread + mIdx;
Type tgtTy = ret[z].getType();
Value opA = has[{m + mm, k}];
Value opB = hbs[{n + nn, k}];
assert(opA.getType() == tgtTy);
assert(opB.getType() == tgtTy);

llvm::TypeSwitch<Type>(tgtTy)
.Case<FloatType>([&](auto) {
ret[z] =
rewriter.create<LLVM::FMulAddOp>(loc, opA, opB, ret[z]);
})
.Case<IntegerType>([&](auto) {
ret[z] = rewriter.create<LLVM::AddOp>(
loc, rewriter.create<LLVM::MulOp>(loc, opA, opB), ret[z]);
});
}
auto sizePerThread =
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
auto shapePerCTATile =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));

unsigned K = aShapePerCTA[2];

unsigned perThreadShape[3];
for (int i = 0; i < 3; ++i) {
unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i];
numRep = std::max(static_cast<unsigned>(1), numRep);
perThreadShape[i] = numRep * sizePerThread[i];
}

auto res = packLLElements(loc, typeConverter, ret, rewriter, dTensorTy);
auto has = getValueTableFromStructFMA(
llA, {perThreadShape[0], perThreadShape[1], K},
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, order);
auto hbs = getValueTableFromStructFMA(
llB, {perThreadShape[0], K, perThreadShape[2]},
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, order);

SmallVector<Value> acc = cc;

for (unsigned b = 0; b < perThreadShape[0]; ++b)
for (unsigned m = 0; m < perThreadShape[1]; ++m)
for (unsigned n = 0; n < perThreadShape[2]; ++n) {
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
unsigned linearAccumIdx =
linearize(multiDimAccumIdx, perThreadShape, order);
for (unsigned k = 0; k < K; ++k) {
Type tgtTy = acc[linearAccumIdx].getType();
Value opA = has[{b, m, k}];
Value opB = hbs[{b, n, k}];
assert(opA.getType() == tgtTy);
assert(opB.getType() == tgtTy);
llvm::TypeSwitch<Type>(tgtTy)
.Case<FloatType>([&](auto) {
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
loc, opA, opB, acc[linearAccumIdx]);
})
.Case<IntegerType>([&](auto) {
acc[linearAccumIdx] = rewriter.create<LLVM::AddOp>(
loc, rewriter.create<LLVM::MulOp>(loc, opA, opB),
acc[linearAccumIdx]);
});
}
}

auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
rewriter.replaceOp(op, res);

return success();
Expand Down

0 comments on commit 9bd7bd4

Please sign in to comment.