Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Polish the implementation of batch_norm_grad decomposer and add the C…
Browse files Browse the repository at this point in the history
…++ unittest back.
  • Loading branch information
Xreki committed Nov 3, 2021
1 parent 7d11ea7 commit 72c0d26
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 267 deletions.
112 changes: 59 additions & 53 deletions cinn/frontend/decomposer/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ void batch_norm_grad(const Instruction& instr, const DecomposerContext& context)
CHECK_EQ(instr->outputs.size(), 3UL) << " The number of the given outputs is not equal to the required"
<< instr->op_type;

auto& dy = instr->inputs[0];
auto& y_grad = instr->inputs[0];
auto& x = instr->inputs[1];
auto& scale = instr->inputs[2];
auto& save_mean = instr->inputs[3];
Expand All @@ -139,70 +139,76 @@ void batch_norm_grad(const Instruction& instr, const DecomposerContext& context)
LOG(FATAL) << layout << " setting is not support!";
}

/*****************batch norm grad*********************
* grad_bias = reduce_sum(dy)
* grad_scale = reduce_sum(dy * (diff/std_var))
* grad_std_norm = dy * scale
* grad_diff = grad_std_norm / std_var
* grad_std_var = -grad_std_norm * diff / var
* grad_var = 0.5 * grad_std_var / std_var
* grad_mean = grad_var * -2 * mean - reduce(grad_diff)
* grad_mean_square = grad_var
* grad_diff += grad_mean
* grad_x = grad_diff + 2 * x * grad_mean_square + grad_mean
*/

// batch norm grad
// std_norm = (x - saved_mean) / std_variance
// y = scale * std_norm + bias
// ==>
// bias_grad = reduce_sum(y_grad)
// scale_grad = reduce_sum(y_grad * std_norm)
// std_norm_grad = y_grad * scale
//
// x_mean_diff = x - saved_mean
// std_norm = x_mean_diff / std_variance
// ==>
// x_mean_diff_grad = std_norm_grad / std_variance
// std_variance_grad = - std_norm_grad * x_mean_diff / variance
//
// variance_grad = 0.5 * std_variance_grad / std_variance
// mean_grad = variance_grad * -2 * mean - reduce(x_mean_diff_grad)
// mean_square_grad = variance_grad
// x_mean_diff_grad += mean_grad
// x_grad = x_mean_diff_grad + 2 * x * mean_square_grad + mean_grad

// bias_grad = reduce_sum(dy), shape = [c]
auto bias_grad = builder->Reduce(y_grad, ReduceKind::kSum, reduce_dim);

// std_norm = (x - saved_mean) / std_variance
// scale_grad = y_grad * std_norm, shape = [c]
auto epsilon_1d = builder->BroadcastTo(builder->ConstScalar(epsilon, common::UniqName("epsilon")), scale->shape, {0});
auto element_count_1d = builder->BroadcastTo(
builder->ConstScalar(1.0f / element_count, common::UniqName("element_count")), scale->shape, {0});
// grad bias = reduce(dy), shape = [c]
auto grad_bias = builder->Reduce(dy, ReduceKind::kSum, reduce_dim);

// grad scale = dy * (x - mean)/var, shape = [c]
auto mean_4d = builder->BroadcastTo(save_mean, x->shape, {channel_dim});
auto variance_1d = builder->Add(save_variance, epsilon_1d);
auto variance_4d = builder->BroadcastTo(variance_1d, x->shape, {channel_dim});
// std variance
auto variance_1d = builder->Add(save_variance, epsilon_1d);
auto variance_4d = builder->BroadcastTo(variance_1d, x->shape, {channel_dim});
auto std_variance_1d = builder->Sqrt(variance_1d);
auto std_variance_4d = builder->BroadcastTo(std_variance_1d, x->shape, {channel_dim});

auto diff = builder->Sub(x, mean_4d);
// grad scale = dy * (diff/std_var), shape = [c]
auto grad_scale =
builder->Reduce(builder->Mul(dy, builder->Div(diff, std_variance_4d)), ReduceKind::kSum, reduce_dim);
auto mean_4d = builder->BroadcastTo(save_mean, x->shape, {channel_dim});
auto x_mean_diff = builder->Sub(x, mean_4d);
auto scale_grad =
builder->Div(builder->Reduce(builder->Mul(y_grad, x_mean_diff), ReduceKind::kSum, reduce_dim), std_variance_1d);

// grad [(x - mean)/std_var] = dy * scale, shape = [n,c,h,w]
// std_norm_grad = y_grad * scale, shape = [n,c,h,w]
auto scale_4d = builder->BroadcastTo(scale, x->shape, {channel_dim});
auto grad_std_norm = builder->Mul(dy, scale_4d);
auto std_norm_grad = builder->Mul(y_grad, scale_4d);

// grad [diff=(x - mean)] = dstd/std_var, shape = [n,c,h,w]
auto grad_diff = builder->Div(grad_std_norm, std_variance_4d);
// x_mean_diff_grad = std_norm_grad / std_variance, shape = [n,c,h,w]
auto x_mean_diff_grad = builder->Div(std_norm_grad, std_variance_4d); // a portion of x_grad

// grad std var = -1 * reduce((grad_std * diff) / (var), shape = [c])
auto grad_std_variance_1d = builder->Negative(
builder->Reduce(builder->Div(builder->Mul(grad_std_norm, diff), variance_4d), ReduceKind::kSum, reduce_dim));
// std_variance_grad_1d = - reduce_sum(std_norm_grad * x_mean_diff / variance), shape = [c])
auto std_variance_grad_1d = builder->Negative(builder->Reduce(
builder->Div(builder->Mul(std_norm_grad, x_mean_diff), variance_4d), ReduceKind::kSum, reduce_dim));

// grad var = 1/2 * dy / std_var, do not mul 0.5 first
auto grad_variance_1d_without_mul = builder->Div(grad_std_variance_1d, std_variance_1d);
// variance = std_variance * std_variance
// variance_grad = 1/2 * std_variance_grad / std_variance
auto variance_grad_1d_without_mul = builder->Div(std_variance_grad_1d, std_variance_1d);

// grad_x0 = broadcastTo(grad_variance_1d_without_mul * 0.5 /element_count) * 2 * x
auto grad_x0 = builder->Mul(
x, builder->BroadcastTo(builder->Mul(grad_variance_1d_without_mul, element_count_1d), x->shape, {channel_dim}));
// x_grad_0 = (variance_grad_1d_without_mul * 0.5 / element_count) * 2 * x
auto element_count_1d =
builder->BroadcastTo(builder->ConstScalar(element_count, common::UniqName("element_count")), scale->shape, {0});
auto x_grad_0 = builder->Mul(
x, builder->BroadcastTo(builder->Div(variance_grad_1d_without_mul, element_count_1d), x->shape, {channel_dim}));

// -1.0 * grad_mean = ( -1.0 * reduce(grad_diff) + -1.0 * grad_variance_1d_without_mul * 0.5 * 2 * mean) /
// -1.0 * mean_grad = ((-1.0 * reduce(x_mean_diff_grad)) + (-1.0 * variance_grad_1d_without_mul * 0.5 * 2 * mean)) /
// element_count_1d
auto minus_grad_mean = builder->Mul(element_count_1d,
builder->Add(builder->Reduce(grad_diff, ReduceKind::kSum, reduce_dim),
builder->Mul(grad_variance_1d_without_mul, save_mean)));

// grad_x = grad_diff + boradcastTo(grad_mean) + grad_x0
auto grad_x =
builder->Sub(builder->Add(grad_diff, grad_x0), builder->BroadcastTo(minus_grad_mean, x->shape, {channel_dim}));

// set output
context.MapOutToOrigin(grad_x, instr->outputs[0]);
context.MapOutToOrigin(grad_scale, instr->outputs[1]);
context.MapOutToOrigin(grad_bias, instr->outputs[2]);
auto minus_mean_grad = builder->Div(builder->Add(builder->Reduce(x_mean_diff_grad, ReduceKind::kSum, reduce_dim),
builder->Mul(variance_grad_1d_without_mul, save_mean)),
element_count_1d);
auto minus_mean_grad_4d = builder->BroadcastTo(minus_mean_grad, x->shape, {channel_dim});

// x_grad = x_mean_diff_grad + mean_grad + x_grad_0
auto x_grad = builder->Sub(builder->Add(x_mean_diff_grad, x_grad_0), minus_mean_grad_4d);

context.MapOutToOrigin(x_grad, instr->outputs[0]);
context.MapOutToOrigin(scale_grad, instr->outputs[1]);
context.MapOutToOrigin(bias_grad, instr->outputs[2]);
}

} // namespace decomposer
Expand Down
Loading

0 comments on commit 72c0d26

Please sign in to comment.