diff --git a/cinnamon/include/cinm-mlir/Dialect/Cinm/Transforms/Passes.td b/cinnamon/include/cinm-mlir/Dialect/Cinm/Transforms/Passes.td index e024c56..db43824 100644 --- a/cinnamon/include/cinm-mlir/Dialect/Cinm/Transforms/Passes.td +++ b/cinnamon/include/cinm-mlir/Dialect/Cinm/Transforms/Passes.td @@ -27,4 +27,14 @@ def CinmTilingPass: Pass<"cinm-tiling"> { ]; } +def SoftmaxToCinmPass: Pass<"softmax-to-cinm"> { + let summary = "converts the linalg::softmax op to cinm"; + let description = [{}]; + + let dependentDialects = [ + "mlir::LLVM::LLVMDialect", + "mlir::linalg::LinalgDialect", + ]; +} + #endif // CINM_TRANSFORM_PASSES diff --git a/cinnamon/lib/Dialect/Cinm/Transforms/CMakeLists.txt b/cinnamon/lib/Dialect/Cinm/Transforms/CMakeLists.txt index c1a27ce..8f66e85 100644 --- a/cinnamon/lib/Dialect/Cinm/Transforms/CMakeLists.txt +++ b/cinnamon/lib/Dialect/Cinm/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(CinmTransforms + SoftmaxToCinmPass.cpp TilingPass.cpp DEPENDS diff --git a/cinnamon/lib/Dialect/Cinm/Transforms/SoftmaxToCinmPass.cpp b/cinnamon/lib/Dialect/Cinm/Transforms/SoftmaxToCinmPass.cpp new file mode 100644 index 0000000..9860eb9 --- /dev/null +++ b/cinnamon/lib/Dialect/Cinm/Transforms/SoftmaxToCinmPass.cpp @@ -0,0 +1,81 @@ +#include "cinm-mlir/Dialect/Cinm/IR/CinmAttributes.h" +#include "cinm-mlir/Dialect/Cinm/IR/CinmBase.h" +#include "cinm-mlir/Dialect/Cinm/IR/CinmOps.h" +#include "cinm-mlir/Dialect/Cinm/Interfaces/TilingInterface.h" +#include "cinm-mlir/Dialect/Cinm/Transforms/Passes.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlir::cinm { + +//===- Generated passes ---------------------------------------------------===// + +#define GEN_PASS_DEF_SOFTMAXTOCINMPASS +#include "cinm-mlir/Dialect/Cinm/Transforms/Passes.h.inc" + +//===----------------------------------------------------------------------===// + +struct SoftmaxToCinmPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(linalg::SoftmaxOp op, + OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto loc = op.getLoc(); + const auto input = op.getInput(); + const ShapedType inputType = input.getType(); + cinm::ComputeOp computeOp = + rewriter.replaceOpWithNewOp(op, op.getResultTypes()); + + rewriter.setInsertionPointToEnd(&computeOp.getBody().emplaceBlock()); + const Value max = rewriter.create( + loc, inputType.getElementType(), ReduceMethod::MAX, input); + const Value t = rewriter.create(loc, input, max); + const Value init = rewriter.create( + loc, inputType.getShape(), inputType.getElementType()); + const Value e = + rewriter + .create( + loc, + TypeRange{RankedTensorType::get(inputType.getShape(), + inputType.getElementType())}, + ValueRange{t}, ValueRange{init}) + .getResult(0); + const Value s = rewriter.create( + loc, inputType.getElementType(), ReduceMethod::ADD, e); + const Value result = rewriter.create(loc, e, s); + rewriter.create(loc, ValueRange{result}); + return success(); + } +}; + +struct SoftmaxToCinmPass + : public impl::SoftmaxToCinmPassBase { + using Base::Base; + + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + ConversionTarget target(getContext()); + target.markUnknownOpDynamicallyLegal([](...) { return true; }); + target.addIllegalOp(); + + if (applyPartialConversion(getOperation(), target, std::move(patterns)) + .failed()) + signalPassFailure(); + } +}; + +} // namespace mlir::cinm