Skip to content

Commit

Permalink
add linalg::softmax to cinm conversion pass
Browse files Browse the repository at this point in the history
  • Loading branch information
ge0mk committed Sep 16, 2024
1 parent 2a23a9d commit adc708b
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 0 deletions.
10 changes: 10 additions & 0 deletions cinnamon/include/cinm-mlir/Dialect/Cinm/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,14 @@ def CinmTilingPass: Pass<"cinm-tiling"> {
];
}

def SoftmaxToCinmPass: Pass<"softmax-to-cinm"> {
let summary = "converts the linalg::softmax op to cinm";
let description = [{}];

let dependentDialects = [
"mlir::LLVM::LLVMDialect",
"mlir::linalg::LinalgDialect",
];
}

#endif // CINM_TRANSFORM_PASSES
1 change: 1 addition & 0 deletions cinnamon/lib/Dialect/Cinm/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(CinmTransforms
SoftmaxToCinmPass.cpp
TilingPass.cpp

DEPENDS
Expand Down
81 changes: 81 additions & 0 deletions cinnamon/lib/Dialect/Cinm/Transforms/SoftmaxToCinmPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#include "cinm-mlir/Dialect/Cinm/IR/CinmAttributes.h"
#include "cinm-mlir/Dialect/Cinm/IR/CinmBase.h"
#include "cinm-mlir/Dialect/Cinm/IR/CinmOps.h"
#include "cinm-mlir/Dialect/Cinm/Interfaces/TilingInterface.h"
#include "cinm-mlir/Dialect/Cinm/Transforms/Passes.h"

#include <cmath>
#include <cstdint>
#include <llvm/ADT/APFloat.h>
#include <llvm/ADT/SmallVector.h>
#include <mlir/Conversion/AffineToStandard/AffineToStandard.h>
#include <mlir/Conversion/LLVMCommon/TypeConverter.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/BuiltinTypeInterfaces.h>
#include <mlir/IR/BuiltinTypes.h>

namespace mlir::cinm {

//===- Generated passes ---------------------------------------------------===//

#define GEN_PASS_DEF_SOFTMAXTOCINMPASS
#include "cinm-mlir/Dialect/Cinm/Transforms/Passes.h.inc"

//===----------------------------------------------------------------------===//

struct SoftmaxToCinmPattern : public OpConversionPattern<linalg::SoftmaxOp> {
using OpConversionPattern<linalg::SoftmaxOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(linalg::SoftmaxOp op,
OpConversionPattern<linalg::SoftmaxOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto loc = op.getLoc();
const auto input = op.getInput();
const ShapedType inputType = input.getType();
cinm::ComputeOp computeOp =
rewriter.replaceOpWithNewOp<cinm::ComputeOp>(op, op.getResultTypes());

rewriter.setInsertionPointToEnd(&computeOp.getBody().emplaceBlock());
const Value max = rewriter.create<cinm::ReduceOp>(
loc, inputType.getElementType(), ReduceMethod::MAX, input);
const Value t = rewriter.create<cinm::SubsOp>(loc, input, max);
const Value init = rewriter.create<tensor::EmptyOp>(
loc, inputType.getShape(), inputType.getElementType());
const Value e =
rewriter
.create<linalg::ExpOp>(
loc,
TypeRange{RankedTensorType::get(inputType.getShape(),
inputType.getElementType())},
ValueRange{t}, ValueRange{init})
.getResult(0);
const Value s = rewriter.create<cinm::ReduceOp>(
loc, inputType.getElementType(), ReduceMethod::ADD, e);
const Value result = rewriter.create<cinm::DivsOp>(loc, e, s);
rewriter.create<cinm::YieldOp>(loc, ValueRange{result});
return success();
}
};

struct SoftmaxToCinmPass
: public impl::SoftmaxToCinmPassBase<SoftmaxToCinmPass> {
using Base::Base;

void runOnOperation() final {
RewritePatternSet patterns(&getContext());
patterns.insert<SoftmaxToCinmPattern>(&getContext());
ConversionTarget target(getContext());
target.markUnknownOpDynamicallyLegal([](...) { return true; });
target.addIllegalOp<linalg::SoftmaxOp>();

if (applyPartialConversion(getOperation(), target, std::move(patterns))
.failed())
signalPassFailure();
}
};

} // namespace mlir::cinm

0 comments on commit adc708b

Please sign in to comment.