Skip to content

Commit

Permalink
fix matmul's shape inference
Browse files Browse the repository at this point in the history
Change-Id: I6c611de1c9693abfc16ef0fc20ccbdbf2a3a447d
  • Loading branch information
chang.zhao authored and HarmonyHu committed Aug 16, 2023
1 parent 4224585 commit f735aff
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
16 changes: 6 additions & 10 deletions lib/Dialect/Top/Interfaces/MatMul.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,15 @@ void top::MatMulOp::shape_inference() {
} else if (in1_shape[k_idx] == k) {
if (module::getPlatform() == module::Platform::CAFFE) {
// for caffe case
auto sum = 1;
for (int i = 0; i < in0_dims; i++) {
sum *= out_shape[i];
}
// shape case:[1, 1, 1, 4832] * [4832, 126] = [1, 126]
if (sum == k) {
while (out_shape.size() > 1) {
out_shape.pop_back();
// shape case:[8, 1, 1, 4832] * [4832, 136] = [8, 136]
for (int i = 1; i < out_shape.size(); i++) {
if(out_shape[i] == 1){
out_shape.erase(out_shape.begin()+i);
i--;
}
out_shape.push_back(n);
} else {
out_shape[in0_dims - 1] = n;
}
out_shape[out_shape.size() - 1] = n;
} else {
out_shape[in0_dims - 1] = n;
}
Expand Down
4 changes: 2 additions & 2 deletions python/transform/CaffeConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# ==============================================================================

from .MLIRImporter import MLIRImporter
from .MLIRImporter import MLIRImporter, Platform
from .BaseConverter import BaseConverter
import numpy as np

Expand Down Expand Up @@ -178,7 +178,7 @@ def init_MLIRImporter(self):
break
output_shapes.append(self.getShape(_name))
# init importer
self.mlir = MLIRImporter(input_shapes, output_shapes, self.model_name)
self.mlir = MLIRImporter(input_shapes, output_shapes, self.model_name, platform=Platform.CAFFE)
self.weight_file = self.mlir.weight_file

def layerType(self, layer):
Expand Down

0 comments on commit f735aff

Please sign in to comment.