diff --git a/lib/Dialect/Top/Interfaces/MatMul.cpp b/lib/Dialect/Top/Interfaces/MatMul.cpp old mode 100644 new mode 100755 index e68ec2bd5..ebc072940 --- a/lib/Dialect/Top/Interfaces/MatMul.cpp +++ b/lib/Dialect/Top/Interfaces/MatMul.cpp @@ -118,19 +118,15 @@ void top::MatMulOp::shape_inference() { } else if (in1_shape[k_idx] == k) { if (module::getPlatform() == module::Platform::CAFFE) { // for caffe case - auto sum = 1; - for (int i = 0; i < in0_dims; i++) { - sum *= out_shape[i]; - } // shape case:[1, 1, 1, 4832] * [4832, 126] = [1, 126] - if (sum == k) { - while (out_shape.size() > 1) { - out_shape.pop_back(); + // shape case:[8, 1, 1, 4832] * [4832, 136] = [8, 136] + for (int i = 1; i < out_shape.size(); i++) { + if(out_shape[i] == 1){ + out_shape.erase(out_shape.begin()+i); + i--; } - out_shape.push_back(n); - } else { - out_shape[in0_dims - 1] = n; } + out_shape[out_shape.size() - 1] = n; } else { out_shape[in0_dims - 1] = n; } diff --git a/python/transform/CaffeConverter.py b/python/transform/CaffeConverter.py index 5fabca91e..dbb192411 100644 --- a/python/transform/CaffeConverter.py +++ b/python/transform/CaffeConverter.py @@ -5,7 +5,7 @@ # # ============================================================================== -from .MLIRImporter import MLIRImporter +from .MLIRImporter import MLIRImporter, Platform from .BaseConverter import BaseConverter import numpy as np @@ -178,7 +178,7 @@ def init_MLIRImporter(self): break output_shapes.append(self.getShape(_name)) # init importer - self.mlir = MLIRImporter(input_shapes, output_shapes, self.model_name) + self.mlir = MLIRImporter(input_shapes, output_shapes, self.model_name, platform=Platform.CAFFE) self.weight_file = self.mlir.weight_file def layerType(self, layer):