diff --git a/cinn/frontend/decomposer/batch_norm.cc b/cinn/frontend/decomposer/batch_norm.cc index ac158cba23..a38ca088f6 100644 --- a/cinn/frontend/decomposer/batch_norm.cc +++ b/cinn/frontend/decomposer/batch_norm.cc @@ -112,7 +112,7 @@ void batch_norm_grad(const Instruction& instr, const DecomposerContext& context) CHECK_EQ(instr->outputs.size(), 3UL) << " The number of the given outputs is not equal to the required" << instr->op_type; - auto& dy = instr->inputs[0]; + auto& y_grad = instr->inputs[0]; auto& x = instr->inputs[1]; auto& scale = instr->inputs[2]; auto& save_mean = instr->inputs[3]; @@ -139,70 +139,76 @@ void batch_norm_grad(const Instruction& instr, const DecomposerContext& context) LOG(FATAL) << layout << " setting is not support!"; } - /*****************batch norm grad********************* - * grad_bias = reduce_sum(dy) - * grad_scale = reduce_sum(dy * (diff/std_var)) - * grad_std_norm = dy * scale - * grad_diff = grad_std_norm / std_var - * grad_std_var = -grad_std_norm * diff / var - * grad_var = 0.5 * grad_std_var / std_var - * grad_mean = grad_var * -2 * mean - reduce(grad_diff) - * grad_mean_square = grad_var - * grad_diff += grad_mean - * grad_x = grad_diff + 2 * x * grad_mean_square + grad_mean - */ - + // batch norm grad + // std_norm = (x - saved_mean) / std_variance + // y = scale * std_norm + bias + // ==> + // bias_grad = reduce_sum(y_grad) + // scale_grad = reduce_sum(y_grad * std_norm) + // std_norm_grad = y_grad * scale + // + // x_mean_diff = x - saved_mean + // std_norm = x_mean_diff / std_variance + // ==> + // x_mean_diff_grad = std_norm_grad / std_variance + // std_variance_grad = - std_norm_grad * x_mean_diff / variance + // + // variance_grad = 0.5 * std_variance_grad / std_variance + // mean_grad = variance_grad * -2 * mean - reduce(x_mean_diff_grad) + // mean_square_grad = variance_grad + // x_mean_diff_grad += mean_grad + // x_grad = x_mean_diff_grad + 2 * x * mean_square_grad + mean_grad + + // bias_grad = reduce_sum(dy), shape = [c] + auto bias_grad = builder->Reduce(y_grad, ReduceKind::kSum, reduce_dim); + + // std_norm = (x - saved_mean) / std_variance + // scale_grad = y_grad * std_norm, shape = [c] auto epsilon_1d = builder->BroadcastTo(builder->ConstScalar(epsilon, common::UniqName("epsilon")), scale->shape, {0}); - auto element_count_1d = builder->BroadcastTo( - builder->ConstScalar(1.0f / element_count, common::UniqName("element_count")), scale->shape, {0}); - // grad bias = reduce(dy), shape = [c] - auto grad_bias = builder->Reduce(dy, ReduceKind::kSum, reduce_dim); - - // grad scale = dy * (x - mean)/var, shape = [c] - auto mean_4d = builder->BroadcastTo(save_mean, x->shape, {channel_dim}); - auto variance_1d = builder->Add(save_variance, epsilon_1d); - auto variance_4d = builder->BroadcastTo(variance_1d, x->shape, {channel_dim}); - // std variance + auto variance_1d = builder->Add(save_variance, epsilon_1d); + auto variance_4d = builder->BroadcastTo(variance_1d, x->shape, {channel_dim}); auto std_variance_1d = builder->Sqrt(variance_1d); auto std_variance_4d = builder->BroadcastTo(std_variance_1d, x->shape, {channel_dim}); - auto diff = builder->Sub(x, mean_4d); - // grad scale = dy * (diff/std_var), shape = [c] - auto grad_scale = - builder->Reduce(builder->Mul(dy, builder->Div(diff, std_variance_4d)), ReduceKind::kSum, reduce_dim); + auto mean_4d = builder->BroadcastTo(save_mean, x->shape, {channel_dim}); + auto x_mean_diff = builder->Sub(x, mean_4d); + auto scale_grad = + builder->Div(builder->Reduce(builder->Mul(y_grad, x_mean_diff), ReduceKind::kSum, reduce_dim), std_variance_1d); - // grad [(x - mean)/std_var] = dy * scale, shape = [n,c,h,w] + // std_norm_grad = y_grad * scale, shape = [n,c,h,w] auto scale_4d = builder->BroadcastTo(scale, x->shape, {channel_dim}); - auto grad_std_norm = builder->Mul(dy, scale_4d); + auto std_norm_grad = builder->Mul(y_grad, scale_4d); - // grad [diff=(x - mean)] = dstd/std_var, shape = [n,c,h,w] - auto grad_diff = builder->Div(grad_std_norm, std_variance_4d); + // x_mean_diff_grad = std_norm_grad / std_variance, shape = [n,c,h,w] + auto x_mean_diff_grad = builder->Div(std_norm_grad, std_variance_4d); // a portion of x_grad - // grad std var = -1 * reduce((grad_std * diff) / (var), shape = [c]) - auto grad_std_variance_1d = builder->Negative( - builder->Reduce(builder->Div(builder->Mul(grad_std_norm, diff), variance_4d), ReduceKind::kSum, reduce_dim)); + // std_variance_grad_1d = - reduce_sum(std_norm_grad * x_mean_diff / variance), shape = [c]) + auto std_variance_grad_1d = builder->Negative(builder->Reduce( + builder->Div(builder->Mul(std_norm_grad, x_mean_diff), variance_4d), ReduceKind::kSum, reduce_dim)); - // grad var = 1/2 * dy / std_var, do not mul 0.5 first - auto grad_variance_1d_without_mul = builder->Div(grad_std_variance_1d, std_variance_1d); + // variance = std_variance * std_variance + // variance_grad = 1/2 * std_variance_grad / std_variance + auto variance_grad_1d_without_mul = builder->Div(std_variance_grad_1d, std_variance_1d); - // grad_x0 = broadcastTo(grad_variance_1d_without_mul * 0.5 /element_count) * 2 * x - auto grad_x0 = builder->Mul( - x, builder->BroadcastTo(builder->Mul(grad_variance_1d_without_mul, element_count_1d), x->shape, {channel_dim})); + // x_grad_0 = (variance_grad_1d_without_mul * 0.5 / element_count) * 2 * x + auto element_count_1d = + builder->BroadcastTo(builder->ConstScalar(element_count, common::UniqName("element_count")), scale->shape, {0}); + auto x_grad_0 = builder->Mul( + x, builder->BroadcastTo(builder->Div(variance_grad_1d_without_mul, element_count_1d), x->shape, {channel_dim})); - // -1.0 * grad_mean = ( -1.0 * reduce(grad_diff) + -1.0 * grad_variance_1d_without_mul * 0.5 * 2 * mean) / + // -1.0 * mean_grad = ((-1.0 * reduce(x_mean_diff_grad)) + (-1.0 * variance_grad_1d_without_mul * 0.5 * 2 * mean)) / // element_count_1d - auto minus_grad_mean = builder->Mul(element_count_1d, - builder->Add(builder->Reduce(grad_diff, ReduceKind::kSum, reduce_dim), - builder->Mul(grad_variance_1d_without_mul, save_mean))); - - // grad_x = grad_diff + boradcastTo(grad_mean) + grad_x0 - auto grad_x = - builder->Sub(builder->Add(grad_diff, grad_x0), builder->BroadcastTo(minus_grad_mean, x->shape, {channel_dim})); - - // set output - context.MapOutToOrigin(grad_x, instr->outputs[0]); - context.MapOutToOrigin(grad_scale, instr->outputs[1]); - context.MapOutToOrigin(grad_bias, instr->outputs[2]); + auto minus_mean_grad = builder->Div(builder->Add(builder->Reduce(x_mean_diff_grad, ReduceKind::kSum, reduce_dim), + builder->Mul(variance_grad_1d_without_mul, save_mean)), + element_count_1d); + auto minus_mean_grad_4d = builder->BroadcastTo(minus_mean_grad, x->shape, {channel_dim}); + + // x_grad = x_mean_diff_grad + mean_grad + x_grad_0 + auto x_grad = builder->Sub(builder->Add(x_mean_diff_grad, x_grad_0), minus_mean_grad_4d); + + context.MapOutToOrigin(x_grad, instr->outputs[0]); + context.MapOutToOrigin(scale_grad, instr->outputs[1]); + context.MapOutToOrigin(bias_grad, instr->outputs[2]); } } // namespace decomposer diff --git a/cinn/frontend/decomposer/batch_norm_test.cc b/cinn/frontend/decomposer/batch_norm_test.cc index 26498e2669..e3de6bdd56 100644 --- a/cinn/frontend/decomposer/batch_norm_test.cc +++ b/cinn/frontend/decomposer/batch_norm_test.cc @@ -143,7 +143,6 @@ void ComputeBatchNormTrainRef(const std::vector& x, } TEST(Decomposer, BatchNormTrain) { - // parameter int n = 16, c = 32, h = 16, w = 16; float epsilon = 1e-5; float momentum = 0.9f; @@ -151,21 +150,18 @@ TEST(Decomposer, BatchNormTrain) { NetBuilder net_builder("batch_norm_train"); std::vector output_names; { - // create input auto x = net_builder.CreateInput(Float(32), {n, c, h, w}, "x"); auto scale = net_builder.CreateInput(Float(32), {c}, "scale"); auto bias = net_builder.CreateInput(Float(32), {c}, "bias"); auto moving_mean = net_builder.CreateInput(Float(32), {c}, "moving_mean"); auto moving_variance = net_builder.CreateInput(Float(32), {c}, "moving_variance"); - // add batch norm train auto outputs = net_builder.batch_norm_train(x, scale, bias, moving_mean, moving_variance, epsilon, momentum, data_layout); for (auto output : outputs) { output_names.push_back(output->id); } } - // build program auto program = net_builder.Build(); auto target = GetTarget(); @@ -236,192 +232,119 @@ TEST(Decomposer, BatchNormTrain) { } } -#if 0 template -void cpu_batch_norm_grad(const std::vector& x, - const std::vector& dy, - const std::vector& scale, - const std::vector& save_mean, - const std::vector& save_variance, - const int n, - const int c, - const int h, - const int w, - std::vector* dx, - std::vector* dscale, - std::vector* dbias, - std::vector* grad_std_norm, - std::vector* grad_diff, - std::vector* grad_std_variance_2d, - std::vector* grad_variance_2d_without_mul, - std::vector* grad_x0, - std::vector* minus_grad_mean, - float epsilon = 1e-5) { +void ComputeBatchNormGradRef(const std::vector& y_grad, + const std::vector& x, + const std::vector& scale, + const std::vector& save_mean, + const std::vector& save_variance, + const int n, + const int c, + const int h, + const int w, + std::vector* x_grad, + std::vector* scale_grad, + std::vector* bias_grad, + const float epsilon = 1e-5) { Offset offset(n, c, h, w); - std::vector save_std_varance(c); - for (int idx = 0; idx < c; ++idx) { - save_std_varance[idx] = sqrt(save_variance[idx] + epsilon); - } - // grad bias - memset(dbias->data(), 0, sizeof(float) * c); - auto func_dbias = [=](int idx, int idy, int idz, int ida) { - dbias->at(idy) += dy[offset(idx, idy, idz, ida)]; - }; - Loop(func_dbias, n, c, h, w); - - // grad scale - memset(dscale->data(), 0, sizeof(float) * c); - auto func_dscale = [=](int idx, int idy, int idz, int ida) { - dscale->at(idy) += dy[offset(idx, idy, idz, ida)] * - ((x[offset(idx, idy, idz, ida)] - save_mean[idy]) / save_std_varance[idy]); - }; - Loop(func_dscale, n, c, h, w); - - // grad_std - auto func_grad_std_norm = [=](int idx, int idy, int idz, int ida) { - grad_std_norm->at(offset(idx, idy, idz, ida)) = dy[offset(idx, idy, idz, ida)] * scale[idy]; - }; - Loop(func_grad_std_norm, n, c, h, w); + // bias_grad + memset(bias_grad->data(), 0, sizeof(T) * c); + auto func_bias_grad = [=](int in, int ic, int ih, int iw) { bias_grad->at(ic) += y_grad[offset(in, ic, ih, iw)]; }; + Loop(func_bias_grad, n, c, h, w); - auto func_grad_diff = [=](int idx, int idy, int idz, int ida) { - grad_diff->at(offset(idx, idy, idz, ida)) = - grad_std_norm->at(offset(idx, idy, idz, ida)) / save_std_varance[idy]; - }; - Loop(func_grad_diff, n, c, h, w); - - memset(grad_std_variance_2d->data(), 0, sizeof(float) * c); - auto func_grad_std_variance_2d = [=](int idx, int idy, int idz, int ida) { - grad_std_variance_2d->at(idy) += -1 * grad_std_norm->at(offset(idx, idy, idz, ida)) * - (x[offset(idx, idy, idz, ida)] - save_mean[idy]) / - (save_variance[idy] + epsilon); - }; - Loop(func_grad_std_variance_2d, n, c, h, w); - - for (int idx = 0; idx < c; ++idx) { - grad_variance_2d_without_mul->at(idx) = 0.5 * grad_std_variance_2d->at(idx) / save_std_varance[idx]; + // std_variance + std::vector std_variance(c); + for (int ic = 0; ic < c; ++ic) { + std_variance[ic] = sqrt(save_variance[ic] + epsilon); } - auto func_grad_x0 = [=](int idx, int idy, int idz, int ida) { - grad_x0->at(offset(idx, idy, idz, ida)) = - 2 * x[offset(idx, idy, idz, ida)] * grad_variance_2d_without_mul->at(idy) / (n * h * w); - }; - Loop(func_grad_x0, n, c, h, w); - memset(minus_grad_mean->data(), 0, sizeof(float) * c); - auto func_minus_grad_mean = [=](int idx, int idy, int idz, int ida) { - minus_grad_mean->at(idy) += -1 * grad_diff->at(offset(idx, idy, idz, ida)); + // grad scale + memset(scale_grad->data(), 0, sizeof(T) * c); + auto func_scale_grad = [=](int in, int ic, int ih, int iw) { + int idx = offset(in, ic, ih, iw); + scale_grad->at(ic) += y_grad[idx] * (x[idx] - save_mean[ic]); }; - Loop(func_minus_grad_mean, n, c, h, w); - - for (int idx = 0; idx < c; ++idx) { - minus_grad_mean->at(idx) += -1 * 2 * grad_variance_2d_without_mul->at(idx) * save_mean.at(idx); - minus_grad_mean->at(idx) /= (n * h * w); + Loop(func_scale_grad, n, c, h, w); + for (int ic = 0; ic < c; ++ic) { + scale_grad->at(ic) /= std_variance[ic]; } - auto func_grad_x = [=](int idx, int idy, int idz, int ida) { - dx->at(offset(idx, idy, idz, ida)) = - grad_diff->at(offset(idx, idy, idz, ida)) + - grad_x0->at(offset(idx, idy, idz, ida)) + minus_grad_mean->at(idy); + // std_norm_grad + std::vector std_norm_grad(n * c * h * w); + auto func_std_norm_grad = [&](int in, int ic, int ih, int iw) { + int idx = offset(in, ic, ih, iw); + std_norm_grad[idx] = y_grad[idx] * scale[ic]; }; - Loop(func_grad_x, n, c, h, w); -} - -void GradX(const std::vector& grad_std_norm, - const std::vector& x, - const std::vector& mean, - const std::vector& variance, - int n, - int c, - int h, - int w, - float epsilon = 1e-5) { - std::vector std_variance(c); - for (int idx = 0; idx < c; ++idx) { - std_variance[idx] = sqrt(variance[idx] + epsilon); - } + Loop(func_std_norm_grad, n, c, h, w); - std::vector grad_diff(n * c * h * w); - auto func_0 = [&](int idx, int idy, int idz, int ida) { - grad_diff[idx * c * h * w + idy * h * w + idz * w + ida] = - grad_std_norm[idx * c * h * w + idy * h * w + idz * w + ida] / std_variance[idy]; + // x_mean_diff_grad + std::vector x_mean_diff_grad(n * c * h * w); + auto func_x_mean_diff_grad = [&](int in, int ic, int ih, int iw) { + int idx = offset(in, ic, ih, iw); + x_mean_diff_grad[idx] = std_norm_grad[idx] / std_variance[ic]; }; - Loop(func_0, n, c, h, w); - for (auto value : grad_diff) { - std::cerr << value << " "; - } - std::cerr << std::endl; + Loop(func_x_mean_diff_grad, n, c, h, w); - std::vector grad_std_variance(c, 0); - auto func_1 = [&](int idx, int idy, int idz, int ida) { - grad_std_variance[idy] += -1 * grad_std_norm[idx * c * h * w + idy * h * w + idz * w + ida] * - (x[idx * c * h * w + idy * h * w + idz * w + ida] - mean[idy]) / - (variance[idy] + epsilon); + // std_variance_grad + std::vector std_variance_grad(c, 0); + auto func_std_variance_grad = [&](int in, int ic, int ih, int iw) { + int idx = offset(in, ic, ih, iw); + std_variance_grad[ic] += -1.0f * std_norm_grad[idx] * (x[idx] - save_mean[ic]) / (save_variance[ic] + epsilon); }; - Loop(func_1, n, c, h, w); - - std::vector grad_variance(c); - for (int idx = 0; idx < c; ++idx) { - grad_variance[idx] = grad_std_variance[idx] * 0.5 / std_variance[idx]; - } + Loop(func_std_variance_grad, n, c, h, w); - for (auto value : grad_variance) { - std::cerr << value << " "; + // variance_grad_without_mul + std::vector variance_grad_without_mul(c); + for (int ic = 0; ic < c; ++ic) { + variance_grad_without_mul[ic] = std_variance_grad[ic] / std_variance[ic]; } - std::cerr << std::endl; - - std::vector grad_square_diff(n * c * h * w); - auto func_11 = [&](int idx, int idy, int idz, int ida) { - grad_square_diff[idx * c * h * w + idy * h * w + idz * w + ida] = - grad_variance[idy] * 2 * (x[idx * c * h * w + idy * h * w + idz * w + ida] - mean[idy]) / float(n * h * w); - }; - Loop(func_11, n, c, h, w); - auto func_2 = [&](int idx, int idy, int idz, int ida) { - grad_diff[idx * c * h * w + idy * h * w + idz * w + ida] += - grad_square_diff[idx * c * h * w + idy * h * w + idz * w + ida]; + // x_grad_0 + float element_count = static_cast(n * h * w); + std::vector x_grad_0(n * c * h * w); + auto func_x_grad_0 = [&](int in, int ic, int ih, int iw) { + int idx = offset(in, ic, ih, iw); + x_grad_0[idx] = x[idx] * (variance_grad_without_mul[ic] / element_count); }; - Loop(func_2, n, c, h, w); + Loop(func_x_grad_0, n, c, h, w); - std::vector grad_mean(c, 0); - auto func_3 = [&](int idx, int idy, int idz, int ida) { - grad_mean[idy] += -1 * grad_diff[idx * c * h * w + idy * h * w + idz * w + ida]; + // minus_mean_grad + std::vector minus_mean_grad(c, 0); + auto func_minus_mean_grad = [&](int in, int ic, int ih, int iw) { + minus_mean_grad[ic] += x_mean_diff_grad[offset(in, ic, ih, iw)]; }; - Loop(func_3, n, c, h, w); + Loop(func_minus_mean_grad, n, c, h, w); + for (int ic = 0; ic < c; ++ic) { + minus_mean_grad[ic] += variance_grad_without_mul[ic] * save_mean[ic]; + minus_mean_grad[ic] /= element_count; + } - std::vector grad_x(n * c * h * w); - auto func_4 = [&](int idx, int idy, int idz, int ida) { - grad_x[idx * c * h * w + idy * h * w + idz * w + ida] = - grad_diff[idx * c * h * w + idy * h * w + idz * w + ida] + grad_mean[idy] / (float(n * h * w)); + auto func_x_grad = [=](int in, int ic, int ih, int iw) { + int idx = offset(in, ic, ih, iw); + x_grad->at(idx) = x_mean_diff_grad[idx] + x_grad_0[idx] - minus_mean_grad[ic]; }; - Loop(func_4, n, c, h, w); - - for (auto value : grad_x) { - std::cerr << value << " "; - } - std::cerr << std::endl; + Loop(func_x_grad, n, c, h, w); } -TEST(nn, BATCH_NORM_GRAD) { - // parameter - int n = 4, c = 16, h = 4, w = 4; - int num = n * c * h * w; - NetBuilder net_builder("net_builder_batch_norm_grad"); +TEST(Decomposer, BatchNormGrad) { + int n = 16, c = 32, h = 16, w = 16; + int num = n * c * h * w; + float epsilon = 1e-5; + NetBuilder net_builder("batch_norm_grad"); std::vector output_names; { - // create input - auto dy = net_builder.CreateInput(Float(32), {n, c, h, w}, "dy"); - auto x = net_builder.CreateInput(Float(32), {n, c, h, w}, "x"); - auto scale = net_builder.CreateInput(Float(32), {c}, "scale"); - auto save_mean = net_builder.CreateInput(Float(32), {c}, "save_mean"); - auto save_variance = net_builder.CreateInput(Float(32), {c}, "save_variance"); - - // add batch norm train - auto outputs = net_builder.batch_norm_grad(dy, x, scale, save_mean, save_variance); + auto y_grad = net_builder.CreateInput(Float(32), {n, c, h, w}, "y_grad"); + auto x = net_builder.CreateInput(Float(32), {n, c, h, w}, "x"); + auto scale = net_builder.CreateInput(Float(32), {c}, "scale"); + auto saved_mean = net_builder.CreateInput(Float(32), {c}, "saved_mean"); + auto saved_variance = net_builder.CreateInput(Float(32), {c}, "saved_variance"); + + auto outputs = net_builder.batch_norm_grad(y_grad, x, scale, saved_mean, saved_variance, epsilon); for (auto output : outputs) { output_names.push_back(output->id); } } - // build program auto program = net_builder.Build(); auto target = GetTarget(); @@ -430,29 +353,30 @@ TEST(nn, BATCH_NORM_GRAD) { auto graph = std::make_shared(program, target); auto scope = BuildScope(target, graph); hlir::framework::GraphCompiler gc(target, scope, graph); - hlir::framework::ApplyPass(graph.get(), "OpFusion"); + // hlir::framework::ApplyPass(graph.get(), "OpFusion"); auto run_program = gc.Build(); // set input - std::vector dy(num), x(num), scale(c), save_mean(c, 0), save_variance(c, 0); - InitRandomVector(&dy, num); + std::vector y_grad(num), x(num), scale(c), saved_mean(c, 0), saved_variance(c, 0); + InitRandomVector(&y_grad, num); InitRandomVector(&x, num); InitRandomVector(&scale, c); - auto func_save_mean = [&](int idx, int idy, int idz, int ida) { - save_mean[idy] += x[idx * c * h * w + idy * h * w + idz * w + ida]; - save_variance[idy] += - x[idx * c * h * w + idy * h * w + idz * w + ida] * x[idx * c * h * w + idy * h * w + idz * w + ida]; + Offset offset(n, c, h, w); + auto func_save_mean = [&](int in, int ic, int ih, int iw) { + int idx = offset(in, ic, ih, iw); + saved_mean[ic] += x[idx]; + saved_variance[ic] += x[idx] * x[idx]; }; Loop(func_save_mean, n, c, h, w); - for (int idx = 0; idx < c; ++idx) { - save_mean[idx] /= float(n * h * w); - save_variance[idx] /= float(n * h * w); - save_variance[idx] -= (save_mean[idx] * save_mean[idx]); + float element_count = static_cast(n * h * w); + for (int ic = 0; ic < c; ++ic) { + saved_mean[ic] /= element_count; + saved_variance[ic] = saved_variance[ic] / element_count - saved_mean[ic] * saved_mean[ic]; } std::vector>> inputs = { - {"dy", dy}, {"x", x}, {"scale", scale}, {"save_mean", save_mean}, {"save_variance", save_variance}}; + {"y_grad", y_grad}, {"x", x}, {"scale", scale}, {"saved_mean", saved_mean}, {"saved_variance", saved_variance}}; for (auto& input : inputs) { scope->Var(input.first); auto tensor = scope->GetTensor(input.first); @@ -460,51 +384,29 @@ TEST(nn, BATCH_NORM_GRAD) { } run_program->Execute(); - std::vector dx(num), dscale(c), dbias(c); - std::vector grad_std_norm(num), grad_diff(num), grad_std_variance_1d(c), grad_variance_1d_without_mul(c), - grad_x0(num), minus_grad_mean(c); - - cpu_batch_norm_grad(x, - dy, - scale, - save_mean, - save_variance, - n, - c, - h, - w, - &dx, - &dscale, - &dbias, - &grad_std_norm, - &grad_diff, - &grad_std_variance_1d, - &grad_variance_1d_without_mul, - &grad_x0, - &minus_grad_mean); - // GradX(grad_std_norm, x, save_mean, save_variance, n, c, h , w); - - std::vector>> outputs = { - {output_names[2], dbias}, - {output_names[1], dscale}, - {output_names[0], dx}, - }; + std::vector x_grad(num), scale_grad(c), bias_grad(c); + ComputeBatchNormGradRef( + y_grad, x, scale, saved_mean, saved_variance, n, c, h, w, &x_grad, &scale_grad, &bias_grad, epsilon); + + std::unordered_map>> output_refs = { + {"bias_grad", {output_names[2], bias_grad}}, + {"scale_grad", {output_names[1], scale_grad}}, + {"x_grad", {output_names[0], x_grad}}}; - for (auto& output : outputs) { + for (auto& iter : output_refs) { + auto output = iter.second; auto tensor = scope->GetTensor(output.first); std::vector data(tensor->shape().numel()); CopyToVector(tensor, &data); - LOG(INFO) << output.first << " " << tensor->shape().numel(); - for (int idx = 0; idx < tensor->shape().numel(); ++idx) { - float diff = abs((data[idx] - output.second[idx]) / data[idx]); - if (diff > 1e-5) { - LOG(INFO) << "i=" << idx << ", " << data[idx] << " vs " << output.second[idx] << ", diff=" << diff; - } -// ASSERT_LT(diff, 1e-3); + + LOG(INFO) << "output[" << iter.first << "], var_name=" << output.first << ", shape=" << tensor->shape().data(); + if (iter.first == "y_grad") { + CheckOutput(data, output.second, 1e-2, true); + } else { + CheckOutput(data, output.second, 1e-5); } } } -#endif } // namespace } // namespace frontend diff --git a/cinn/frontend/decomposer/test_helper.h b/cinn/frontend/decomposer/test_helper.h index d6c4fef644..fa052b9c7e 100644 --- a/cinn/frontend/decomposer/test_helper.h +++ b/cinn/frontend/decomposer/test_helper.h @@ -121,13 +121,13 @@ void CheckOutput(const std::vector& results, } if ((relative_diff > max_relative_error) || (check_absolute_error && (absolute_diff > 1e-6))) { num_diffs += 1; - // LOG(INFO) << "- i=" << i << ", " << std::setprecision(8) << results[i] << " vs " << std::setprecision(8) << - // references[i] << ", relative_diff=" << relative_diff << ", absolute_diff=" << absolute_diff; + VLOG(4) << "- i=" << i << ", " << std::setprecision(8) << results[i] << " vs " << std::setprecision(8) + << references[i] << ", relative_diff=" << relative_diff << ", absolute_diff=" << absolute_diff; } } LOG(INFO) << "- Total " << num_diffs << " different results, offset=" << offset << ", " << results[offset] << " vs " << references[offset] << ", maximum_relative_diff=" << max_diff - << "(absolute_diff=" << abs((results[offset] - references[offset])) << ")"; + << " (absolute_diff=" << abs((results[offset] - references[offset])) << ")"; ASSERT_EQ(num_diffs, 0); ASSERT_LT(max_diff, max_relative_error); } diff --git a/cinn/frontend/decomposer/use_decomposer.h b/cinn/frontend/decomposer/use_decomposer.h index e42ef31ad9..e015d7dfde 100644 --- a/cinn/frontend/decomposer/use_decomposer.h +++ b/cinn/frontend/decomposer/use_decomposer.h @@ -21,7 +21,6 @@ CINN_USE_REGISTER(activation_grad_decomposers) CINN_USE_REGISTER(elementwise_decomposers) CINN_USE_REGISTER(broadcast_decomposers) CINN_USE_REGISTER(broadcast_grad_decomposers) - CINN_USE_REGISTER(batch_norm_train_decomposer) CINN_USE_REGISTER(batch_norm_grad_decomposer)