Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Polish some unittests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Xreki committed Oct 29, 2021
1 parent a97fe62 commit 02d8f7c
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 19 deletions.
2 changes: 1 addition & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ function cmake_ {
echo "set(WITH_CUDNN $cudnn_config)" >> $build_dir/config.cmake
echo "set(WITH_MKL_CBLAS ON)" >> $build_dir/config.cmake
cd $build_dir
cmake .. -DPUBLISH_LIBS=ON -DWITH_TESTING=ON -DPY_VERSION=3.6
cmake ${workspace} -DPUBLISH_LIBS=ON -DWITH_TESTING=ON -DPY_VERSION=3.6
}

function _download_and_untar {
Expand Down
3 changes: 1 addition & 2 deletions cinn/frontend/decomposer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ gather_srcs(cinnapi_src SRCS
cc_test(test_activation_decomposer SRCS activation_test.cc DEPS cinncore)
cc_test(test_elementwise_decomposer SRCS elementwise_test.cc DEPS cinncore)
cc_test(test_broadcast_decomposer SRCS broadcast_test.cc DEPS cinncore)

#cc_test(test_batch_norm_decomposer SRCS batch_norm_test.cc DEPS cinncore)
cc_test(test_batch_norm_decomposer SRCS batch_norm_test.cc DEPS cinncore)
if(WITH_CUDNN)
cc_test(test_conv2d_grad_decomposer SRCS conv2d_grad_test.cc DEPS cinncore)
endif()
45 changes: 34 additions & 11 deletions cinn/frontend/decomposer/batch_norm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ namespace cinn {
namespace frontend {
namespace {

struct Offset {
int n;
int c;
int h;
int w;

Offset(int arg_n, int arg_c, int arg_h, int arg_w) : n(arg_n), c(arg_c), h(arg_h), w(arg_w) {}

int operator()(int idx_n, int idx_c, int idx_h, int idx_w) const {
return idx_n * c * h * w + idx_c * h * w + idx_h * w + idx_w;
}
};

template <typename FuncType>
void loop(FuncType func, const int n, const int c, const int h, const int w) {
for (int idx = 0; idx < n; ++idx) {
Expand Down Expand Up @@ -71,12 +84,13 @@ void cpu_run_batch_norm_train(const std::vector<T>& x,
std::vector<T>* variance,
const float epsilon = 1e-5,
const float momentum = 0.9f) {
Offset offset(n, c, h, w);

// sum
memset(mean->data(), 0, sizeof(T) * c);
auto func_sum = [=](int idx, int idy, int idz, int ida) {
mean->at(idy) += x[idx * c * h * w + idy * h * w + idz * w + ida];
};
auto func_sum = [=](int idx, int idy, int idz, int ida) { mean->at(idy) += x[offset(idx, idy, idz, ida)]; };
loop(func_sum, n, c, h, w);

// mean
for (int idx = 0; idx < c; ++idx) {
mean->at(idx) /= float(n * h * w);
Expand All @@ -85,26 +99,31 @@ void cpu_run_batch_norm_train(const std::vector<T>& x,
// square
std::vector<float> square_mean(c, 0);
auto func_sum_square = [&](int idx, int idy, int idz, int ida) {
square_mean.at(idy) +=
x[idx * c * h * w + idy * h * w + idz * w + ida] * x[idx * c * h * w + idy * h * w + idz * w + ida];
square_mean.at(idy) += x[offset(idx, idy, idz, ida)] * x[offset(idx, idy, idz, ida)];
};
loop(func_sum_square, n, c, h, w);
//

for (int idx = 0; idx < c; ++idx) {
square_mean[idx] /= float(n * h * w);
}

std::vector<float> std_variance(c);
// sum diff2
std::vector<float> std_variance(c);
for (int idx = 0; idx < c; ++idx) {
variance->at(idx) = square_mean[idx] - (mean->at(idx) * mean->at(idx));
std_variance[idx] = sqrt(variance->at(idx) + epsilon);
}

// compute output
std::vector<float> normalization(n * c * h * w);
auto func_normalization = [&](int idx, int idy, int idz, int ida) {
normalization[offset(idx, idy, idz, ida)] =
((x[offset(idx, idy, idz, ida)] - mean->at(idy)) * scale[idy]) / std_variance[idy];
};
loop(func_normalization, n, c, h, w);

auto func_y = [&](int idx, int idy, int idz, int ida) {
y->at(idx * c * h * w + idy * h * w + idz * w + ida) =
(x[idx * c * h * w + idy * h * w + idz * w + ida] - mean->at(idy)) / std_variance[idy] * scale[idy] + bias[idy];
y->at(offset(idx, idy, idz, ida)) = normalization[offset(idx, idy, idz, ida)] + bias[idy];
};
loop(func_y, n, c, h, w);

Expand Down Expand Up @@ -194,9 +213,13 @@ TEST(nn, BATCH_NORM_TRAIN) {
std::vector<float> data(tensor->shape().numel());
CopyToVector(tensor, &data);

LOG(INFO) << output.first << " " << tensor->shape().numel();
LOG(INFO) << output.first << ", shape=" << tensor->shape().numel();
for (int idx = 0; idx < tensor->shape().numel(); ++idx) {
ASSERT_LT(abs((data[idx] - output.second[idx]) / data[idx]), 1e-3);
float diff = abs((data[idx] - output.second[idx]) / data[idx]);
if (diff > 1e-5) {
LOG(INFO) << "i=" << idx << ", " << data[idx] << " vs " << output.second[idx] << ", diff=" << diff;
}
ASSERT_LT(diff, 1e-5);
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions cinn/frontend/decomposer/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ void elementwise_add(const Instruction& instr, const DecomposerContext& context)
bcast_y = builder->BroadcastTo(y, output->shape, bcast_axes_y);
}

auto numel = [](const std::vector<int>& v) {
return std::accumulate(v.begin(), v.end(), 1, [](int a, int b) { return a * b; });
};

out = builder->Add(bcast_x, bcast_y);

// map the the output of decomposed operator to the original.
Expand Down
20 changes: 17 additions & 3 deletions python/tests/ops/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,28 @@ def check_outputs_and_grads(self):
self.check_results(self.paddle_grads, self.cinn_grads)

def check_results(self, expect_res, actual_res):
def _max_relative_error(i, expect, actual, max_relative_error=1e-5):
diff = (np.abs(expect - actual) / np.abs(expect)).flatten()
max_diff = np.max(diff)
offset = np.argmax(diff > max_relative_error)
error_message = "The maximum relative error of %d-th output is %e, offset=%d, shape=%s, %e vs %e." % (
i, max_diff, offset, str(expect.shape),
expect.flatten()[offset], actual.flatten()[offset])
return error_message

self.assertEqual(len(expect_res), len(actual_res))
for i in range(len(expect_res)):
if expect_res[i] is None:
continue

logger.debug("Check the %d -th Result..." % i)
self.assertTrue(
np.allclose(expect_res[i], actual_res[i], atol=1e-6))
if isinstance(expect_res[i], paddle.Tensor):
expect = expect_res[i].numpy()
else:
expect = expect_res[i]
actual = actual_res[i]
is_allclose = np.allclose(expect, actual, atol=1e-6)
self.assertTrue(is_allclose, _max_relative_error(
i, expect, actual))


class OpTestTool:
Expand Down
4 changes: 2 additions & 2 deletions python/tests/ops/test_batch_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def _random(shape, dtype):

def config(self):
self.dtype = "float32"
self.x_shape = [128, 64, 112, 112]
self.param_shape = [64]
self.x_shape = [4, 16, 4, 4]
self.param_shape = [16]
self.epsilon = 1e-05
self.momentum = 0.9
self.data_format = "NCHW"
Expand Down

0 comments on commit 02d8f7c

Please sign in to comment.