Skip to content

Commit

Permalink
Merge pull request #17 from ge0mk/float_fix
Browse files Browse the repository at this point in the history
add support for floating point operations
  • Loading branch information
ge0mk authored Sep 28, 2024
2 parents dcd57f8 + a27d662 commit 8508233
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 52 deletions.
88 changes: 57 additions & 31 deletions cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,38 +711,64 @@ struct ConvertCinmReduceToCnm : public OpConversionPattern<cinm::ReduceOp> {
op.getResult().getType(),
builder.getZeroAttr(op.getResult().getType()));

const bool isFloatOp = op.getType()
.cast<ShapedType>()
.getElementType()
.dyn_cast<FloatType>() != nullptr;

llvm::SmallVector<Value, 1> newResults;
if (convertCinmToCnm(builder, op, workgroup.getResult(), computeBlock, {},
adaptor.getOperands(), ValueRange{outputInit},
op->getResults(), newResults,
[&](ImplicitLocOpBuilder &builder, ValueRange inputs,
ValueRange outputs) {
builder.create<linalg::ReduceOp>(
inputs, outputs, ArrayRef<int64_t>{0},
[&](OpBuilder &builder, Location loc,
ValueRange inputs) -> void {
Value result;
switch (op.getMethod()) {
case mlir::cinm::ReduceMethod::ADD: {
result = builder.create<arith::AddIOp>(
loc, inputs[0], inputs[1]);
} break;
case mlir::cinm::ReduceMethod::MUL: {
result = builder.create<arith::MulIOp>(
loc, inputs[0], inputs[1]);
} break;
case mlir::cinm::ReduceMethod::MAX: {
result = builder.create<arith::MaxSIOp>(
loc, inputs[0], inputs[1]);
} break;
case mlir::cinm::ReduceMethod::MIN: {
result = builder.create<arith::MinSIOp>(
loc, inputs[0], inputs[1]);
} break;
}
builder.create<linalg::YieldOp>(loc, result);
});
})
if (convertCinmToCnm(
builder, op, workgroup.getResult(), computeBlock, {},
adaptor.getOperands(), ValueRange{outputInit}, op->getResults(),
newResults,
[&](ImplicitLocOpBuilder &builder, ValueRange inputs,
ValueRange outputs) {
builder.create<linalg::ReduceOp>(
inputs, outputs, ArrayRef<int64_t>{0},
[&](OpBuilder &builder, Location loc,
ValueRange inputs) -> void {
Value result;
switch (op.getMethod()) {
case mlir::cinm::ReduceMethod::ADD: {
if (isFloatOp) {
result = builder.create<arith::AddFOp>(loc, inputs[0],
inputs[1]);
} else {
result = builder.create<arith::AddIOp>(loc, inputs[0],
inputs[1]);
}
} break;
case mlir::cinm::ReduceMethod::MUL: {
if (isFloatOp) {
result = builder.create<arith::MulFOp>(loc, inputs[0],
inputs[1]);
} else {
result = builder.create<arith::MulIOp>(loc, inputs[0],
inputs[1]);
}
} break;
case mlir::cinm::ReduceMethod::MAX: {
if (isFloatOp) {
result = builder.create<arith::MaximumFOp>(
loc, inputs[0], inputs[1]);
} else {
result = builder.create<arith::MaxSIOp>(loc, inputs[0],
inputs[1]);
}
} break;
case mlir::cinm::ReduceMethod::MIN: {
if (isFloatOp) {
result = builder.create<arith::MinimumFOp>(
loc, inputs[0], inputs[1]);
} else {
result = builder.create<arith::MinSIOp>(loc, inputs[0],
inputs[1]);
}
} break;
}
builder.create<linalg::YieldOp>(loc, result);
});
})
.failed()) {
return failure();
}
Expand Down
13 changes: 12 additions & 1 deletion cinnamon/lib/Dialect/Cinm/IR/CinmTilingImplementations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,18 @@ TilingResult2 GemmOp::convertToTiledOps(OpBuilder &builder,
const ValueRange resultDynamicOffsets = parIndices;

auto reductionAccTy = RankedTensorType::get({p0, p1}, eltTy);
auto zeros = DenseIntElementsAttr::get(reductionAccTy, {0});
DenseElementsAttr zeros;
if (auto floatType =
reductionAccTy.getElementType().dyn_cast<FloatType>()) {
zeros = DenseElementsAttr::get(
reductionAccTy,
{APFloat::getZero(floatType.getFloatSemantics())});
} else {
zeros = DenseElementsAttr::get(
reductionAccTy,
{APInt::getZero(reductionAccTy.getElementTypeBitWidth())});
}

Value cst0 = builder.create<arith::ConstantOp>(loc, zeros);

// this is the reduction loop
Expand Down
74 changes: 54 additions & 20 deletions cinnamon/lib/Dialect/Cinm/Interfaces/TilingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <cstdint>
#include <functional>
#include <limits>
#include <llvm/ADT/APFloat.h>
#include <llvm/ADT/APInt.h>
#include <llvm/ADT/SmallVector.h>
#include <mlir/IR/BuiltinTypeInterfaces.h>
#include <numeric>
Expand Down Expand Up @@ -207,48 +209,80 @@ Value createVectorReduce(OpBuilder &builder, Location loc, Value vector,
return builder.create<tensor::ExtractOp>(loc, sum.getResult(0), ValueRange{});
}

// todo these impls should work also for float types

Value createVectorReduceAdd(OpBuilder &builder, Location loc, Value vector,
int64_t clusterSize) {
const Type elementType =
vector.getType().cast<RankedTensorType>().getElementType();
const Value init = builder.create<arith::ConstantOp>(
loc, builder.getIntegerAttr(elementType, 0));
return createVectorReduce<arith::AddIOp>(builder, loc, vector, init,
clusterSize);
if (FloatType floatType = elementType.dyn_cast<FloatType>()) {
const TypedAttr zeroAttr = FloatAttr::get(
elementType, APFloat::getZero(floatType.getFloatSemantics()));
const Value init = builder.create<arith::ConstantOp>(loc, zeroAttr);
return createVectorReduce<arith::AddFOp>(builder, loc, vector, init,
clusterSize);
} else {
const TypedAttr zeroAttr = IntegerAttr::get(elementType, 0);
const Value init = builder.create<arith::ConstantOp>(loc, zeroAttr);
return createVectorReduce<arith::AddIOp>(builder, loc, vector, init,
clusterSize);
}
}

Value createVectorReduceMul(OpBuilder &builder, Location loc, Value vector,
int64_t clusterSize) {
const Type elementType =
vector.getType().cast<RankedTensorType>().getElementType();
const Value init = builder.create<arith::ConstantOp>(
loc, builder.getIntegerAttr(elementType, 1));
return createVectorReduce<arith::MulIOp>(builder, loc, vector, init,
clusterSize);
if (FloatType floatType = elementType.dyn_cast<FloatType>()) {
const TypedAttr oneAttr =
FloatAttr::get(elementType, APFloat(floatType.getFloatSemantics(), 1));
const Value init = builder.create<arith::ConstantOp>(loc, oneAttr);
return createVectorReduce<arith::MulFOp>(builder, loc, vector, init,
clusterSize);
} else {
const TypedAttr oneAttr = IntegerAttr::get(elementType, 1);
const Value init = builder.create<arith::ConstantOp>(loc, oneAttr);
return createVectorReduce<arith::MulIOp>(builder, loc, vector, init,
clusterSize);
}
}

Value createVectorReduceMin(OpBuilder &builder, Location loc, Value vector,
int64_t clusterSize) {
const Type elementType =
vector.getType().cast<RankedTensorType>().getElementType();
const Value init = builder.create<arith::ConstantOp>(
loc, builder.getIntegerAttr(elementType,
std::numeric_limits<uint64_t>::max()));
return createVectorReduce<arith::MinUIOp>(builder, loc, vector, init,
clusterSize);
if (FloatType floatType = elementType.dyn_cast<FloatType>()) {
const TypedAttr maxValAttr = FloatAttr::get(
elementType, APFloat::getInf(floatType.getFloatSemantics()));
const Value init = builder.create<arith::ConstantOp>(loc, maxValAttr);
return createVectorReduce<arith::MinimumFOp>(builder, loc, vector, init,
clusterSize);
} else {
const TypedAttr maxValAttr = IntegerAttr::get(
elementType,
APInt::getSignedMaxValue(elementType.getIntOrFloatBitWidth()));
const Value init = builder.create<arith::ConstantOp>(loc, maxValAttr);
return createVectorReduce<arith::MinSIOp>(builder, loc, vector, init,
clusterSize);
}
}

Value createVectorReduceMax(OpBuilder &builder, Location loc, Value vector,
int64_t clusterSize) {
const Type elementType =
vector.getType().cast<RankedTensorType>().getElementType();
const Value init = builder.create<arith::ConstantOp>(
loc, builder.getIntegerAttr(elementType,
std::numeric_limits<uint64_t>::min()));
return createVectorReduce<arith::MaxUIOp>(builder, loc, vector, init,
clusterSize);
if (FloatType floatType = elementType.dyn_cast<FloatType>()) {
const TypedAttr minValAttr = FloatAttr::get(
elementType, -APFloat::getInf(floatType.getFloatSemantics()));
const Value init = builder.create<arith::ConstantOp>(loc, minValAttr);
return createVectorReduce<arith::MaximumFOp>(builder, loc, vector, init,
clusterSize);
} else {
const TypedAttr minValAttr = IntegerAttr::get(
elementType,
APInt::getSignedMinValue(elementType.getIntOrFloatBitWidth()));
const Value init = builder.create<arith::ConstantOp>(loc, minValAttr);
return createVectorReduce<arith::MaxSIOp>(builder, loc, vector, init,
clusterSize);
}
}

} // namespace mlir::cinm

0 comments on commit 8508233

Please sign in to comment.