Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'wangzhen/fix_conv2d_grad' into workspace
Browse files Browse the repository at this point in the history
  • Loading branch information
Xreki committed Nov 17, 2021
2 parents 2defab6 + 790f29d commit 8a77d5e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
24 changes: 13 additions & 11 deletions cinn/frontend/decomposer/conv2d_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,19 @@ void conv2d_grad(const Instruction& instr, const DecomposerContext& context) {

CinnBuilder* builder = context.builder();
// create backward data
auto dx = builder->Conv(w,
dy,
instr.GetAttrs<std::vector<int>>("strides"),
instr.GetAttrs<std::vector<int>>("paddings"),
instr.GetAttrs<std::vector<int>>("dilations"),
instr.GetAttrs<int>("groups"),
"backward_data",
instr.GetAttrs<std::string>("data_format"),
instr.GetAttrs<std::string>("padding_algorithm"),
x->shape);
context.MapOutToOrigin(dx, instr->outputs[0]);
if (!instr->outputs[0].is_const()) {
auto dx = builder->Conv(w,
dy,
instr.GetAttrs<std::vector<int>>("strides"),
instr.GetAttrs<std::vector<int>>("paddings"),
instr.GetAttrs<std::vector<int>>("dilations"),
instr.GetAttrs<int>("groups"),
"backward_data",
instr.GetAttrs<std::string>("data_format"),
instr.GetAttrs<std::string>("padding_algorithm"),
x->shape);
context.MapOutToOrigin(dx, instr->outputs[0]);
}

// create backward filter
auto dw = builder->Conv(x,
Expand Down
22 changes: 16 additions & 6 deletions cinn/frontend/op_mappers/conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,14 @@ void Conv2dGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContex
CHECK_EQ(op_desc.Input("Filter").size(), 1UL);
auto w_name = op_desc.Input("Filter").front();

// get dx,dfilter
CHECK_EQ(op_desc.Output(paddle::GradVarName("Input")).size(), 1UL);
auto dx_name = op_desc.Output(paddle::GradVarName("Input")).front();

// get d_x
std::string dx_name;
bool has_dx = !op_desc.Output(paddle::GradVarName("Input")).empty();
if (has_dx) {
CHECK_EQ(op_desc.Output(paddle::GradVarName("Input")).size(), 1UL);
dx_name = op_desc.Output(paddle::GradVarName("Input")).front();
}
// get d_filter
CHECK_EQ(op_desc.Output(paddle::GradVarName("Filter")).size(), 1UL);
auto dw_name = op_desc.Output(paddle::GradVarName("Filter")).front();

Expand All @@ -118,8 +122,14 @@ void Conv2dGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContex

auto out =
ctx.Builder()->conv2d_grad(dy, x, weight, strides, paddings, dilations, groups, data_format, padding_algorithm);
ctx.AddVar(dx_name, out[0]);
ctx.AddVarModelToProgram(dx_name, out[0]->id);

if (has_dx) {
ctx.AddVar(dx_name, out[0]);
ctx.AddVarModelToProgram(dx_name, out[0]->id);
} else {
out[0].set_const(true);
}

ctx.AddVar(dw_name, out[1]);
ctx.AddVarModelToProgram(dw_name, out[1]->id);
}
Expand Down

0 comments on commit 8a77d5e

Please sign in to comment.