Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Experimental PR for the first OP to clean old schedule (#1524)
Browse files Browse the repository at this point in the history
This PR tried to use the new schedule replace old schedule completely.
  • Loading branch information
zhhsplendid authored Jun 16, 2023
1 parent 725ddd5 commit bf52680
Showing 1 changed file with 26 additions and 37 deletions.
63 changes: 26 additions & 37 deletions cinn/hlir/op/contrib/argmin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,15 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgmin(const framework::NodeAt
framework::CINNCompute argmin_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of argmin compute is empty! Please check.";
common::CINNValuePack pack_args = args[0];
std::string tensor_name = UniqName("Argmin_out");
CHECK_GE(pack_args.size(), 1U) << "There should be 1 input args for argmax compute";
Expr in_expr = pack_args[0];
CHECK(in_expr.as_tensor());
Tensor in_tensor = in_expr.as_tensor_ref();
auto stages = CreateStages({in_tensor});
if (FLAGS_cinn_ir_schedule) {
CHECK_EQ(pack_args.size(), 2U);
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
auto out_tensor = Argmin(in_tensor, target, stages, axis, keep_dims, tensor_name);
CHECK_EQ(pack_args.size(), 2U);
CHECK(pack_args[1].is_string());
std::string tensor_name = pack_args[1].operator std::string();
auto out_tensor = Argmin(in_tensor, target, stages, axis, keep_dims, tensor_name);

stages->InsertLazily(out_tensor[0]);
std::vector<CINNValue> cinn_values{
Expand All @@ -133,38 +130,30 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgmin(const framework::NodeAt
});

framework::CINNSchedule argmin_schedule([=](lang::Args args, lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) {
CHECK(!args.empty()) << "The input argument of arange_schedule is empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
std::vector<Expr> vec_ast;
for (int i = 0; i < arg_pack.size(); i++) {
if (arg_pack[i].is_expr()) {
Expr temp = arg_pack[i];
vec_ast.emplace_back(temp);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
auto blocks = ir_sch.GetAllBlocks();
// TODO: It needs to be rewritten according to the reduction_min operator to improve performance.
// Do not use local variables, because the size will exceed the limit.
ir_sch.SetBuffer(blocks[0], "local");
ir_sch.SetBuffer(blocks[1], "local");
long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies<int>());
if (prod_size > 1 && target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true);
CHECK(!args.empty()) << "The input argument of arange_schedule is empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
std::vector<Expr> vec_ast;
for (int i = 0; i < arg_pack.size(); i++) {
if (arg_pack[i].is_expr()) {
Expr temp = arg_pack[i];
vec_ast.emplace_back(temp);
}
std::vector<common::CINNValue> res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
} else {
CHECK(!args.empty()) << "The input argument of arange_schedule is empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
CHECK(out.as_tensor());
*ret = arg_pack;
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
auto blocks = ir_sch.GetAllBlocks();
// TODO: It needs to be rewritten according to the reduction_min operator to improve performance.
// Do not use local variables, because the size will exceed the limit.
ir_sch.SetBuffer(blocks[0], "local");
ir_sch.SetBuffer(blocks[1], "local");
long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies<int>());
if (prod_size > 1 && target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true);
}
std::vector<common::CINNValue> res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
});

auto strategy = std::make_shared<framework::OpStrategy>();
Expand Down

0 comments on commit bf52680

Please sign in to comment.