From c44a4a46ce4fed5cc79afede8e1b49789872ad05 Mon Sep 17 00:00:00 2001 From: "zhibin.xin" Date: Sat, 30 Nov 2024 00:17:40 -0500 Subject: [PATCH] Revert "merge matmul pattern" This reverts commit fb13cd2ec23fabee89dbbbc56af57e7cecdeb9f7. Change-Id: Ie03e675ad7bf90a487005728aee14546785155bb --- .vscode/settings.json | 16 +- include/tpu_mlir/Support/Module.h | 3 - lib/Dialect/Tpu/Interfaces/Common/MatMul.cpp | 3 +- .../Tpu/Transforms/CoreParallel/CoreMatch.cpp | 50 ++- .../ProcessorOptimize/OptimizeBM1684X.cpp | 285 ------------------ lib/Support/Module.cpp | 44 --- python/test/test_tpulang.py | 1 - 7 files changed, 49 insertions(+), 353 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 7612e0591..99b5683a9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -95,21 +95,7 @@ "__functional_base": "cpp", "__functional_base_03": "cpp", "__tuple": "cpp", - "__target_macros": "cpp", - "any": "cpp", - "barrier": "cpp", - "charconv": "cpp", - "coroutine": "cpp", - "csetjmp": "cpp", - "csignal": "cpp", - "cuchar": "cpp", - "netfwd": "cpp", - "source_location": "cpp", - "rope": "cpp", - "slist": "cpp", - "latch": "cpp", - "scoped_allocator": "cpp", - "syncstream": "cpp" + "__target_macros": "cpp" }, "files.autoSave": "afterDelay", "files.trimTrailingWhitespace": true, diff --git a/include/tpu_mlir/Support/Module.h b/include/tpu_mlir/Support/Module.h index b16a97a85..5dc4994d3 100644 --- a/include/tpu_mlir/Support/Module.h +++ b/include/tpu_mlir/Support/Module.h @@ -348,9 +348,6 @@ bool startsWith(const std::string &fullString, bool endsWith(const std::string &fullString, const std::string &suffix); bool IsRightMat(Value v); -bool isOpSameCalc(Operation *op0, Operation *op1); -bool isOpSameCalc(const std::vector &ops); - bool isInMatMulGrpOp(Operation *op); } // namespace module diff --git a/lib/Dialect/Tpu/Interfaces/Common/MatMul.cpp b/lib/Dialect/Tpu/Interfaces/Common/MatMul.cpp index ba9dc8023..d3117352d 100644 --- a/lib/Dialect/Tpu/Interfaces/Common/MatMul.cpp +++ b/lib/Dialect/Tpu/Interfaces/Common/MatMul.cpp @@ -129,8 +129,7 @@ matmul_attr_t tpu::MatMulOp::parseParam() { for (int i = 1; i < a_dims - 2; i++) { a_temp *= a_s[i]; } - // consider left_transpose - p.M = a_s[o_dims - 2 + p.left_transpose] * a_temp; + p.M = a_s[o_dims - 2] * a_temp; } } return p; diff --git a/lib/Dialect/Tpu/Transforms/CoreParallel/CoreMatch.cpp b/lib/Dialect/Tpu/Transforms/CoreParallel/CoreMatch.cpp index 1089360e6..144f993a9 100644 --- a/lib/Dialect/Tpu/Transforms/CoreParallel/CoreMatch.cpp +++ b/lib/Dialect/Tpu/Transforms/CoreParallel/CoreMatch.cpp @@ -18,6 +18,50 @@ namespace tpu_mlir { namespace tpu { +static bool isOpSameCalc(Operation *op0, Operation *op1) { + auto compare = [&](mlir::ValueRange left, mlir::ValueRange right) -> bool { + for (auto it : llvm::zip(left, right)) { + auto left = std::get<0>(it); + auto right = std::get<1>(it); + if (module::isNone(left) || module::isNone(right)) { + continue; + } + auto l_s = module::getShape(left); + auto r_s = module::getShape(right); + if (l_s != r_s) { + return false; + } + } + return true; + }; + if (op0 == op1) { + // can't be the same op + return false; + } + if (op0->getName() != op1->getName()) { + return false; + } + if (false == compare(op0->getOperands(), op1->getOperands())) { + return false; + } + if (false == compare(op0->getResults(), op1->getResults())) { + return false; + } + return true; +} + +static bool isOpSameCalc(const std::vector &ops) { + if (ops.size() < 2) { + return false; + } + for (int i = 1; i < ops.size(); i++) { + if (!isOpSameCalc(ops[0], ops[i])) { + return false; + } + } + return true; +} + static Operation *cloneOp(PatternRewriter &rewriter, Operation *op, llvm::ArrayRef new_shape, llvm::StringRef suffix) { @@ -304,7 +348,7 @@ static void common_match(PatternRewriter &rewriter, } next_ops.push_back(next_op); } - if (next_ops.size() != num_ops || !module::isOpSameCalc(next_ops)) { + if (next_ops.size() != num_ops || !isOpSameCalc(next_ops)) { next_is_same = false; continue; } @@ -402,7 +446,7 @@ struct FuncInputMatch : public OpRewriterPatternEx { if (module::isOpInBlock(right_op)) { continue; } - if (module::isOpSameCalc(left_op, right_op)) { + if (isOpSameCalc(left_op, right_op)) { ops.push_back(right_op); } if (ops.size() == num_core) { @@ -527,7 +571,7 @@ struct CommonMatch : public OpRewriterPatternEx3 { if (module::isOpInBlock(right_op)) { continue; } - if (module::isOpSameCalc(left_op, right_op)) { + if (isOpSameCalc(left_op, right_op)) { ops.push_back(right_op); } if (ops.size() == num_core) { diff --git a/lib/Dialect/Tpu/Transforms/ProcessorOptimize/OptimizeBM1684X.cpp b/lib/Dialect/Tpu/Transforms/ProcessorOptimize/OptimizeBM1684X.cpp index a7832325c..929158b81 100644 --- a/lib/Dialect/Tpu/Transforms/ProcessorOptimize/OptimizeBM1684X.cpp +++ b/lib/Dialect/Tpu/Transforms/ProcessorOptimize/OptimizeBM1684X.cpp @@ -440,288 +440,6 @@ class MatMulRemoveReshapePattern : public OpRewriterPatternEx { } }; -class MatMulMergeAddConstPattern : public OpRewriterPatternEx { -public: - // using OpRewriterPatternEx::OpRewriterPatternEx; - MatMulMergeAddConstPattern(mlir::MLIRContext *context, int benifit) - : OpRewriterPatternEx( - context, "MatMulMergeAddConstPattern", 10) {} - LogicalResult matchAndRewriteImpl(tpu::AddConstOp op, - PatternRewriter &rewriter) const override { - // temp result - if (auto matmulIpt = dyn_cast(op.getInput().getDefiningOp())) { - op.replaceAllUsesWith(matmulIpt.getOperation()); - rewriter.eraseOp(op); - return success(); - } - return failure(); - } -}; - - -/** - * - * A @ B = (B^T @ A^T)^T - * - * original input shape = (1, 1, M) - * original right shape = (1, M, N) - * original result shape = (1, 1, N) - * step1. Matmul(input, right, bias) -> Permute(Matmul(right, Permute(input), bias)) - * - * - * after apply pattern: - * new right shape = (1, N, M) - * new input shape = (1, M, 1) - * new result shape = (1, N, 1) = (1, 1, N) - */ -class MatmulUsePermutePattern : public OpRewriterPatternEx { -public: - MatmulUsePermutePattern(mlir::MLIRContext *context, int benefit) - : OpRewriterPatternEx( - context, "MatmulUsePermutePattern", benefit) {} - - LogicalResult matchAndRewriteImpl(tpu::MatMulOp op, - PatternRewriter &rewriter) const override { - Value input = op.getInput(); - std::vector sameMatmuls; - auto inputShape = module::getShape(input); - if (inputShape.size() != 3 || inputShape[0] != 1 || inputShape[1] != 1) { - return failure(); - } - - // Find all MatMulOps with the same input - for (auto user : input.getUsers()) { - if (auto matmulOp = dyn_cast(user)) { - if (matmulOp.getInput() == input) { - sameMatmuls.push_back(matmulOp); - // mmWeights.push_back(matmulOp.getRight()); - if (isa(*(matmulOp->getUsers().begin()))) { - return failure(); - } - } - } else { - return failure(); - } - } - - if (sameMatmuls.size() == 0) { - return failure(); - } - - for (auto matmul : sameMatmuls) { - auto right = matmul.getRight(); - auto rightShape = module::getShape(right); - auto newRightType = RankedTensorType::get(rightShape, module::getElementType(right)); - right.setType(newRightType); - matmul.setOperand(0, right); - } - - std::vector attrs; - std::vector order = {0, 2, 1}; - attrs.push_back(rewriter.getNamedAttr("order", rewriter.getI64ArrayAttr(order))); - rewriter.setInsertionPointAfter(input.getDefiningOp()); - auto permute_op = rewriter.create( - NameLoc::get(rewriter.getStringAttr(module::getName(input) + "_permute")), - RankedTensorType::get({inputShape[0], inputShape[2], inputShape[1]}, module::getElementType(input)), - ValueRange{input, module::getNoneOp(op)}, attrs); - - - for (auto matmul : sameMatmuls) { - matmul.setOperand(1, permute_op.getOutput()); - - - auto resultShape = module::getShape(matmul.getResult()); - // auto oriType = matmul.getResult().getType(); - matmul.getResult().setType(RankedTensorType::get({resultShape[0], resultShape[2], resultShape[1]}, module::getElementType(matmul.getResult()))); - - rewriter.setInsertionPointAfter(matmul); - - auto reshapeType = RankedTensorType::get(resultShape, module::getElementType(matmul.getResult())); - auto reshapeOp = rewriter.create( - NameLoc::get(rewriter.getStringAttr(module::getName(matmul.getOutput()) + "_reshape")), - reshapeType, - ValueRange{matmul.getOutput()}); - - matmul.getOutput().replaceAllUsesExcept(reshapeOp.getOutput(), {reshapeOp}); - } - - - return success(); - } -}; - - - -/** - * Matmul(input, right1, bias1) - \ / ... - * Matmul(input, right2, bias2) - | -> A X concat(B1..Bk) -> Slice x k -> - ... - * Matmul(input, right3, bias3) - / \ ... - * - */ -class MultipleSameLeftMatmulMergePattern : public OpRewriterPatternEx { -public: - MultipleSameLeftMatmulMergePattern(mlir::MLIRContext *context, int benefit) - : OpRewriterPatternEx( - context, "MultipleSameLeftMatmulMergePattern", benefit) {} - - LogicalResult matchAndRewriteImpl(tpu::MatMulOp op, - PatternRewriter &rewriter) const override { - auto none = module::getNoneOp(op); - if (op->getUsers().empty() || isa(*(op->getUsers().begin()))) { - return failure(); - } - - Value input = op.getOperands()[1]; - std::vector sameMatmuls; - std::vector mmWeights; - std::vector mmBiases; - - std::vector>> mmWeightOps_fp16; - std::vector>> mmBiasOps_fp16; - // Find all MatMulOps with the same input - for (auto user : input.getUsers()) { - if (auto matmulOp = dyn_cast(user)) { - if (matmulOp.getRight() == input && !op.getBias().getType().isa()) { - // TODO detect is Same Op - if(auto op = dyn_cast((matmulOp.getOperands()[0]).getDefiningOp())){ - mmWeightOps_fp16.push_back(op.read()); - mmBiasOps_fp16.push_back(cast((matmulOp.getBias()).getDefiningOp()).read()); - sameMatmuls.push_back(matmulOp); - mmWeights.push_back(matmulOp.getOperands()[0]); - mmBiases.push_back(matmulOp.getBias()); - } else { - return failure(); - } - - }else { - return failure(); - } - } else { - return failure(); - } - } - - if (sameMatmuls.size() <= 1) { - return failure(); - } - - auto inputShape = module::getShape(input); - if (inputShape[2] != 1) { - return failure(); - } - - auto weightShape = module::getShape(op.getOperands()[0]); // actual weight - // auto biasShape = module::getShape(op.getBias()); - - rewriter.setInsertionPointAfter(input.getDefiningOp()); - // step1. concat weight, get shape (size*K) x N - std::string weightName = module::getName(op.getOperands()[0]).str() + "_merge_right"; - long newweight_row_size = weightShape[1] * sameMatmuls.size(); - auto weightType = RankedTensorType::get({1, newweight_row_size, weightShape[0]}, module::getElementType(op.getOperands()[0])); - auto weight_size = weightType.getNumElements(); - auto weightCoeff = std::make_shared>(weight_size, 0); - // TOTO :use efficient copy method ; check relative order of coeff - for(int row_idx=0;row_idxat(row_idx*weightShape[0]+col_idx) = mmWeightOps_fp16[row_idx/weightShape[1]]->at(col_idx*weightShape[1]+row_idx%weightShape[1]); - } - auto wret = - module::weightFile().addTensor(weightName, (uint16_t*)weightCoeff->data(), weightType); - assert(succeeded(wret)); - auto weight_op = rewriter.create( - NameLoc::get(rewriter.getStringAttr(weightName)), - weightType, - ValueRange{}); - - // step2. concat bias, get shape (size*K) - std::string biasName = module::getName(op.getBias()).str() + "_merge_bias"; - auto biasType = RankedTensorType::get({newweight_row_size}, module::getElementType(op.getBias())); - auto bias_size = biasType.getNumElements(); - auto biasCoeff = std::make_shared>(bias_size, 0); - for(int col_idx=0;col_idxat(col_idx) = mmBiasOps_fp16[col_idx/weightShape[1]]->at(col_idx%weightShape[1]); - } - auto bret = - module::weightFile().addTensor(biasName, (uint16_t*)biasCoeff->data(), biasType); - assert(succeeded(bret)); - auto bias_op = rewriter.create( - NameLoc::get(rewriter.getStringAttr(biasName)), - biasType, - ValueRange{}); - - // step3. create new large MatMulOp, W(size*K x N) @ Ipt(1 x N x 1) + B(size*K) => R(1 x size*K) - auto newMatmulOp = rewriter.create( - op.getLoc(), - RankedTensorType::get({1, newweight_row_size, 1}, - module::getElementType(op.getOutput())), - ValueRange{weight_op.getResult(), input, none, none, none}, - op->getAttrs()); - - auto resultShape = module::getShape(newMatmulOp.getResult()); - rewriter.setInsertionPointAfter(newMatmulOp); - auto reshapeType = RankedTensorType::get({resultShape[0],resultShape[2],resultShape[1]}, module::getElementType(newMatmulOp.getResult())); - auto reshapeOp = rewriter.create( - NameLoc::get(rewriter.getStringAttr(module::getName(newMatmulOp.getOutput()) + "_new_reshape")), - reshapeType, - ValueRange{newMatmulOp.getOutput()}); - - auto add_loc = NameLoc::get( - rewriter.getStringAttr(module::getName(reshapeOp.getOperation()).str() +"_add")); - auto add_op = rewriter.create( - add_loc, reshapeType, - mlir::ValueRange{reshapeOp.getOutput(), bias_op.getOutput()}); - // step4. slice each original MatMulOp - int64_t sliceOffset = 0; - std::vector operands; - std::vector reshape_ops; - auto lastSliceOp = none.getOperation(); - for (size_t i = 0; i < sameMatmuls.size(); ++i) { - auto originalOp = sameMatmuls[i]; - auto originalShape = module::getShape(originalOp.getOutput()); - - std::vector attrs; - attrs.push_back(rewriter.getNamedAttr("axes", rewriter.getI64ArrayAttr({0, 1, 2}))); - attrs.push_back(rewriter.getNamedAttr("ends", rewriter.getI64ArrayAttr({originalShape[0],originalShape[2],sliceOffset}))); - attrs.push_back(rewriter.getNamedAttr("offset", rewriter.getI64ArrayAttr({0, 0, sliceOffset}))); - attrs.push_back(rewriter.getNamedAttr("steps", rewriter.getI64ArrayAttr({1, 1, 1}))); - rewriter.setInsertionPointAfter(add_op); - - auto slice_type = RankedTensorType::get({1, 1, originalShape[1]}, module::getElementType(add_op.getOperands()[0])); - auto sliceOp = rewriter.create( - originalOp.getLoc(), - slice_type, - ValueRange{add_op.getResult(), none, none, none, none}, - attrs); - module::setLocSuffix(sliceOp, std::to_string(i)); - lastSliceOp = sliceOp; - - auto reshape_op = dyn_cast_or_null(*originalOp.getOutput().getUsers().begin()); - reshape_op.getOutput().replaceAllUsesWith(sliceOp.getResult()); - reshape_ops.insert(reshape_ops.end(), originalOp->getUsers().begin(), originalOp->getUsers().end()); - sliceOffset += originalShape[1]; - } - for (auto op : reshape_ops) { - rewriter.eraseOp(op); - } - - for (auto op : sameMatmuls) { - rewriter.eraseOp(op); - } - - for (auto operand : mmWeights) { - rewriter.eraseOp(operand.getDefiningOp()); - } - - for (auto operand : mmBiases) { - rewriter.eraseOp(operand.getDefiningOp()); - } - - return success(); - } -}; - - - // transform group conv to normal conv, when int8/f16/bf16 && // input_c<=ic_parallel && isBM1684XFamily() class GroupConv2NormalConv : public OpRewriterPatternEx { @@ -4967,7 +4685,6 @@ void populateOptimizeBM1684XPatterns(RewritePatternSet *patterns) { PermutePadSwap, FitPermute2Hdim, ErasePermuteAroundAdd, - MatMulMergeAddConstPattern, PermuteMulconstSwap, MatMulActiveMatMulPattern, RotaryPosEmbPattern, @@ -4991,8 +4708,6 @@ void populateOptimizeBM1684XPatterns(RewritePatternSet *patterns) { patterns->add(ctx, 7); patterns->add(ctx, 3); patterns->add(ctx, 4); - patterns->add(ctx, 4); - patterns->add(ctx, 3); } } // namespace tpu diff --git a/lib/Support/Module.cpp b/lib/Support/Module.cpp index 5bef7b391..a3a5429aa 100644 --- a/lib/Support/Module.cpp +++ b/lib/Support/Module.cpp @@ -2064,50 +2064,6 @@ bool IsRightMat(Value v) { return false; } -bool isOpSameCalc(Operation *op0, Operation *op1) { - auto compare = [&](mlir::ValueRange left, mlir::ValueRange right) -> bool { - for (auto it : llvm::zip(left, right)) { - auto left = std::get<0>(it); - auto right = std::get<1>(it); - if (module::isNone(left) || module::isNone(right)) { - continue; - } - auto l_s = module::getShape(left); - auto r_s = module::getShape(right); - if (l_s != r_s) { - return false; - } - } - return true; - }; - if (op0 == op1) { - // can't be the same op - return false; - } - if (op0->getName() != op1->getName()) { - return false; - } - if (false == compare(op0->getOperands(), op1->getOperands())) { - return false; - } - if (false == compare(op0->getResults(), op1->getResults())) { - return false; - } - return true; -} - -bool isOpSameCalc(const std::vector &ops) { - if (ops.size() < 2) { - return false; - } - for (int i = 1; i < ops.size(); i++) { - if (!isOpSameCalc(ops[0], ops[i])) { - return false; - } - } - return true; -} - bool isInMatMulGrpOp(Operation *op) { if (isa(op)) { diff --git a/python/test/test_tpulang.py b/python/test/test_tpulang.py index 3e6570ecc..5b2df4a4b 100755 --- a/python/test/test_tpulang.py +++ b/python/test/test_tpulang.py @@ -4520,7 +4520,6 @@ def test_all(tester: TPULANG_IR_TESTER): f, s = test_all_base(tester) if f: return f - print("====== start no save test ======") tester.no_save = True f, s = test_all_base(tester) return f