Skip to content

Commit

Permalink
Fix tt.dot codegen when input operands are integers (#204)
Browse files Browse the repository at this point in the history
Fix codegen for `tt.dot` so that:
- operands that do not have the same type as the result type are
promoted correctly
- FMA instruction is generated only when the operands have floating
point type
- LLVM mul + add instructions are generated when the operands have
integer type

Fixes issue #191.

---------

Signed-off-by: Tiotto, Ettore <[email protected]>
  • Loading branch information
etiotto authored Jan 15, 2024
1 parent 7b680e6 commit acf7752
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 38 deletions.
86 changes: 62 additions & 24 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,54 @@ static Value convertIfRequired(Value val, Type tgtTy, Location loc,
if (valTy == tgtTy)
return val;

assert(isa<FloatType>(tgtTy) && valTy.isIntOrFloat() &&
assert(tgtTy.isIntOrFloat() && valTy.isIntOrFloat() &&
"Unexpected tgtTy or valTy types");
const unsigned tgtBitWidth = tgtTy.getIntOrFloatBitWidth(),
valBitWidth = valTy.getIntOrFloatBitWidth();

return llvm::TypeSwitch<Type, Value>(valTy)
.Case<FloatType>([&](FloatType ty) {
Operation *castOp =
(valBitWidth <= tgtBitWidth)
? rewriter.create<LLVM::FPExtOp>(loc, tgtTy, val)
: rewriter.create<LLVM::FPTruncOp>(loc, tgtTy, val);
return castOp->getResult(0);
})
.Case<IntegerType>([&](IntegerType ty) {
Operation *castOp =
(ty.isSigned() || ty.isSignless())
? rewriter.create<LLVM::SIToFPOp>(loc, tgtTy, val)
: rewriter.create<LLVM::UIToFPOp>(loc, tgtTy, val);
return castOp->getResult(0);
});

auto convertToFloat = [&](Type valTy, FloatType tgtTy) -> Value {
unsigned tgtBitWidth = tgtTy.getIntOrFloatBitWidth(),
valBitWidth = valTy.getIntOrFloatBitWidth();

return llvm::TypeSwitch<Type, Value>(valTy)
.Case<FloatType>([&](FloatType ty) {
Operation *castOp =
(valBitWidth <= tgtBitWidth)
? rewriter.create<LLVM::FPExtOp>(loc, tgtTy, val)
: rewriter.create<LLVM::FPTruncOp>(loc, tgtTy, val);
return castOp->getResult(0);
})
.Case<IntegerType>([&](IntegerType ty) {
Operation *castOp =
(ty.isSigned() || ty.isSignless())
? rewriter.create<LLVM::SIToFPOp>(loc, tgtTy, val)
: rewriter.create<LLVM::UIToFPOp>(loc, tgtTy, val);
return castOp->getResult(0);
});
};

auto convertToInteger = [&](Type valTy, IntegerType tgtTy) -> Value {
unsigned tgtBitWidth = tgtTy.getIntOrFloatBitWidth(),
valBitWidth = valTy.getIntOrFloatBitWidth();

return llvm::TypeSwitch<Type, Value>(valTy)
.Case<FloatType>([&](FloatType ty) {
Operation *castOp =
(tgtTy.isSigned() || tgtTy.isSignless())
? rewriter.create<LLVM::FPToSIOp>(loc, tgtTy, val)
: rewriter.create<LLVM::FPToUIOp>(loc, tgtTy, val);
return castOp->getResult(0);
})
.Case<IntegerType>([&](IntegerType ty) {
Operation *castOp =
(valBitWidth <= tgtBitWidth)
? rewriter.create<LLVM::SExtOp>(loc, tgtTy, val)
: rewriter.create<LLVM::TruncOp>(loc, tgtTy, val);
return castOp->getResult(0);
});
};

return llvm::TypeSwitch<Type, Value>(tgtTy)
.Case<FloatType>([&](auto ty) { return convertToFloat(valTy, ty); })
.Case<IntegerType>([&](auto ty) { return convertToInteger(valTy, ty); });
}

LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
Expand Down Expand Up @@ -118,11 +146,21 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
int z = isCRow
? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx
: nIdx * M / mShapePerCTATile * nSizePerThread + mIdx;
Value opA = convertIfRequired(has[{m + mm, k}], ret[z].getType(),
loc, rewriter);
Value opB = convertIfRequired(hbs[{n + nn, k}], ret[z].getType(),
loc, rewriter);
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, opA, opB, ret[z]);
Type tgtTy = ret[z].getType();
Value opA =
convertIfRequired(has[{m + mm, k}], tgtTy, loc, rewriter);
Value opB =
convertIfRequired(hbs[{n + nn, k}], tgtTy, loc, rewriter);

llvm::TypeSwitch<Type>(tgtTy)
.Case<FloatType>([&](auto) {
ret[z] =
rewriter.create<LLVM::FMulAddOp>(loc, opA, opB, ret[z]);
})
.Case<IntegerType>([&](auto) {
ret[z] = rewriter.create<LLVM::AddOp>(
loc, rewriter.create<LLVM::MulOp>(loc, opA, opB), ret[z]);
});
}
}

Expand Down
10 changes: 1 addition & 9 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,12 @@ class TritonLLVMConversionTarget : public ConversionTarget {
switch (target) {
case Target::NVVM:
addLegalDialect<NVVM::NVVMDialect>();
addLegalDialect<mlir::triton::nvgpu::NVGPUDialect>();
break;
case Target::GENX:
addLegalDialect<GENX::GENXDialect>();
break;
}
addLegalDialect<mlir::triton::nvgpu::NVGPUDialect>();
addIllegalDialect<triton::TritonDialect>();
addIllegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalDialect<triton::nvidia_gpu::TritonNvidiaGPUDialect>();
Expand Down Expand Up @@ -905,14 +905,6 @@ struct ConvertTritonGPUToLLVM
}
});
}

static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
Type promotedType) {
Type tensorPromotedType =
operand.getType().cast<RankedTensorType>().cloneWith(std::nullopt,
promotedType);
return builder.create<triton::FpToFpOp>(loc, tensorPromotedType, operand);
}
};

} // anonymous namespace
Expand Down
21 changes: 19 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <memory>

Expand Down Expand Up @@ -337,10 +338,26 @@ class BlockedToMMA : public mlir::RewritePattern {

static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
Type promotedType) {
Type tensorPromotedType =
auto tensorPromotedType =
operand.getType().cast<RankedTensorType>().cloneWith(std::nullopt,
promotedType);
return builder.create<tt::FpToFpOp>(loc, tensorPromotedType, operand);
Type elemType = tensorPromotedType.getElementType();
return llvm::TypeSwitch<Type, Value>(elemType)
.Case<FloatType>([&](auto) {
return builder.create<tt::FpToFpOp>(loc, tensorPromotedType, operand);
})
.Case<IntegerType>([&](auto) {
unsigned tgtBitWidth = elemType.getIntOrFloatBitWidth(),
valBitWidth = operand.getType()
.cast<RankedTensorType>()
.getElementTypeBitWidth();
Operation *castOp = (valBitWidth <= tgtBitWidth)
? builder.create<arith::ExtSIOp>(
loc, tensorPromotedType, operand)
: builder.create<arith::TruncIOp>(
loc, tensorPromotedType, operand);
return castOp->getResult(0);
});
}

// promote operands of dot op if the existing combination is not natively
Expand Down
4 changes: 1 addition & 3 deletions python/test/unit/language/test_line_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@ def check_file_lines(file_lines, file_name, lineno, should_contain=True):
return not should_contain


# TODO: dot_combine fails to compile.
# func_types = ["single", "call", "call_noinline", "multi_files", "autotune", "dot_combine"]
func_types = ["single", "call", "call_noinline", "multi_files", "autotune"]
func_types = ["single", "call", "call_noinline", "multi_files", "autotune", "dot_combine"]

@pytest.mark.parametrize("func", func_types)
def test_line_info(func: str):
Expand Down

0 comments on commit acf7752

Please sign in to comment.