Skip to content

Commit

Permalink
Revert "merge matmul pattern"
Browse files Browse the repository at this point in the history
This reverts commit fb13cd2.

Change-Id: Ie03e675ad7bf90a487005728aee14546785155bb
  • Loading branch information
charlesxzb committed Nov 30, 2024
1 parent ce5a0aa commit c44a4a4
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 353 deletions.
16 changes: 1 addition & 15 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,7 @@
"__functional_base": "cpp",
"__functional_base_03": "cpp",
"__tuple": "cpp",
"__target_macros": "cpp",
"any": "cpp",
"barrier": "cpp",
"charconv": "cpp",
"coroutine": "cpp",
"csetjmp": "cpp",
"csignal": "cpp",
"cuchar": "cpp",
"netfwd": "cpp",
"source_location": "cpp",
"rope": "cpp",
"slist": "cpp",
"latch": "cpp",
"scoped_allocator": "cpp",
"syncstream": "cpp"
"__target_macros": "cpp"
},
"files.autoSave": "afterDelay",
"files.trimTrailingWhitespace": true,
Expand Down
3 changes: 0 additions & 3 deletions include/tpu_mlir/Support/Module.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,6 @@ bool startsWith(const std::string &fullString,
bool endsWith(const std::string &fullString, const std::string &suffix);
bool IsRightMat(Value v);

bool isOpSameCalc(Operation *op0, Operation *op1);
bool isOpSameCalc(const std::vector<Operation *> &ops);

bool isInMatMulGrpOp(Operation *op);
} // namespace module

Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/Tpu/Interfaces/Common/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ matmul_attr_t tpu::MatMulOp::parseParam() {
for (int i = 1; i < a_dims - 2; i++) {
a_temp *= a_s[i];
}
// consider left_transpose
p.M = a_s[o_dims - 2 + p.left_transpose] * a_temp;
p.M = a_s[o_dims - 2] * a_temp;
}
}
return p;
Expand Down
50 changes: 47 additions & 3 deletions lib/Dialect/Tpu/Transforms/CoreParallel/CoreMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,50 @@
namespace tpu_mlir {
namespace tpu {

static bool isOpSameCalc(Operation *op0, Operation *op1) {
auto compare = [&](mlir::ValueRange left, mlir::ValueRange right) -> bool {
for (auto it : llvm::zip(left, right)) {
auto left = std::get<0>(it);
auto right = std::get<1>(it);
if (module::isNone(left) || module::isNone(right)) {
continue;
}
auto l_s = module::getShape(left);
auto r_s = module::getShape(right);
if (l_s != r_s) {
return false;
}
}
return true;
};
if (op0 == op1) {
// can't be the same op
return false;
}
if (op0->getName() != op1->getName()) {
return false;
}
if (false == compare(op0->getOperands(), op1->getOperands())) {
return false;
}
if (false == compare(op0->getResults(), op1->getResults())) {
return false;
}
return true;
}

static bool isOpSameCalc(const std::vector<Operation *> &ops) {
if (ops.size() < 2) {
return false;
}
for (int i = 1; i < ops.size(); i++) {
if (!isOpSameCalc(ops[0], ops[i])) {
return false;
}
}
return true;
}

static Operation *cloneOp(PatternRewriter &rewriter, Operation *op,
llvm::ArrayRef<int64_t> new_shape,
llvm::StringRef suffix) {
Expand Down Expand Up @@ -304,7 +348,7 @@ static void common_match(PatternRewriter &rewriter,
}
next_ops.push_back(next_op);
}
if (next_ops.size() != num_ops || !module::isOpSameCalc(next_ops)) {
if (next_ops.size() != num_ops || !isOpSameCalc(next_ops)) {
next_is_same = false;
continue;
}
Expand Down Expand Up @@ -402,7 +446,7 @@ struct FuncInputMatch : public OpRewriterPatternEx<FuncOp> {
if (module::isOpInBlock(right_op)) {
continue;
}
if (module::isOpSameCalc(left_op, right_op)) {
if (isOpSameCalc(left_op, right_op)) {
ops.push_back(right_op);
}
if (ops.size() == num_core) {
Expand Down Expand Up @@ -527,7 +571,7 @@ struct CommonMatch : public OpRewriterPatternEx3 {
if (module::isOpInBlock(right_op)) {
continue;
}
if (module::isOpSameCalc(left_op, right_op)) {
if (isOpSameCalc(left_op, right_op)) {
ops.push_back(right_op);
}
if (ops.size() == num_core) {
Expand Down
Loading

0 comments on commit c44a4a4

Please sign in to comment.