-
Notifications
You must be signed in to change notification settings - Fork 114
Add model MobileNetV1, ResNet50 and SqueezeNet #441
base: develop
Are you sure you want to change the base?
Conversation
Thanks for your contribution! |
} | ||
|
||
void PaddleModelToProgram::AddOpMapper_reshape2() { | ||
op_mappers_["reshape2"] = [&](const paddle::cpp::OpDesc& op_desc) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshape2?改下名字,下同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thanks
cinn/hlir/op/nn.cc
Outdated
@@ -1252,6 +1262,12 @@ std::vector<std::vector<int>> InferShapeForPool2d(const std::vector<std::vector< | |||
(inputs_shape[0][width_axis] - kernel_size[1] + padding_size[1] + padding_size[3]) / stride_size[1] + 1; | |||
} | |||
|
|||
if (adaptive) { | |||
kernel_size = std::get<std::vector<int>>(attr_store["kernel_size"]); | |||
if (kernel_size.size() == 1) kernel_size.push_back(kernel_size[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check kernel_size.size(),防止数组越界
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
} | ||
} | ||
if (axis < 0) axis += inputs_shape[0].size(); | ||
std::vector<int> output_shape = inputs_shape[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对inputs_shape进行校验,两者size需要相同以及axis不能越界
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
/** | ||
* Reshape a tensor. | ||
*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释具体点?下同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForReshape2) | ||
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForReshape2)) | ||
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForReshape2)) | ||
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
所有新增op需要添加inferlayout函数
std::vector<int> output_shape; | ||
for (auto &iter : attrs.attr_store) { | ||
if (iter.first == "shape") { | ||
output_shape = std::get<std::vector<int>>(iter.second); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
break
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
output_shape = std::get<std::vector<int>>(iter.second); | ||
} | ||
} | ||
int tensor_size = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新增op添加单测
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
int flag_index = -1; | ||
for (int i = 0; i < output_shape.size(); i++) { | ||
if (output_shape[i] > 0) { | ||
CHECK_EQ(tensor_size % output_shape[i], 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle里output_shape[i]可能为0吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
for (auto &iter : attrs.attr_store) { | ||
if (iter.first == "axis") { | ||
axis = std::get<int>(iter.second); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
break
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
std::vector<Expr> new_expr_shape; | ||
std::vector<Expr> A_expr_shape = A->shape; | ||
for (auto& i : new_shape) { | ||
new_expr_shape.push_back(Expr(i)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增强shape参数的校验
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
for (auto& i : new_shape) { | ||
new_expr_shape.push_back(Expr(i)); | ||
} | ||
auto res = Compute( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的new_shape[i]是不是也有可能是-1和0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不会,这里的shape是通过infershape得到的
python/tests/test_frontend.py
Outdated
class TestLoadPaddleModel_FC(unittest.TestCase): | ||
def setUp(self): | ||
if enable_gpu == "ON": | ||
self.target = DefaultNVGPUTarget() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些为啥删了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
误删,已恢复
b9ad508
to
d1f6e55
Compare
87833c2
to
d1f6e55
Compare
This pr supports 2 new Models:
MobileNetV1
andResNet50
. And their corresponding tests are added, too.To achieve this, a new op
reshape2
is added and some bugs are fixed.