From a27d6624e2de483adb2d8948d1824b11efeeef83 Mon Sep 17 00:00:00 2001 From: Georg Kunze Date: Mon, 16 Sep 2024 19:20:42 +0200 Subject: [PATCH] add support for floating point operations --- .../lib/Conversion/CinmToCnm/CinmToCnm.cpp | 88 ++++++++++++------- .../Cinm/IR/CinmTilingImplementations.cpp | 13 ++- .../Cinm/Interfaces/TilingInterface.cpp | 74 +++++++++++----- 3 files changed, 123 insertions(+), 52 deletions(-) diff --git a/cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp b/cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp index d41d059..e963f27 100644 --- a/cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp +++ b/cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp @@ -711,38 +711,64 @@ struct ConvertCinmReduceToCnm : public OpConversionPattern { op.getResult().getType(), builder.getZeroAttr(op.getResult().getType())); + const bool isFloatOp = op.getType() + .cast() + .getElementType() + .dyn_cast() != nullptr; + llvm::SmallVector newResults; - if (convertCinmToCnm(builder, op, workgroup.getResult(), computeBlock, {}, - adaptor.getOperands(), ValueRange{outputInit}, - op->getResults(), newResults, - [&](ImplicitLocOpBuilder &builder, ValueRange inputs, - ValueRange outputs) { - builder.create( - inputs, outputs, ArrayRef{0}, - [&](OpBuilder &builder, Location loc, - ValueRange inputs) -> void { - Value result; - switch (op.getMethod()) { - case mlir::cinm::ReduceMethod::ADD: { - result = builder.create( - loc, inputs[0], inputs[1]); - } break; - case mlir::cinm::ReduceMethod::MUL: { - result = builder.create( - loc, inputs[0], inputs[1]); - } break; - case mlir::cinm::ReduceMethod::MAX: { - result = builder.create( - loc, inputs[0], inputs[1]); - } break; - case mlir::cinm::ReduceMethod::MIN: { - result = builder.create( - loc, inputs[0], inputs[1]); - } break; - } - builder.create(loc, result); - }); - }) + if (convertCinmToCnm( + builder, op, workgroup.getResult(), computeBlock, {}, + adaptor.getOperands(), ValueRange{outputInit}, op->getResults(), + newResults, + [&](ImplicitLocOpBuilder &builder, ValueRange inputs, + ValueRange outputs) { + builder.create( + inputs, outputs, ArrayRef{0}, + [&](OpBuilder &builder, Location loc, + ValueRange inputs) -> void { + Value result; + switch (op.getMethod()) { + case mlir::cinm::ReduceMethod::ADD: { + if (isFloatOp) { + result = builder.create(loc, inputs[0], + inputs[1]); + } else { + result = builder.create(loc, inputs[0], + inputs[1]); + } + } break; + case mlir::cinm::ReduceMethod::MUL: { + if (isFloatOp) { + result = builder.create(loc, inputs[0], + inputs[1]); + } else { + result = builder.create(loc, inputs[0], + inputs[1]); + } + } break; + case mlir::cinm::ReduceMethod::MAX: { + if (isFloatOp) { + result = builder.create( + loc, inputs[0], inputs[1]); + } else { + result = builder.create(loc, inputs[0], + inputs[1]); + } + } break; + case mlir::cinm::ReduceMethod::MIN: { + if (isFloatOp) { + result = builder.create( + loc, inputs[0], inputs[1]); + } else { + result = builder.create(loc, inputs[0], + inputs[1]); + } + } break; + } + builder.create(loc, result); + }); + }) .failed()) { return failure(); } diff --git a/cinnamon/lib/Dialect/Cinm/IR/CinmTilingImplementations.cpp b/cinnamon/lib/Dialect/Cinm/IR/CinmTilingImplementations.cpp index 4d68883..16a3541 100644 --- a/cinnamon/lib/Dialect/Cinm/IR/CinmTilingImplementations.cpp +++ b/cinnamon/lib/Dialect/Cinm/IR/CinmTilingImplementations.cpp @@ -95,7 +95,18 @@ TilingResult2 GemmOp::convertToTiledOps(OpBuilder &builder, const ValueRange resultDynamicOffsets = parIndices; auto reductionAccTy = RankedTensorType::get({p0, p1}, eltTy); - auto zeros = DenseIntElementsAttr::get(reductionAccTy, {0}); + DenseElementsAttr zeros; + if (auto floatType = + reductionAccTy.getElementType().dyn_cast()) { + zeros = DenseElementsAttr::get( + reductionAccTy, + {APFloat::getZero(floatType.getFloatSemantics())}); + } else { + zeros = DenseElementsAttr::get( + reductionAccTy, + {APInt::getZero(reductionAccTy.getElementTypeBitWidth())}); + } + Value cst0 = builder.create(loc, zeros); // this is the reduction loop diff --git a/cinnamon/lib/Dialect/Cinm/Interfaces/TilingInterface.cpp b/cinnamon/lib/Dialect/Cinm/Interfaces/TilingInterface.cpp index 05832fa..3c7e4e6 100644 --- a/cinnamon/lib/Dialect/Cinm/Interfaces/TilingInterface.cpp +++ b/cinnamon/lib/Dialect/Cinm/Interfaces/TilingInterface.cpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include #include #include @@ -207,48 +209,80 @@ Value createVectorReduce(OpBuilder &builder, Location loc, Value vector, return builder.create(loc, sum.getResult(0), ValueRange{}); } -// todo these impls should work also for float types - Value createVectorReduceAdd(OpBuilder &builder, Location loc, Value vector, int64_t clusterSize) { const Type elementType = vector.getType().cast().getElementType(); - const Value init = builder.create( - loc, builder.getIntegerAttr(elementType, 0)); - return createVectorReduce(builder, loc, vector, init, - clusterSize); + if (FloatType floatType = elementType.dyn_cast()) { + const TypedAttr zeroAttr = FloatAttr::get( + elementType, APFloat::getZero(floatType.getFloatSemantics())); + const Value init = builder.create(loc, zeroAttr); + return createVectorReduce(builder, loc, vector, init, + clusterSize); + } else { + const TypedAttr zeroAttr = IntegerAttr::get(elementType, 0); + const Value init = builder.create(loc, zeroAttr); + return createVectorReduce(builder, loc, vector, init, + clusterSize); + } } Value createVectorReduceMul(OpBuilder &builder, Location loc, Value vector, int64_t clusterSize) { const Type elementType = vector.getType().cast().getElementType(); - const Value init = builder.create( - loc, builder.getIntegerAttr(elementType, 1)); - return createVectorReduce(builder, loc, vector, init, - clusterSize); + if (FloatType floatType = elementType.dyn_cast()) { + const TypedAttr oneAttr = + FloatAttr::get(elementType, APFloat(floatType.getFloatSemantics(), 1)); + const Value init = builder.create(loc, oneAttr); + return createVectorReduce(builder, loc, vector, init, + clusterSize); + } else { + const TypedAttr oneAttr = IntegerAttr::get(elementType, 1); + const Value init = builder.create(loc, oneAttr); + return createVectorReduce(builder, loc, vector, init, + clusterSize); + } } Value createVectorReduceMin(OpBuilder &builder, Location loc, Value vector, int64_t clusterSize) { const Type elementType = vector.getType().cast().getElementType(); - const Value init = builder.create( - loc, builder.getIntegerAttr(elementType, - std::numeric_limits::max())); - return createVectorReduce(builder, loc, vector, init, - clusterSize); + if (FloatType floatType = elementType.dyn_cast()) { + const TypedAttr maxValAttr = FloatAttr::get( + elementType, APFloat::getInf(floatType.getFloatSemantics())); + const Value init = builder.create(loc, maxValAttr); + return createVectorReduce(builder, loc, vector, init, + clusterSize); + } else { + const TypedAttr maxValAttr = IntegerAttr::get( + elementType, + APInt::getSignedMaxValue(elementType.getIntOrFloatBitWidth())); + const Value init = builder.create(loc, maxValAttr); + return createVectorReduce(builder, loc, vector, init, + clusterSize); + } } Value createVectorReduceMax(OpBuilder &builder, Location loc, Value vector, int64_t clusterSize) { const Type elementType = vector.getType().cast().getElementType(); - const Value init = builder.create( - loc, builder.getIntegerAttr(elementType, - std::numeric_limits::min())); - return createVectorReduce(builder, loc, vector, init, - clusterSize); + if (FloatType floatType = elementType.dyn_cast()) { + const TypedAttr minValAttr = FloatAttr::get( + elementType, -APFloat::getInf(floatType.getFloatSemantics())); + const Value init = builder.create(loc, minValAttr); + return createVectorReduce(builder, loc, vector, init, + clusterSize); + } else { + const TypedAttr minValAttr = IntegerAttr::get( + elementType, + APInt::getSignedMinValue(elementType.getIntOrFloatBitWidth())); + const Value init = builder.create(loc, minValAttr); + return createVectorReduce(builder, loc, vector, init, + clusterSize); + } } } // namespace mlir::cinm