From acf775287421efb853e92edc859d5746f53e2df4 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Mon, 15 Jan 2024 13:49:57 -0500 Subject: [PATCH] Fix `tt.dot` codegen when input operands are integers (#204) Fix codegen for `tt.dot` so that: - operands that do not have the same type as the result type are promoted correctly - FMA instruction is generated only when the operands have floating point type - LLVM mul + add instructions are generated when the operands have integer type Fixes issue #191. --------- Signed-off-by: Tiotto, Ettore --- .../TritonGPUToLLVM/DotOpToLLVM/FMA.cpp | 86 +++++++++++++------ .../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 10 +-- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 21 ++++- python/test/unit/language/test_line_info.py | 4 +- 4 files changed, 83 insertions(+), 38 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp index 2833a8c60e..ab7a46cfc1 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -33,26 +33,54 @@ static Value convertIfRequired(Value val, Type tgtTy, Location loc, if (valTy == tgtTy) return val; - assert(isa(tgtTy) && valTy.isIntOrFloat() && + assert(tgtTy.isIntOrFloat() && valTy.isIntOrFloat() && "Unexpected tgtTy or valTy types"); - const unsigned tgtBitWidth = tgtTy.getIntOrFloatBitWidth(), - valBitWidth = valTy.getIntOrFloatBitWidth(); - - return llvm::TypeSwitch(valTy) - .Case([&](FloatType ty) { - Operation *castOp = - (valBitWidth <= tgtBitWidth) - ? rewriter.create(loc, tgtTy, val) - : rewriter.create(loc, tgtTy, val); - return castOp->getResult(0); - }) - .Case([&](IntegerType ty) { - Operation *castOp = - (ty.isSigned() || ty.isSignless()) - ? rewriter.create(loc, tgtTy, val) - : rewriter.create(loc, tgtTy, val); - return castOp->getResult(0); - }); + + auto convertToFloat = [&](Type valTy, FloatType tgtTy) -> Value { + unsigned tgtBitWidth = tgtTy.getIntOrFloatBitWidth(), + valBitWidth = valTy.getIntOrFloatBitWidth(); + + return llvm::TypeSwitch(valTy) + .Case([&](FloatType ty) { + Operation *castOp = + (valBitWidth <= tgtBitWidth) + ? rewriter.create(loc, tgtTy, val) + : rewriter.create(loc, tgtTy, val); + return castOp->getResult(0); + }) + .Case([&](IntegerType ty) { + Operation *castOp = + (ty.isSigned() || ty.isSignless()) + ? rewriter.create(loc, tgtTy, val) + : rewriter.create(loc, tgtTy, val); + return castOp->getResult(0); + }); + }; + + auto convertToInteger = [&](Type valTy, IntegerType tgtTy) -> Value { + unsigned tgtBitWidth = tgtTy.getIntOrFloatBitWidth(), + valBitWidth = valTy.getIntOrFloatBitWidth(); + + return llvm::TypeSwitch(valTy) + .Case([&](FloatType ty) { + Operation *castOp = + (tgtTy.isSigned() || tgtTy.isSignless()) + ? rewriter.create(loc, tgtTy, val) + : rewriter.create(loc, tgtTy, val); + return castOp->getResult(0); + }) + .Case([&](IntegerType ty) { + Operation *castOp = + (valBitWidth <= tgtBitWidth) + ? rewriter.create(loc, tgtTy, val) + : rewriter.create(loc, tgtTy, val); + return castOp->getResult(0); + }); + }; + + return llvm::TypeSwitch(tgtTy) + .Case([&](auto ty) { return convertToFloat(valTy, ty); }) + .Case([&](auto ty) { return convertToInteger(valTy, ty); }); } LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, @@ -118,11 +146,21 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, int z = isCRow ? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx : nIdx * M / mShapePerCTATile * nSizePerThread + mIdx; - Value opA = convertIfRequired(has[{m + mm, k}], ret[z].getType(), - loc, rewriter); - Value opB = convertIfRequired(hbs[{n + nn, k}], ret[z].getType(), - loc, rewriter); - ret[z] = rewriter.create(loc, opA, opB, ret[z]); + Type tgtTy = ret[z].getType(); + Value opA = + convertIfRequired(has[{m + mm, k}], tgtTy, loc, rewriter); + Value opB = + convertIfRequired(hbs[{n + nn, k}], tgtTy, loc, rewriter); + + llvm::TypeSwitch(tgtTy) + .Case([&](auto) { + ret[z] = + rewriter.create(loc, opA, opB, ret[z]); + }) + .Case([&](auto) { + ret[z] = rewriter.create( + loc, rewriter.create(loc, opA, opB), ret[z]); + }); } } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 5463997d51..c9f328d497 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -390,12 +390,12 @@ class TritonLLVMConversionTarget : public ConversionTarget { switch (target) { case Target::NVVM: addLegalDialect(); + addLegalDialect(); break; case Target::GENX: addLegalDialect(); break; } - addLegalDialect(); addIllegalDialect(); addIllegalDialect(); addIllegalDialect(); @@ -905,14 +905,6 @@ struct ConvertTritonGPUToLLVM } }); } - - static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, - Type promotedType) { - Type tensorPromotedType = - operand.getType().cast().cloneWith(std::nullopt, - promotedType); - return builder.create(loc, tensorPromotedType, operand); - } }; } // anonymous namespace diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 890fd96bd8..c940ba7132 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -6,6 +6,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include @@ -337,10 +338,26 @@ class BlockedToMMA : public mlir::RewritePattern { static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, Type promotedType) { - Type tensorPromotedType = + auto tensorPromotedType = operand.getType().cast().cloneWith(std::nullopt, promotedType); - return builder.create(loc, tensorPromotedType, operand); + Type elemType = tensorPromotedType.getElementType(); + return llvm::TypeSwitch(elemType) + .Case([&](auto) { + return builder.create(loc, tensorPromotedType, operand); + }) + .Case([&](auto) { + unsigned tgtBitWidth = elemType.getIntOrFloatBitWidth(), + valBitWidth = operand.getType() + .cast() + .getElementTypeBitWidth(); + Operation *castOp = (valBitWidth <= tgtBitWidth) + ? builder.create( + loc, tensorPromotedType, operand) + : builder.create( + loc, tensorPromotedType, operand); + return castOp->getResult(0); + }); } // promote operands of dot op if the existing combination is not natively diff --git a/python/test/unit/language/test_line_info.py b/python/test/unit/language/test_line_info.py index bf2d84bb31..5992fe44e7 100644 --- a/python/test/unit/language/test_line_info.py +++ b/python/test/unit/language/test_line_info.py @@ -125,9 +125,7 @@ def check_file_lines(file_lines, file_name, lineno, should_contain=True): return not should_contain -# TODO: dot_combine fails to compile. -# func_types = ["single", "call", "call_noinline", "multi_files", "autotune", "dot_combine"] -func_types = ["single", "call", "call_noinline", "multi_files", "autotune"] +func_types = ["single", "call", "call_noinline", "multi_files", "autotune", "dot_combine"] @pytest.mark.parametrize("func", func_types) def test_line_info(func: str):