Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

op unittest for repeat/arange/reverse/elementwise_add_grad/flip #1514

Merged
merged 11 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -816,11 +816,7 @@ Variable NetBuilder::Arange(const float start, const float stop, const float ste
}

Variable NetBuilder::Flip(const Variable& operand, const std::vector<int>& axes) {
Instruction instr("flip", {operand});
instr.SetAttr("axes", axes);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
return CustomInstr("reverse", {operand}, {{"axis", utils::GetPositiveAxes(axes, operand->shape.size())}}).front();
}

Variable NetBuilder::Matmul(const Variable& x, const Variable& y, bool trans_x, bool trans_y, float alpha) {
Expand Down
5 changes: 4 additions & 1 deletion cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,10 @@ class NetBuilder {
const std::string& padding_algorithm = "EXPLICIT");

/**
* This API flipes the Variable x along the given axis.
* @brief This API reverse the Variable x along the given axis.
* @param x An N-D variable.
* @param axis Specify the axis to operate on the input reverse.
* @return A reversed variable with the same data type as x.
*/
Variable Flip(const Variable& operand, const std::vector<int>& axes);

Expand Down
70 changes: 0 additions & 70 deletions cinn/frontend/net_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -984,76 +984,6 @@ TEST(net_build, program_execute_arange_int) {
}
}

TEST(net_build, program_execute_flip) {
const int C = 2;
const int H = 2;
const int W = 2;
const std::vector<int> axes{0};

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Float(32), {C, H, W}, "Img");
Variable output = builder.Flip(input, axes);
auto program = builder.Build();

#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
std::unordered_set<std::string> fetch_ids;
auto graph = Optimize(&program, fetch_ids, target);

auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<float>(input_tensor, target);
std::vector<float> input_data = GetTensorData<float>(input_tensor, target);

runtime_program->Execute();
auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_tensor->type(), Float(32));
EXPECT_EQ(output_shape.size(), 3UL);
EXPECT_EQ(output_shape[0], C);
EXPECT_EQ(output_shape[1], H);
EXPECT_EQ(output_shape[2], W);

std::vector<float> output_data = GetTensorData<float>(output_tensor, target);
VLOG(6) << "Visualize flip input_data";
for (int c = 0; c < C; c++) {
for (int h = 0; h < H; h++) {
std::string line;
for (int w = 0; w < W; w++) {
int index = c * (H * W) + h * W + w;
line += (std::to_string(index) + ": " + std::to_string(input_data[index]) + ", ");
}
VLOG(6) << line;
}
}

VLOG(6) << "Visualize flip output_data";
for (int c = 0; c < C; c++) {
int flip_c = std::find(axes.begin(), axes.end(), 0) == axes.end() ? c : C - c - 1;
for (int h = 0; h < H; h++) {
std::string line;
int flip_h = std::find(axes.begin(), axes.end(), 1) == axes.end() ? h : H - h - 1;
for (int w = 0; w < W; w++) {
int flip_w = std::find(axes.begin(), axes.end(), 2) == axes.end() ? w : W - w - 1;
int flip_index = flip_c * H * W + flip_h * W + flip_w;
int index = c * (H * W) + h * W + w;
line += (std::to_string(index) + ": " + std::to_string(output_data[index]) + ", ");
EXPECT_EQ(input_data[index], output_data[flip_index]);
}
VLOG(6) << line;
}
}
}

TEST(net_build, program_argmax_case1) {
const int N = 4;
const int IN_C = 3;
Expand Down
2 changes: 0 additions & 2 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ core_gather_headers()

gather_srcs(cinnapi_src SRCS
gather_nd.cc
flip.cc
sort.cc
argmin.cc
argmax.cc
Expand All @@ -24,7 +23,6 @@ cc_test(test_gather_nd SRCS gather_nd_test.cc DEPS cinncore)
cc_test(test_sort SRCS sort_test.cc DEPS cinncore)
cc_test(test_argmin SRCS argmin_test.cc DEPS cinncore)
cc_test(test_argmax SRCS argmax_test.cc DEPS cinncore)
cc_test(test_flip SRCS flip_test.cc DEPS cinncore)
cc_test(test_repeat SRCS repeat_test.cc DEPS cinncore)
cc_test(test_one_hot SRCS one_hot_test.cc DEPS cinncore)
cc_test(test_lookup_table SRCS lookup_table_test.cc DEPS cinncore)
Expand Down
118 changes: 0 additions & 118 deletions cinn/hlir/op/contrib/flip.cc

This file was deleted.

32 changes: 0 additions & 32 deletions cinn/hlir/op/contrib/flip.h

This file was deleted.

67 changes: 0 additions & 67 deletions cinn/hlir/op/contrib/flip_test.cc

This file was deleted.

9 changes: 0 additions & 9 deletions cinn/hlir/op/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,6 @@ std::shared_ptr<OpStrategy> StrategyForReverse(const framework::NodeAttr &attrs,
std::vector<int> axis;
if (attrs.attr_store.find("axis") != attrs.attr_store.end()) {
axis = absl::get<std::vector<int>>(attrs.attr_store.at("axis"));
CHECK(!axis.empty()) << "axis is empty! Please check setting.\n";
for (auto &e : axis) {
if (e >= static_cast<int>(output_shapes[0].size()) || e < -1 * static_cast<int>(output_shapes[0].size())) {
LOG(FATAL) << "axis is not in [0, n_dim), Please check.";
Expand All @@ -840,8 +839,6 @@ std::shared_ptr<OpStrategy> StrategyForReverse(const framework::NodeAttr &attrs,
e += output_shapes[0].size();
}
}
} else {
LOG(FATAL) << "axis is not be set! Please check.";
}

framework::CINNCompute reverse_compute([=](lang::Args args, lang::RetValue *ret) {
Expand Down Expand Up @@ -875,7 +872,6 @@ std::vector<framework::shape_t> InferShapeForReverse(const std::vector<framework
std::vector<framework::shape_t> res{inputs_shape[0]};
if (attrs.find("axis") != attrs.end()) {
auto axis = absl::get<std::vector<int>>(attrs.at("axis"));
CHECK(!axis.empty()) << "axis is empty! Please check setting.\n";
for (auto &e : axis) {
if (e >= static_cast<int>(inputs_shape[0].size()) || e < -1 * static_cast<int>(inputs_shape[0].size())) {
LOG(FATAL) << "axis is not in [-n_dim, n_dim), Please check.";
Expand All @@ -884,8 +880,6 @@ std::vector<framework::shape_t> InferShapeForReverse(const std::vector<framework
e += inputs_shape[0].size();
}
}
} else {
LOG(FATAL) << "axis is not be set! Please check.";
}
return res;
}
Expand All @@ -896,14 +890,11 @@ std::vector<std::vector<std::string>> InferLayoutForReverse(const std::vector<fr
const Target &target) {
if (attrs.attr_store.find("axis") != attrs.attr_store.end()) {
auto axis = absl::get<std::vector<int>>(attrs.attr_store.at("axis"));
CHECK(!axis.empty()) << "axis is empty! Please check setting.\n";
for (auto &e : axis) {
if (e >= static_cast<int>(input_shapes[0].size()) || e < -1 * static_cast<int>(input_shapes[0].size())) {
LOG(FATAL) << "axis is not in [-n_dim, n_dim), Please check.";
}
}
} else {
LOG(FATAL) << "axis is not be set! Please check.";
}
CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again.";
return {input_layouts, input_layouts};
Expand Down
1 change: 0 additions & 1 deletion cinn/hlir/op/use_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ CINN_USE_REGISTER(argmin_ops)
CINN_USE_REGISTER(argmax_ops)
CINN_USE_REGISTER(reduce_ops)
CINN_USE_REGISTER(custom_call_op)
CINN_USE_REGISTER(flip_ops)
CINN_USE_REGISTER(repeat_ops)
CINN_USE_REGISTER(one_hot_ops)
CINN_USE_REGISTER(lookup_table_ops)
Expand Down
Loading