diff --git a/cinn/hlir/op/contrib/argmin.cc b/cinn/hlir/op/contrib/argmin.cc index 51214f30eb..7ad8b3e76d 100644 --- a/cinn/hlir/op/contrib/argmin.cc +++ b/cinn/hlir/op/contrib/argmin.cc @@ -113,18 +113,15 @@ std::shared_ptr StrategyForArgmin(const framework::NodeAt framework::CINNCompute argmin_compute([=](lang::Args args, lang::RetValue *ret) { CHECK(!args.empty()) << "The input argument of argmin compute is empty! Please check."; common::CINNValuePack pack_args = args[0]; - std::string tensor_name = UniqName("Argmin_out"); CHECK_GE(pack_args.size(), 1U) << "There should be 1 input args for argmax compute"; Expr in_expr = pack_args[0]; CHECK(in_expr.as_tensor()); Tensor in_tensor = in_expr.as_tensor_ref(); auto stages = CreateStages({in_tensor}); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2U); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); - } - auto out_tensor = Argmin(in_tensor, target, stages, axis, keep_dims, tensor_name); + CHECK_EQ(pack_args.size(), 2U); + CHECK(pack_args[1].is_string()); + std::string tensor_name = pack_args[1].operator std::string(); + auto out_tensor = Argmin(in_tensor, target, stages, axis, keep_dims, tensor_name); stages->InsertLazily(out_tensor[0]); std::vector cinn_values{ @@ -133,38 +130,30 @@ std::shared_ptr StrategyForArgmin(const framework::NodeAt }); framework::CINNSchedule argmin_schedule([=](lang::Args args, lang::RetValue *ret) { - if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of arange_schedule is empty! Please check.\n"; - common::CINNValuePack arg_pack = args[0]; - std::vector vec_ast; - for (int i = 0; i < arg_pack.size(); i++) { - if (arg_pack[i].is_expr()) { - Expr temp = arg_pack[i]; - vec_ast.emplace_back(temp); - } - } - CHECK(!vec_ast.empty()); - ir::ModuleExpr mod_expr(vec_ast); - ir::IRSchedule ir_sch(mod_expr); - ir_sch.MergeExprs(); - auto blocks = ir_sch.GetAllBlocks(); - // TODO: It needs to be rewritten according to the reduction_min operator to improve performance. - // Do not use local variables, because the size will exceed the limit. - ir_sch.SetBuffer(blocks[0], "local"); - ir_sch.SetBuffer(blocks[1], "local"); - long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies()); - if (prod_size > 1 && target.arch == Target::Arch::X86) { - pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true); + CHECK(!args.empty()) << "The input argument of arange_schedule is empty! Please check.\n"; + common::CINNValuePack arg_pack = args[0]; + std::vector vec_ast; + for (int i = 0; i < arg_pack.size(); i++) { + if (arg_pack[i].is_expr()) { + Expr temp = arg_pack[i]; + vec_ast.emplace_back(temp); } - std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; - *ret = common::CINNValuePack{res}; - } else { - CHECK(!args.empty()) << "The input argument of arange_schedule is empty! Please check.\n"; - common::CINNValuePack arg_pack = args[0]; - Expr out = arg_pack[0]; - CHECK(out.as_tensor()); - *ret = arg_pack; } + CHECK(!vec_ast.empty()); + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + ir_sch.MergeExprs(); + auto blocks = ir_sch.GetAllBlocks(); + // TODO: It needs to be rewritten according to the reduction_min operator to improve performance. + // Do not use local variables, because the size will exceed the limit. + ir_sch.SetBuffer(blocks[0], "local"); + ir_sch.SetBuffer(blocks[1], "local"); + long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies()); + if (prod_size > 1 && target.arch == Target::Arch::X86) { + pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true); + } + std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + *ret = common::CINNValuePack{res}; }); auto strategy = std::make_shared();