diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5eeea1831834b..b30d8c5de626d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,7 +47,8 @@ repos: files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$ exclude: | (?x)^( - paddle/utils/.* + paddle/utils/.*| + paddle/cinn/utils/registry.h )$ # For Python files - repo: https://github.com/psf/black.git diff --git a/paddle/cinn/auto_schedule/analysis/analyze_ir.cc b/paddle/cinn/auto_schedule/analysis/analyze_ir.cc index b64449d703aab..cf4b17747ad3b 100644 --- a/paddle/cinn/auto_schedule/analysis/analyze_ir.cc +++ b/paddle/cinn/auto_schedule/analysis/analyze_ir.cc @@ -41,7 +41,7 @@ std::vector IndicesToVars(const std::vector& indices) { for (const ir::Expr& e : indices) { // Whether we have to convert other types, like const numbers to Var? if (e.As() != nullptr) { - ir::Expr copy_e = optim::IRCopy(e); + ir::Expr copy_e = optim::IRCopy(e); ir::_Var_* var_ref = copy_e.As(); result.emplace_back(ir::Var(var_ref)); } @@ -58,26 +58,32 @@ void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block) { const ir::Load* load_expr = x->As(); if (load_expr != nullptr) { const ir::Tensor t = load_expr->tensor.as_tensor_ref(); - sche_block->read_buffers.emplace_back(ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices))); + sche_block->read_buffers.emplace_back( + ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices))); return false; } const ir::Store* store_expr = x->As(); if (store_expr != nullptr) { const ir::Tensor t = store_expr->tensor.as_tensor_ref(); - sche_block->write_buffers.emplace_back(ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices))); + sche_block->write_buffers.emplace_back( + ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices))); return false; } return false; }); } -bool ContainsNodeType(ir::Expr expr, const std::unordered_set& node_types) { - std::set collection = ir::CollectIRNodesWithoutTensor( - expr, [&](const Expr* x) { return node_types.find(x->node_type()) != node_types.end(); }); +bool ContainsNodeType(ir::Expr expr, + const std::unordered_set& node_types) { + std::set collection = + ir::CollectIRNodesWithoutTensor(expr, [&](const Expr* x) { + return node_types.find(x->node_type()) != node_types.end(); + }); return !collection.empty(); } -std::unordered_set GetOutputNamesFromLoweredFunc(const std::vector& lowered_funcs) { +std::unordered_set GetOutputNamesFromLoweredFunc( + const std::vector& lowered_funcs) { std::unordered_set result; for (const ir::LoweredFunc& func : lowered_funcs) { for (const ir::Argument& arg : func->args) { @@ -90,18 +96,22 @@ std::unordered_set GetOutputNamesFromLoweredFunc(const std::vector< } bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) { - const ir::ScheduleBlock* sche_block = sche_block_realize.schedule_block.As(); - if (sche_block->write_buffers.size() != 1 || sche_block->read_buffers.empty()) { + const ir::ScheduleBlock* sche_block = + sche_block_realize.schedule_block.As(); + if (sche_block->write_buffers.size() != 1 || + sche_block->read_buffers.empty()) { return false; } - const ir::Expr& write_buffer = sche_block->write_buffers[0].As()->buffer; + const ir::Expr& write_buffer = + sche_block->write_buffers[0].As()->buffer; // Enumerate each read region, get the number of schedule block iter vars // which are not used to index the read region int total_unused_iter_vars = 0; for (const ir::Expr& read_buffer_expr : sche_block->read_buffers) { - const ir::_BufferRange_* read_buffer = read_buffer_expr.As(); + const ir::_BufferRange_* read_buffer = + read_buffer_expr.As(); // Skip the reduction buffer if (read_buffer->buffer == write_buffer) { continue; @@ -133,18 +143,22 @@ bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) { return total_unused_iter_vars >= 1; } -ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::LoweredFunc& old_func, ir::Expr& body) { +ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, + const ir::LoweredFunc& old_func, + ir::Expr& body) { ir::ModuleExpr mod_expr(std::vector({body})); ir::IRSchedule ir_sch(mod_expr); // temp_bufs may be deleted during auto tuning (such as auto inline), // we have to check from old temp bufs and set them as local buffer. for (const ir::Buffer& buf : old_func->temp_bufs) { - const std::string& buf_name = buf->name; + const std::string& buf_name = buf->name; std::vector all_block_realizes = ir_sch.GetAllBlocks(); for (ir::Expr& e : all_block_realizes) { - const ir::ScheduleBlockRealize* sche_block_realize = e.As(); - const std::string& sche_name = sche_block_realize->schedule_block.As()->name; + const ir::ScheduleBlockRealize* sche_block_realize = + e.As(); + const std::string& sche_name = + sche_block_realize->schedule_block.As()->name; if (buf_name == "_" + sche_name) { VLOG(6) << "Set local buffer for temp buffer " << buf_name; ir_sch.SetBuffer(e, "local", true); @@ -159,14 +173,17 @@ ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::Lo #endif // Get new temp bufs by analyzing. - std::vector new_temp_bufs = lang::GetTempBuffers(old_func->args, updated_body); - ir::LoweredFunc new_func = ir::_LoweredFunc_::Make(old_func->name, old_func->args, updated_body, new_temp_bufs); + std::vector new_temp_bufs = + lang::GetTempBuffers(old_func->args, updated_body); + ir::LoweredFunc new_func = ir::_LoweredFunc_::Make( + old_func->name, old_func->args, updated_body, new_temp_bufs); #ifdef CINN_WITH_CUDA if (target == common::DefaultNVGPUTarget()) { new_func->PrepareCudaAxisInfoFromBody(); } #endif - new_func = optim::Optimize(Expr(new_func), target, false).as_lowered_func_ref(); + new_func = + optim::Optimize(Expr(new_func), target, false).as_lowered_func_ref(); new_func->PrepareBufferCastExprs(/*with_expr_gen_tensor = */ false); return new_func; diff --git a/paddle/cinn/auto_schedule/analysis/analyze_ir.h b/paddle/cinn/auto_schedule/analysis/analyze_ir.h index fdd8d9604ac29..4e48be04ee5fc 100644 --- a/paddle/cinn/auto_schedule/analysis/analyze_ir.h +++ b/paddle/cinn/auto_schedule/analysis/analyze_ir.h @@ -27,12 +27,14 @@ namespace auto_schedule { void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block); -bool ContainsNodeType(ir::Expr expr, const std::unordered_set& node_types); +bool ContainsNodeType(ir::Expr expr, + const std::unordered_set& node_types); /** * Collects all input lowered_funcs and return names of all output arguments */ -std::unordered_set GetOutputNamesFromLoweredFunc(const std::vector& lowered_funcs); +std::unordered_set GetOutputNamesFromLoweredFunc( + const std::vector& lowered_funcs); /** * Determine whether a schedule block needs multileveltiling @@ -42,7 +44,9 @@ bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize); /** * Update a LoweredFunc by regenerating related fields with a new function body */ -ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::LoweredFunc& old_func, ir::Expr& body); +ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, + const ir::LoweredFunc& old_func, + ir::Expr& body); } // namespace auto_schedule } // namespace cinn diff --git a/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc b/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc index e0e426575bedd..232a8a498ffbe 100644 --- a/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc +++ b/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc @@ -49,8 +49,9 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) { ir::Tensor B = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); - poly::StageMap stages = poly::CreateStages({A, B}); - std::vector funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); + poly::StageMap stages = poly::CreateStages({A, B}); + std::vector funcs = lang::LowerVec( + "SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); ASSERT_FALSE(funcs.empty()); ir::Expr ast_expr = funcs[0]->body; @@ -65,8 +66,10 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) { std::vector all_block_realizes = ir_sch.GetAllBlocks(); ASSERT_EQ(all_block_realizes.size(), 1UL); - ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes[0].As(); - ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As(); + ir::ScheduleBlockRealize* sche_block_realize = + all_block_realizes[0].As(); + ir::ScheduleBlock* sche_block = + sche_block_realize->schedule_block.As(); AnalyzeScheduleBlockReadWriteBuffer(sche_block); /* @@ -112,8 +115,9 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) { ir::Tensor C = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); - poly::StageMap stages = poly::CreateStages({C}); - std::vector funcs = lang::LowerVec("AddDiffShape", stages, {C}, {}, {}, nullptr, target, true); + poly::StageMap stages = poly::CreateStages({C}); + std::vector funcs = lang::LowerVec( + "AddDiffShape", stages, {C}, {}, {}, nullptr, target, true); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr before MultiLevelTiling: "; @@ -126,8 +130,10 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) { std::vector all_block_realizes = ir_sch.GetAllBlocks(); ASSERT_EQ(all_block_realizes.size(), 1UL); - ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes[0].As(); - ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As(); + ir::ScheduleBlockRealize* sche_block_realize = + all_block_realizes[0].As(); + ir::ScheduleBlock* sche_block = + sche_block_realize->schedule_block.As(); AnalyzeScheduleBlockReadWriteBuffer(sche_block); VLOG(6) << "ScheduleBlockRealize: "; @@ -163,8 +169,9 @@ TEST(AnalyzeIr, ContainsNodeType) { ir::Tensor B = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); - poly::StageMap stages = poly::CreateStages({A, B}); - std::vector funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); + poly::StageMap stages = poly::CreateStages({A, B}); + std::vector funcs = lang::LowerVec( + "SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); ASSERT_FALSE(funcs.empty()); ir::Expr ast_expr = funcs[0]->body; @@ -172,9 +179,12 @@ TEST(AnalyzeIr, ContainsNodeType) { VLOG(6) << "Analyzing for Expr:"; VLOG(6) << ast_expr; - ASSERT_TRUE(ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::Store})); - ASSERT_TRUE(ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::IfThenElse})); - ASSERT_FALSE(ContainsNodeType(ast_expr, {ir::IrNodeTy::IfThenElse, ir::IrNodeTy::Sum})); + ASSERT_TRUE( + ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::Store})); + ASSERT_TRUE(ContainsNodeType(ast_expr, + {ir::IrNodeTy::Load, ir::IrNodeTy::IfThenElse})); + ASSERT_FALSE(ContainsNodeType(ast_expr, + {ir::IrNodeTy::IfThenElse, ir::IrNodeTy::Sum})); } } // namespace auto_schedule diff --git a/paddle/cinn/auto_schedule/auto_tuner.cc b/paddle/cinn/auto_schedule/auto_tuner.cc index f9c8a1da4fe72..68f5b6d199d49 100644 --- a/paddle/cinn/auto_schedule/auto_tuner.cc +++ b/paddle/cinn/auto_schedule/auto_tuner.cc @@ -38,13 +38,17 @@ namespace cinn { namespace auto_schedule { -AutoTuner::AutoTuner(const common::Target& target, hlir::framework::Graph* graph) : target_(target), graph_(graph) {} +AutoTuner::AutoTuner(const common::Target& target, + hlir::framework::Graph* graph) + : target_(target), graph_(graph) {} -void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler* graph_compiler) { +void AutoTuner::Initialize(const Config& config, + hlir::framework::GraphCompiler* graph_compiler) { // create builder, runner, and schedule measurer - builder_ = std::make_unique(graph_compiler); - runner_ = std::make_unique(config.runner_repeat_times); - schedule_measurer_ = std::make_unique(builder_.get(), runner_.get()); + builder_ = std::make_unique(graph_compiler); + runner_ = std::make_unique(config.runner_repeat_times); + schedule_measurer_ = + std::make_unique(builder_.get(), runner_.get()); // initialize database database_ = std::move(Database::Make(config.database_config)); @@ -53,29 +57,43 @@ void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler* TaskCreator task_creator; tasks_ = task_creator.CreateTuneTaskOpLevel(graph_); - const auto& dtype_dict = graph_->GetAttrs>("inferdtype"); - const auto& shape_dict = graph_->GetAttrs>("infershape"); + const auto& dtype_dict = + graph_->GetAttrs>( + "inferdtype"); + const auto& shape_dict = graph_->GetAttrs< + absl::flat_hash_map>("infershape"); - op_lowerer_ = std::make_unique(dtype_dict, shape_dict, target_); + op_lowerer_ = std::make_unique( + dtype_dict, shape_dict, target_); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); for (auto i = 0; i < tasks_.size(); ++i) { auto&& task = tasks_[i]; task.Initialize(shape_dict, dtype_dict, op_lowerer_.get()); // Register the initial ModuleExpr corresponding to the task - task_registry->Regist(task.serialized_key, ir::ModuleExpr(task.GetLoweredFuncBodyExprs())); - VLOG(3) << "Add a task, id:" << i << ", serialized_key:\n" << task.serialized_key; + task_registry->Regist(task.serialized_key, + ir::ModuleExpr(task.GetLoweredFuncBodyExprs())); + VLOG(3) << "Add a task, id:" << i << ", serialized_key:\n" + << task.serialized_key; } // create task optimizers - utils::LinearRandomEngine::StateType initial_seed = utils::LinearRandomEngine::GetDeviceRandomValue(); + utils::LinearRandomEngine::StateType initial_seed = + utils::LinearRandomEngine::GetDeviceRandomValue(); task_optimizers_.resize(tasks_.size()); - std::transform(tasks_.begin(), tasks_.end(), task_optimizers_.begin(), [&](TuneTask& task) { - return std::make_unique( - &task, schedule_measurer_.get(), database_.get(), utils::ForkRandomState(&initial_seed)); - }); + std::transform(tasks_.begin(), + tasks_.end(), + task_optimizers_.begin(), + [&](TuneTask& task) { + return std::make_unique( + &task, + schedule_measurer_.get(), + database_.get(), + utils::ForkRandomState(&initial_seed)); + }); // create task scheduler - task_scheduler_ = TaskScheduler::Make(tasks_, config.task_schedule_config, config.task_schedule_strategy); + task_scheduler_ = TaskScheduler::Make( + tasks_, config.task_schedule_config, config.task_schedule_strategy); } void PrintResult(std::shared_ptr group) { @@ -127,7 +145,8 @@ void PrintResult(const TuningResult& result) { TuningResult AutoTuner::Tune(const TuningOptions& options) { CHECK_GT(options.num_tuning_rounds, 0) << "Invalid config"; - VLOG(3) << "Begin tuning with round num=" << options.num_tuning_rounds << ", tasks size=" << tasks_.size(); + VLOG(3) << "Begin tuning with round num=" << options.num_tuning_rounds + << ", tasks size=" << tasks_.size(); TuningResult result; result.subgraphs.resize(tasks_.size()); @@ -136,7 +155,7 @@ TuningResult AutoTuner::Tune(const TuningOptions& options) { // as default result of graph tuning, and that should be updated // once we support graph tuning. for (auto i = 0; i < tasks_.size(); ++i) { - auto&& task = tasks_.at(i); + auto&& task = tasks_.at(i); result.subgraphs[i] = task.subgraph; } @@ -146,7 +165,7 @@ TuningResult AutoTuner::Tune(const TuningOptions& options) { task_scheduler_->Reset(); while ((run_id = task_scheduler_->NextTaskId()) != -1) { VLOG(3) << "Start tuning Task-" << run_id; - auto* opt = task_optimizers_.at(run_id).get(); + auto* opt = task_optimizers_.at(run_id).get(); auto function_group = opt->Optimize(options); VLOG(3) << "Task-" << run_id << " finished, print optimized functions:\n"; PrintResult(function_group); diff --git a/paddle/cinn/auto_schedule/auto_tuner.h b/paddle/cinn/auto_schedule/auto_tuner.h index 70dd824391aeb..1a4e3c8c60d8e 100644 --- a/paddle/cinn/auto_schedule/auto_tuner.h +++ b/paddle/cinn/auto_schedule/auto_tuner.h @@ -49,7 +49,8 @@ class AutoTuner { AutoTuner(const common::Target& target, hlir::framework::Graph* graph); // Initialize tuner with specific config and auxiliary objects. - void Initialize(const Config& config, hlir::framework::GraphCompiler* graph_compiler); + void Initialize(const Config& config, + hlir::framework::GraphCompiler* graph_compiler); // Perform the tuning process and return the final result TuningResult Tune(const TuningOptions& options); diff --git a/paddle/cinn/auto_schedule/auto_tuner_test.cc b/paddle/cinn/auto_schedule/auto_tuner_test.cc index 10a417720cffe..65a4dcf919e80 100644 --- a/paddle/cinn/auto_schedule/auto_tuner_test.cc +++ b/paddle/cinn/auto_schedule/auto_tuner_test.cc @@ -73,14 +73,16 @@ class TestAutoTuner : public ::testing::Test { // AutoTuner is combined with new IR Schedule FLAGS_cinn_ir_schedule = true; std::unordered_set fetch_ids; - auto program = CreateAddReluProgram(); - auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); + auto program = CreateAddReluProgram(); + auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); compiled_scope = BuildScope(target, graph); - graph_compiler = std::make_unique(target, compiled_scope, graph); - tuner = std::make_unique(target, graph.get()); + graph_compiler = + std::make_unique(target, compiled_scope, graph); + tuner = std::make_unique(target, graph.get()); } - TuningResult InitializeAndTune(const AutoTuner::Config& config, const TuningOptions& options) { + TuningResult InitializeAndTune(const AutoTuner::Config& config, + const TuningOptions& options) { tuner->Initialize(config, graph_compiler.get()); return tuner->Tune(options); } @@ -108,7 +110,8 @@ class TestAutoTuner : public ::testing::Test { VLOG(6) << "Print lowered_funcs before building"; VLOG(6) << compile_options.lowered_funcs[0][0]; VLOG(6) << compile_options.lowered_funcs[1][0]; - auto runtime_program = graph_compiler->Build(compile_options).runtime_program; + auto runtime_program = + graph_compiler->Build(compile_options).runtime_program; ASSERT_EQ(1, runtime_program->size()); runtime_program->Execute(); } @@ -120,7 +123,7 @@ class TestAutoTuner : public ::testing::Test { TuningOptions tuning_options; tuning_options.num_measure_trials = 0; - auto result = InitializeAndTune(tuning_config, tuning_options); + auto result = InitializeAndTune(tuning_config, tuning_options); BasicCheckResult(result); ApplyTunedAndRun(result); } @@ -131,7 +134,7 @@ class TestAutoTuner : public ::testing::Test { tuning_config.task_schedule_strategy = "round_robin"; TuningOptions tuning_options; - tuning_options.num_measure_trials = 4; + tuning_options.num_measure_trials = 4; tuning_options.num_samples_per_iteration = 2; auto result = InitializeAndTune(tuning_config, tuning_options); diff --git a/paddle/cinn/auto_schedule/cost_model/expr_cost_model.cc b/paddle/cinn/auto_schedule/cost_model/expr_cost_model.cc index fadee09ed1a1b..f433115036bc1 100644 --- a/paddle/cinn/auto_schedule/cost_model/expr_cost_model.cc +++ b/paddle/cinn/auto_schedule/cost_model/expr_cost_model.cc @@ -28,14 +28,15 @@ namespace cinn { namespace auto_schedule { -float ExprCostModel::Predict(const ir::ModuleExpr& sample, const common::Target& target) const { +float ExprCostModel::Predict(const ir::ModuleExpr& sample, + const common::Target& target) const { if (trained_times_.load() == 0) { return SearchState::NOT_INIT_COST; } FeatureExtractor extractor; - Feature feature = extractor.Extract(sample, target); + Feature feature = extractor.Extract(sample, target); std::vector feature_numbers = feature.ToFixedSizeVector(); - std::vector pred = XgbCostModel::Predict({feature_numbers}); + std::vector pred = XgbCostModel::Predict({feature_numbers}); return pred[0]; } @@ -44,12 +45,13 @@ void ExprCostModel::Train(const std::vector& samples, const common::Target& target) { trained_times_.store(1); size_t total_size = samples.size(); - CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels"; + CHECK_EQ(total_size, labels.size()) + << "Samples must have same size as labels"; std::vector> train_feature_numbers(total_size); FeatureExtractor extractor; for (size_t i = 0; i < total_size; ++i) { CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr"; - Feature feature = extractor.Extract(*samples[i], target); + Feature feature = extractor.Extract(*samples[i], target); train_feature_numbers[i] = feature.ToFixedSizeVector(); } @@ -61,12 +63,13 @@ void ExprCostModel::Update(const std::vector& samples, const common::Target& target) { ++trained_times_; size_t total_size = samples.size(); - CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels"; + CHECK_EQ(total_size, labels.size()) + << "Samples must have same size as labels"; std::vector> train_feature_numbers(total_size); FeatureExtractor extractor; for (size_t i = 0; i < total_size; ++i) { CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr"; - Feature feature = extractor.Extract(*samples[i], target); + Feature feature = extractor.Extract(*samples[i], target); train_feature_numbers[i] = feature.ToFixedSizeVector(); } diff --git a/paddle/cinn/auto_schedule/cost_model/expr_cost_model.h b/paddle/cinn/auto_schedule/cost_model/expr_cost_model.h index 8aadec6f7ca3f..c0fe6ee899e43 100644 --- a/paddle/cinn/auto_schedule/cost_model/expr_cost_model.h +++ b/paddle/cinn/auto_schedule/cost_model/expr_cost_model.h @@ -29,7 +29,8 @@ namespace auto_schedule { */ class ExprCostModel : public XgbCostModel { public: - virtual float Predict(const ir::ModuleExpr& sample, const common::Target& target) const; + virtual float Predict(const ir::ModuleExpr& sample, + const common::Target& target) const; void Train(const std::vector& samples, const std::vector& labels, const common::Target& target); diff --git a/paddle/cinn/auto_schedule/cost_model/feature.cc b/paddle/cinn/auto_schedule/cost_model/feature.cc index d2c7a89a8f3bd..f993ee256616a 100644 --- a/paddle/cinn/auto_schedule/cost_model/feature.cc +++ b/paddle/cinn/auto_schedule/cost_model/feature.cc @@ -49,7 +49,8 @@ Feature::Feature(const common::Target& target) parent_indices_(1, -1) {} std::vector Feature::ToFixedSizeVector() { - std::vector ret(LoopBlockFeature::kTotalSize + 1, 0); // LoopBlockFeature::kTotalSize plus 1 for target + std::vector ret(LoopBlockFeature::kTotalSize + 1, + 0); // LoopBlockFeature::kTotalSize plus 1 for target if (target_ == common::DefaultNVGPUTarget()) { ret[0] = 1; @@ -58,13 +59,13 @@ std::vector Feature::ToFixedSizeVector() { // loop[i] feature count should multiply iter_multi_num[i] std::vector iter_multi_num; for (size_t i = 0; i < stack_encoded_feature_.size(); ++i) { - int j = 1; + int j = 1; const LoopBlockFeature& loop_feature = stack_encoded_feature_[i]; - int loop_prod = 1; - int parent_prod = 1; + int loop_prod = 1; + int parent_prod = 1; if (i != 0) { parent_prod = iter_multi_num[parent_indices_[i]]; - loop_prod = parent_prod * loop_feature.loop_length; + loop_prod = parent_prod * loop_feature.loop_length; } iter_multi_num.push_back(loop_prod); @@ -165,11 +166,17 @@ void Feature::IntoLoopBlock() { current_loop_block_index_ = stack_encoded_feature_.size() - 1; } -void Feature::ExitLoopBlock() { current_loop_block_index_ = parent_indices_[current_loop_block_index_]; } +void Feature::ExitLoopBlock() { + current_loop_block_index_ = parent_indices_[current_loop_block_index_]; +} -LoopBlockFeature& Feature::CurrentLoopBlock() { return stack_encoded_feature_[current_loop_block_index_]; } +LoopBlockFeature& Feature::CurrentLoopBlock() { + return stack_encoded_feature_[current_loop_block_index_]; +} -const LoopBlockFeature& Feature::CurrentLoopBlock() const { return stack_encoded_feature_[current_loop_block_index_]; } +const LoopBlockFeature& Feature::CurrentLoopBlock() const { + return stack_encoded_feature_[current_loop_block_index_]; +} } // namespace auto_schedule } // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/feature.h b/paddle/cinn/auto_schedule/cost_model/feature.h index 0c8aea6b9e9e3..8b1b59d92c6e3 100644 --- a/paddle/cinn/auto_schedule/cost_model/feature.h +++ b/paddle/cinn/auto_schedule/cost_model/feature.h @@ -24,10 +24,18 @@ namespace cinn { namespace auto_schedule { /* Loop feature enums */ -enum class ForOptimizeFeatureEnum : int { kNone, kGpuBind, kParallel, kUnroll, kVectorize }; +enum class ForOptimizeFeatureEnum : int { + kNone, + kGpuBind, + kParallel, + kUnroll, + kVectorize +}; /* function to scale feature numbers */ -inline float slog(float x) { return x < 0 ? std::log2(-x + 1) : std::log2(x + 1); } +inline float slog(float x) { + return x < 0 ? std::log2(-x + 1) : std::log2(x + 1); +} class LoopBlockFeature { public: @@ -36,20 +44,20 @@ class LoopBlockFeature { // different bits, so we just distinguished int and float here /* Arithmetic features */ int float_add_or_sub = 0; - int float_mul = 0; + int float_mul = 0; int float_div_or_mod = 0; - int float_cmp = 0; - int float_math_func = 0; + int float_cmp = 0; + int float_math_func = 0; int float_other_call = 0; // like simple assign, cast, etc. int int_add_or_sub = 0; - int int_mul = 0; + int int_mul = 0; int int_div_or_mod = 0; - int int_cmp = 0; - int int_math_func = 0; + int int_cmp = 0; + int int_math_func = 0; int int_other_call = 0; // like simple assign, cast, etc. - int bool_op = 0; + int bool_op = 0; int select_op = 0; static constexpr int kArithSize = 6 * 2 + 2; @@ -61,8 +69,8 @@ class LoopBlockFeature { * may be collect operand sizes (like alloc size, write size, or so) */ int mem_alloc = 0; - int mem_free = 0; - int mem_read = 0; + int mem_free = 0; + int mem_read = 0; int mem_write = 0; static constexpr int kMemSize = 4; @@ -71,16 +79,16 @@ class LoopBlockFeature { * Reduce and Broadcast features */ int float_reduce_sum_or_sub = 0; - int float_reduce_mul = 0; - int float_reduce_div = 0; + int float_reduce_mul = 0; + int float_reduce_div = 0; int float_reduce_max_or_min = 0; - int float_broadcast = 0; + int float_broadcast = 0; int int_reduce_sum_or_sub = 0; - int int_reduce_mul = 0; - int int_reduce_div = 0; + int int_reduce_mul = 0; + int int_reduce_div = 0; int int_reduce_max_or_min = 0; - int int_broadcast = 0; + int int_broadcast = 0; static constexpr int kReduceBroadcastSize = 10; @@ -95,18 +103,20 @@ class LoopBlockFeature { /* Thread features if loop is optimized by GPU or CPU parallelism. * Useless in other cases. */ - int len_blockIdx_x = 0; - int len_blockIdx_y = 0; - int len_blockIdx_z = 0; - int len_threadIdx_x = 0; - int len_threadIdx_y = 0; - int len_threadIdx_z = 0; - int len_vthread = 0; // length of virtual thread + int len_blockIdx_x = 0; + int len_blockIdx_y = 0; + int len_blockIdx_z = 0; + int len_threadIdx_x = 0; + int len_threadIdx_y = 0; + int len_threadIdx_z = 0; + int len_vthread = 0; // length of virtual thread int vectorize_factor = 0; static constexpr int kThreadFeatureSize = 8; - static constexpr int kTotalSize = kArithSize + kMemSize + kReduceBroadcastSize + kOptApplySize + kThreadFeatureSize; + static constexpr int kTotalSize = kArithSize + kMemSize + + kReduceBroadcastSize + kOptApplySize + + kThreadFeatureSize; /* Non-feature attributes, used to maintain during feature_extractor */ @@ -158,10 +168,11 @@ class Feature { // some_compute_3 // } // - // We go through the code and push loops into stack, then the features are encoded as - // [loop_block_feature_0, loop_block_feature_1, loop_block_feature_2, loop_block_feature_3] - // where loop_block_feature_i stores the features of some_compute_i (such - // as number of arithmetic operations) + // We go through the code and push loops into stack, then the features are + // encoded as [loop_block_feature_0, loop_block_feature_1, + // loop_block_feature_2, loop_block_feature_3] where loop_block_feature_i + // stores the features of some_compute_i (such as number of arithmetic + // operations) // // loop_block_feature_0.num_sub_loops = 2 // loop_block_feature_1.num_sub_loops = 1 diff --git a/paddle/cinn/auto_schedule/cost_model/feature_extractor.cc b/paddle/cinn/auto_schedule/cost_model/feature_extractor.cc index 01e29c37c06f4..5f1e35cda0b64 100644 --- a/paddle/cinn/auto_schedule/cost_model/feature_extractor.cc +++ b/paddle/cinn/auto_schedule/cost_model/feature_extractor.cc @@ -47,7 +47,8 @@ FeatureExtractor::FeatureExtractor() {} void FeatureExtractor::Visit(const Expr *x) { IRVisitor::Visit(x); } -Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr, const common::Target &target) { +Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr, + const common::Target &target) { feature_ = Feature(target); for (const ir::Expr &e : mod_expr.GetExprs()) { Visit(&e); @@ -85,19 +86,20 @@ VisitDoNothing(_BufferRange_); NotVisitExprFields(_Tensor_) -#define VisitForDtypePattern(NodeType, member) \ - void FeatureExtractor::Visit(const NodeType *x) { \ - if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { \ - feature_.CurrentLoopBlock().float_##member += x->type().lanes(); \ - } else { \ - feature_.CurrentLoopBlock().int_##member += x->type().lanes(); \ - } \ - std::vector sub_exprs = x->expr_fields(); \ - for (const Expr *e : sub_exprs) { \ - if (e->defined()) { \ - Visit(e); \ - } \ - } \ +#define VisitForDtypePattern(NodeType, member) \ + void FeatureExtractor::Visit(const NodeType *x) { \ + if (x->type() == common::F32() || x->type() == common::F16() || \ + x->type() == common::F64()) { \ + feature_.CurrentLoopBlock().float_##member += x->type().lanes(); \ + } else { \ + feature_.CurrentLoopBlock().int_##member += x->type().lanes(); \ + } \ + std::vector sub_exprs = x->expr_fields(); \ + for (const Expr *e : sub_exprs) { \ + if (e->defined()) { \ + Visit(e); \ + } \ + } \ } VisitForDtypePattern(Add, add_or_sub); @@ -118,19 +120,21 @@ VisitForDtypePattern(PrimitiveNode, math_func); VisitForDtypePattern(Cast, other_call); VisitForDtypePattern(Let, other_call); -#define VisitForMultiOperandsDtypePattern(NodeType, member) \ - void FeatureExtractor::Visit(const NodeType *x) { \ - if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { \ - feature_.CurrentLoopBlock().float_##member += (x->operands().size() - 1); \ - } else { \ - feature_.CurrentLoopBlock().int_##member += (x->operands().size() - 1); \ - } \ - std::vector sub_exprs = x->expr_fields(); \ - for (const Expr *e : sub_exprs) { \ - if (e->defined()) { \ - Visit(e); \ - } \ - } \ +#define VisitForMultiOperandsDtypePattern(NodeType, member) \ + void FeatureExtractor::Visit(const NodeType *x) { \ + if (x->type() == common::F32() || x->type() == common::F16() || \ + x->type() == common::F64()) { \ + feature_.CurrentLoopBlock().float_##member += \ + (x->operands().size() - 1); \ + } else { \ + feature_.CurrentLoopBlock().int_##member += (x->operands().size() - 1); \ + } \ + std::vector sub_exprs = x->expr_fields(); \ + for (const Expr *e : sub_exprs) { \ + if (e->defined()) { \ + Visit(e); \ + } \ + } \ } VisitForMultiOperandsDtypePattern(Sum, add_or_sub); @@ -166,23 +170,24 @@ void FeatureExtractor::Visit(const For *x) { LoopBlockFeature &loop_feature = feature_.CurrentLoopBlock(); if (x->min.is_constant() && x->extent.is_constant()) { - loop_feature.loop_length = (x->extent.get_constant() - x->min.get_constant()); + loop_feature.loop_length = + (x->extent.get_constant() - x->min.get_constant()); } else { loop_feature.loop_length = -1; // -1 represents unknown } if (x->is_parallel()) { loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kParallel; - loop_feature.len_vthread = loop_feature.loop_length; + loop_feature.len_vthread = loop_feature.loop_length; } else if (x->is_unrolled()) { loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kUnroll; } else if (x->is_vectorized()) { - loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kVectorize; + loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kVectorize; loop_feature.vectorize_factor = x->vectorize_info().factor; } else if (x->is_binded()) { loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kGpuBind; - const BindInfo &bind_info = x->bind_info(); - int offset = bind_info.offset; + const BindInfo &bind_info = x->bind_info(); + int offset = bind_info.offset; if (bind_info.for_type == ForType::GPUBlock) { if (offset == 0) { loop_feature.len_blockIdx_x = loop_feature.loop_length; @@ -223,13 +228,16 @@ void FeatureExtractor::Visit(const PolyFor *x) { /* Visit for Reduce and Broadcast */ void FeatureExtractor::Visit(const Reduce *x) { - if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { + if (x->type() == common::F32() || x->type() == common::F16() || + x->type() == common::F64()) { switch (x->reduce_type) { case Reduce::ReduceType::kSum: - feature_.CurrentLoopBlock().float_reduce_sum_or_sub += x->type().lanes(); + feature_.CurrentLoopBlock().float_reduce_sum_or_sub += + x->type().lanes(); break; case Reduce::ReduceType::kSub: - feature_.CurrentLoopBlock().float_reduce_sum_or_sub += x->type().lanes(); + feature_.CurrentLoopBlock().float_reduce_sum_or_sub += + x->type().lanes(); break; case Reduce::ReduceType::kDiv: feature_.CurrentLoopBlock().float_reduce_div += x->type().lanes(); @@ -238,10 +246,12 @@ void FeatureExtractor::Visit(const Reduce *x) { feature_.CurrentLoopBlock().float_reduce_mul += x->type().lanes(); break; case Reduce::ReduceType::kMax: - feature_.CurrentLoopBlock().float_reduce_max_or_min += x->type().lanes(); + feature_.CurrentLoopBlock().float_reduce_max_or_min += + x->type().lanes(); break; case Reduce::ReduceType::kMin: - feature_.CurrentLoopBlock().float_reduce_max_or_min += x->type().lanes(); + feature_.CurrentLoopBlock().float_reduce_max_or_min += + x->type().lanes(); break; } } else { diff --git a/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc b/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc index 0cfee9415a611..51e68cb287901 100644 --- a/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc +++ b/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc @@ -48,9 +48,10 @@ TEST(FeatureExtractor, SimpleAssign) { ir::Tensor B = lang::Compute( {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); - poly::StageMap stages = poly::CreateStages({A, B}); - std::vector funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); - ir::Expr ast_expr = funcs[0]->body; + poly::StageMap stages = poly::CreateStages({A, B}); + std::vector funcs = lang::LowerVec( + "SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); + ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr to test: " << ast_expr; std::vector vec_ast{ast_expr}; @@ -62,7 +63,8 @@ TEST(FeatureExtractor, SimpleAssign) { std::vector to_check = feature.ToFixedSizeVector(); - ASSERT_EQ(to_check.size(), static_cast(LoopBlockFeature::kTotalSize + 1)); + ASSERT_EQ(to_check.size(), + static_cast(LoopBlockFeature::kTotalSize + 1)); VLOG(6) << "Feature data before slog:"; for (size_t i = 0; i < to_check.size(); ++i) { VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1); @@ -77,9 +79,11 @@ TEST(FeatureExtractor, SimpleAssign) { ASSERT_EQ(to_check[0], 0); #endif // mem_read - ASSERT_EQ(to_check[17], slog(M.get_constant() * N.get_constant())); // mem_read + ASSERT_EQ(to_check[17], + slog(M.get_constant() * N.get_constant())); // mem_read // mem_write - ASSERT_EQ(to_check[18], slog(M.get_constant() * N.get_constant())); // mem_write + ASSERT_EQ(to_check[18], + slog(M.get_constant() * N.get_constant())); // mem_write // non-opt loops, including root block ASSERT_EQ(to_check[29], slog(3)); } @@ -101,16 +105,19 @@ TEST(FeatureExtractor, MatrixMultiply) { ir::Var k(K.as_int32(), "reduce_axis_k"); ir::Tensor C = lang::Compute( - {M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + {M, N}, + [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, + "C"); - poly::StageMap stages = poly::CreateStages({C}); - std::vector funcs = lang::LowerVec("MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true); + poly::StageMap stages = poly::CreateStages({C}); + std::vector funcs = lang::LowerVec( + "MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true); std::vector vec_ast{funcs[0]->body}; ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); std::vector blocks = ir_sch.GetAllBlocks(); - std::vector loops = ir_sch.GetLoops(blocks[0]); + std::vector loops = ir_sch.GetLoops(blocks[0]); ir_sch.Bind(loops.back(), "threadIdx.x"); ir::Expr ast_expr = mod_expr.GetExprs()[0]; @@ -121,7 +128,8 @@ TEST(FeatureExtractor, MatrixMultiply) { std::vector to_check = feature.ToFixedSizeVector(); - ASSERT_EQ(to_check.size(), static_cast(LoopBlockFeature::kTotalSize + 1)); + ASSERT_EQ(to_check.size(), + static_cast(LoopBlockFeature::kTotalSize + 1)); std::unordered_set non_zero_indice = {0, 1, 2, 17, 18, 29, 30, 37}; for (size_t i = 0; i < to_check.size(); ++i) { VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1); @@ -135,7 +143,7 @@ TEST(FeatureExtractor, MatrixMultiply) { #else ASSERT_EQ(to_check[0], 0); #endif - float out_loop = M.get_constant() * N.get_constant(); + float out_loop = M.get_constant() * N.get_constant(); float total_loop = out_loop * K.get_constant(); // float_mul ASSERT_EQ(to_check[1], slog(total_loop)); diff --git a/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.cc b/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.cc index 5db35b19732fb..8697aaa42ee1c 100644 --- a/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.cc +++ b/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.cc @@ -57,7 +57,8 @@ pybind11::array VectorToNumpy(const std::vector>& vec) { Dtype* py_data = static_cast(ret.mutable_data()); for (size_t i = 0; i < vec.size(); ++i) { - assert(vec[i].size() == shape[1] && "Sub vectors must have same size in VectorToNumpy"); + assert(vec[i].size() == shape[1] && + "Sub vectors must have same size in VectorToNumpy"); memcpy(py_data + (shape[1] * i), vec[i].data(), shape[1] * sizeof(Dtype)); } return ret; @@ -71,19 +72,23 @@ pybind11::array VectorToNumpy(const std::vector>& vec) { void AddDistPkgToPythonSysPath() { pybind11::module sys_py_mod = pybind11::module::import("sys"); // short version such as "3.7", "3.8", ... - std::string py_short_version = sys_py_mod.attr("version").cast().substr(0, 3); + std::string py_short_version = + sys_py_mod.attr("version").cast().substr(0, 3); - std::string site_pkg_str = "/usr/local/lib/python" + py_short_version + "/dist-packages"; + std::string site_pkg_str = + "/usr/local/lib/python" + py_short_version + "/dist-packages"; sys_py_mod.attr("path").attr("append")(site_pkg_str); // TODO(zhhsplendid): warning to users if setuptools hasn't been installed DIR* site_pkg_dir = opendir(site_pkg_str.c_str()); if (site_pkg_dir != nullptr) { - std::regex setuptool_regex("setuptools-.*-py" + py_short_version + "\\.egg"); + std::regex setuptool_regex("setuptools-.*-py" + py_short_version + + "\\.egg"); struct dirent* entry = nullptr; while ((entry = readdir(site_pkg_dir)) != nullptr) { if (std::regex_match(entry->d_name, setuptool_regex)) { - sys_py_mod.attr("path").attr("append")(site_pkg_str + "/" + entry->d_name); + sys_py_mod.attr("path").attr("append")(site_pkg_str + "/" + + entry->d_name); } } closedir(site_pkg_dir); @@ -96,40 +101,49 @@ XgbCostModel::XgbCostModel() { if (previous == 0) { AddDistPkgToPythonSysPath(); } - xgb_module_ = pybind11::module::import("xgboost"); + xgb_module_ = pybind11::module::import("xgboost"); xgb_booster_ = xgb_module_.attr("Booster")(); } -void XgbCostModel::Train(const std::vector>& samples, const std::vector& labels) { - update_samples_ = samples; - update_labels_ = labels; +void XgbCostModel::Train(const std::vector>& samples, + const std::vector& labels) { + update_samples_ = samples; + update_labels_ = labels; pybind11::array np_samples = VectorToNumpy(samples); - pybind11::array np_labels = VectorToNumpy(labels); + pybind11::array np_labels = VectorToNumpy(labels); pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels); - xgb_booster_ = xgb_module_.attr("train")(pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_)); + xgb_booster_ = xgb_module_.attr("train")( + pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_)); } -std::vector XgbCostModel::Predict(const std::vector>& samples) const { +std::vector XgbCostModel::Predict( + const std::vector>& samples) const { pybind11::array np_samples = VectorToNumpy(samples); - pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples); - pybind11::array py_result = xgb_booster_.attr("predict")(dmatrix); + pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples); + pybind11::array py_result = xgb_booster_.attr("predict")(dmatrix); return py_result.cast>(); } -void XgbCostModel::Update(const std::vector>& samples, const std::vector& labels) { +void XgbCostModel::Update(const std::vector>& samples, + const std::vector& labels) { update_samples_.insert(update_samples_.end(), samples.begin(), samples.end()); update_labels_.insert(update_labels_.end(), labels.begin(), labels.end()); pybind11::array np_samples = VectorToNumpy(update_samples_); - pybind11::array np_labels = VectorToNumpy(update_labels_); + pybind11::array np_labels = VectorToNumpy(update_labels_); pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels); - xgb_booster_ = xgb_module_.attr("train")(pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_)); + xgb_booster_ = xgb_module_.attr("train")( + pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_)); } -void XgbCostModel::Save(const std::string& path) { xgb_booster_.attr("save_model")(pybind11::str(path)); } +void XgbCostModel::Save(const std::string& path) { + xgb_booster_.attr("save_model")(pybind11::str(path)); +} -void XgbCostModel::Load(const std::string& path) { xgb_booster_.attr("load_model")(pybind11::str(path)); } +void XgbCostModel::Load(const std::string& path) { + xgb_booster_.attr("load_model")(pybind11::str(path)); +} } // namespace auto_schedule } // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h b/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h index 05c2ecc1f2df4..8b33ccc27323d 100644 --- a/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h +++ b/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h @@ -47,11 +47,14 @@ class XgbCostModel : public CostModel { XgbCostModel(); ~XgbCostModel() = default; - void Train(const std::vector>& samples, const std::vector& labels) override; + void Train(const std::vector>& samples, + const std::vector& labels) override; - std::vector Predict(const std::vector>& samples) const override; + std::vector Predict( + const std::vector>& samples) const override; - void Update(const std::vector>& samples, const std::vector& labels) override; + void Update(const std::vector>& samples, + const std::vector& labels) override; void Save(const std::string& path) override; diff --git a/paddle/cinn/auto_schedule/cost_model/xgb_cost_model_test.cc b/paddle/cinn/auto_schedule/cost_model/xgb_cost_model_test.cc index c75210903e16b..c1381fad3260f 100644 --- a/paddle/cinn/auto_schedule/cost_model/xgb_cost_model_test.cc +++ b/paddle/cinn/auto_schedule/cost_model/xgb_cost_model_test.cc @@ -31,10 +31,11 @@ TEST(CostModel, Basic) { srand(time(NULL)); - int batch_size = 16; + int batch_size = 16; int feature_size = 8; std::vector labels(batch_size, 1.0); - std::vector> samples(batch_size, std::vector(feature_size)); + std::vector> samples(batch_size, + std::vector(feature_size)); for (int i = 0; i < batch_size; ++i) { for (int j = 0; j < feature_size; ++j) { samples[i][j] = rand() % 10; diff --git a/paddle/cinn/auto_schedule/database/database.cc b/paddle/cinn/auto_schedule/database/database.cc index 4a1a075ae20ea..e8c14e58b33c0 100644 --- a/paddle/cinn/auto_schedule/database/database.cc +++ b/paddle/cinn/auto_schedule/database/database.cc @@ -26,7 +26,8 @@ namespace cinn { namespace auto_schedule { -bool TuningRecord::Compare::operator()(const TuningRecord& lhs, const TuningRecord& rhs) const { +bool TuningRecord::Compare::operator()(const TuningRecord& lhs, + const TuningRecord& rhs) const { return lhs.execution_cost < rhs.execution_cost; } @@ -39,15 +40,18 @@ proto::TuningRecord TuningRecord::ToProto() const { return record_proto; } -Database::Database(int capacity_per_task) : capacity_per_task_(capacity_per_task) { - CHECK_GT(capacity_per_task_, 0) << "capacity_per_task_ should be greater than 0"; +Database::Database(int capacity_per_task) + : capacity_per_task_(capacity_per_task) { + CHECK_GT(capacity_per_task_, 0) + << "capacity_per_task_ should be greater than 0"; } std::unique_ptr Database::Make(const DatabaseConfig& config) { if (config.type == DatabaseType::kMemory) { return std::make_unique(config.capacity_per_task); } else if (config.type == DatabaseType::kJSONFile) { - return std::make_unique(config.capacity_per_task, config.record_file_path, true); + return std::make_unique( + config.capacity_per_task, config.record_file_path, true); } LOG(FATAL) << "Unimplemented database type."; @@ -81,13 +85,16 @@ std::vector Database::LookUp(const std::string& task_key) { return results; } -std::vector Database::GetTopK(const std::string& task_key, int k) { +std::vector Database::GetTopK(const std::string& task_key, + int k) { auto fit = key2record_.find(task_key); if (fit == key2record_.end() || k <= 0) { return {}; } if (k > capacity_per_task_) { - LOG(WARNING) << "Top k=" << k << " is greater than the capacity, will adjust k=" << capacity_per_task_; + LOG(WARNING) << "Top k=" << k + << " is greater than the capacity, will adjust k=" + << capacity_per_task_; k = capacity_per_task_; } @@ -103,10 +110,12 @@ std::vector Database::GetTopK(const std::string& task_key, int k) } size_t Database::Size() { - auto res = - std::accumulate(key2record_.begin(), key2record_.end(), size_t(0), [](size_t res, const auto& kv) -> size_t { - return std::move(res) + kv.second.size(); - }); + auto res = std::accumulate(key2record_.begin(), + key2record_.end(), + size_t(0), + [](size_t res, const auto& kv) -> size_t { + return std::move(res) + kv.second.size(); + }); return res; } diff --git a/paddle/cinn/auto_schedule/database/database.h b/paddle/cinn/auto_schedule/database/database.h index 3d9a237ecf626..2893daa9e4a2a 100644 --- a/paddle/cinn/auto_schedule/database/database.h +++ b/paddle/cinn/auto_schedule/database/database.h @@ -39,7 +39,9 @@ struct TuningRecord { predicted_cost(record.predicted_cost()), trace(record.trace()), execution_cost(record.execution_cost()) {} - TuningRecord(const std::string& task_key, const SearchState& state, double execution_cost) + TuningRecord(const std::string& task_key, + const SearchState& state, + double execution_cost) : task_key(task_key), predicted_cost(state->predicted_cost), trace(state->ir_schedule.GetTraceDesc().ToProto()), @@ -58,15 +60,15 @@ struct TuningRecord { enum class DatabaseType : int { kMemory, kJSONFile }; struct DatabaseConfig { - DatabaseType type = DatabaseType::kMemory; - int capacity_per_task = 2; + DatabaseType type = DatabaseType::kMemory; + int capacity_per_task = 2; std::string record_file_path = "/tmp/tuning_record.json"; }; -// A database supports insert or lookup historial tuning result with specified traits. -// It can be implemented with a concrete storage to save/load underlying data, -// such as memory, file, database server and so on, this base class can be regarded as -// one using memory as its underlying storage medium. +// A database supports insert or lookup historial tuning result with specified +// traits. It can be implemented with a concrete storage to save/load underlying +// data, such as memory, file, database server and so on, this base class can be +// regarded as one using memory as its underlying storage medium. class Database { public: explicit Database(int capacity_per_task); @@ -93,7 +95,9 @@ class Database { void Insert(const TuningRecord& record); // map task_key to its records - std::unordered_map> key2record_; + std::unordered_map> + key2record_; // the max number of candidates stored const int capacity_per_task_; }; diff --git a/paddle/cinn/auto_schedule/database/database_test.cc b/paddle/cinn/auto_schedule/database/database_test.cc index 1b6f28e4d0a21..3fc55334b8ea5 100644 --- a/paddle/cinn/auto_schedule/database/database_test.cc +++ b/paddle/cinn/auto_schedule/database/database_test.cc @@ -57,8 +57,10 @@ TEST_F(TestDatabase, GetTopK) { ASSERT_TRUE(test_db.GetTopK("k5", 2).empty()); ASSERT_EQ(test_db.GetTopK("k4", 3).size(), 1); - test_db.AddRecord(TuningRecord("k4", SearchState(ir::IRSchedule(), 1.2), 2.0)); - test_db.AddRecord(TuningRecord("k4", SearchState(ir::IRSchedule(), 1.0), 3.0)); + test_db.AddRecord( + TuningRecord("k4", SearchState(ir::IRSchedule(), 1.2), 2.0)); + test_db.AddRecord( + TuningRecord("k4", SearchState(ir::IRSchedule(), 1.0), 3.0)); auto records = test_db.GetTopK("k4", 3); ASSERT_EQ(records.size(), 2); diff --git a/paddle/cinn/auto_schedule/database/jsonfile_database.cc b/paddle/cinn/auto_schedule/database/jsonfile_database.cc index 023b2585caed6..a518dffce3449 100644 --- a/paddle/cinn/auto_schedule/database/jsonfile_database.cc +++ b/paddle/cinn/auto_schedule/database/jsonfile_database.cc @@ -35,7 +35,8 @@ void AppendLineToFile(const std::string& file_path, const std::string& line) { } // read lines from a json file -std::vector ReadLinesFromFile(const std::string& file_path, bool allow_new_file) { +std::vector ReadLinesFromFile(const std::string& file_path, + bool allow_new_file) { std::ifstream is(file_path); if (is.good()) { std::vector json_strs; @@ -51,20 +52,26 @@ std::vector ReadLinesFromFile(const std::string& file_path, bool al return {}; } -JSONFileDatabase::JSONFileDatabase(int capacity_per_task, const std::string& record_file_path, bool allow_new_file) +JSONFileDatabase::JSONFileDatabase(int capacity_per_task, + const std::string& record_file_path, + bool allow_new_file) : Database(capacity_per_task), record_file_path_(record_file_path) { - VLOG(3) << "Auto schedule will save/load tuning records on file:" << record_file_path; + VLOG(3) << "Auto schedule will save/load tuning records on file:" + << record_file_path; auto json_lines = ReadLinesFromFile(record_file_path_, allow_new_file); - std::vector all_records_proto(json_lines.size()); + std::vector all_records_proto( + json_lines.size()); // convert JSON string to proto object auto worker_fn = [this, &json_lines, &all_records_proto](int index) { cinn::auto_schedule::proto::TuningRecord record_proto; - auto status = google::protobuf::util::JsonStringToMessage(json_lines[index], &record_proto); + auto status = google::protobuf::util::JsonStringToMessage(json_lines[index], + &record_proto); CHECK(status.ok()) << "Failed to parse JSON: " << json_lines[index]; all_records_proto[index].Swap(&record_proto); }; - utils::parallel_run(worker_fn, utils::SequenceDispatcher(0, json_lines.size()), -1); + utils::parallel_run( + worker_fn, utils::SequenceDispatcher(0, json_lines.size()), -1); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); @@ -81,8 +88,10 @@ JSONFileDatabase::JSONFileDatabase(int capacity_per_task, const std::string& rec std::string JSONFileDatabase::RecordToJSON(const TuningRecord& record) { proto::TuningRecord record_proto = record.ToProto(); std::string json_string; - auto status = google::protobuf::util::MessageToJsonString(record_proto, &json_string); - CHECK(status.ok()) << "Failed to serialize record to JSON, task key = " << record.task_key; + auto status = + google::protobuf::util::MessageToJsonString(record_proto, &json_string); + CHECK(status.ok()) << "Failed to serialize record to JSON, task key = " + << record.task_key; VLOG(4) << "json_string = \n" << json_string; return json_string; diff --git a/paddle/cinn/auto_schedule/database/jsonfile_database.h b/paddle/cinn/auto_schedule/database/jsonfile_database.h index 2a9752224217b..de3020e7e7b8e 100644 --- a/paddle/cinn/auto_schedule/database/jsonfile_database.h +++ b/paddle/cinn/auto_schedule/database/jsonfile_database.h @@ -19,16 +19,20 @@ namespace cinn { namespace auto_schedule { -// JSONFileDatabase is a database implemented by JSON file to save/load underlying data. +// JSONFileDatabase is a database implemented by JSON file to save/load +// underlying data. class JSONFileDatabase : public Database { public: /*! * \brief Build a JSONFileDatabase object from a json file. * \param capacity_per_task The max number of candidates stored. * \param record_file_path The path of the json file. - * \param allow_new_file Whether to create new file when the given path is not found. + * \param allow_new_file Whether to create new file when the given path is not + * found. */ - JSONFileDatabase(int capacity_per_task, const std::string& record_file_path, bool allow_new_file); + JSONFileDatabase(int capacity_per_task, + const std::string& record_file_path, + bool allow_new_file); ~JSONFileDatabase() = default; // convert a TuningRecord object to string in JSON format @@ -46,7 +50,8 @@ class JSONFileDatabase : public Database { void AppendLineToFile(const std::string& file_path, const std::string& line); // read lines from a json file -std::vector ReadLinesFromFile(const std::string& file_path, bool allow_new_file = true); +std::vector ReadLinesFromFile(const std::string& file_path, + bool allow_new_file = true); } // namespace auto_schedule } // namespace cinn diff --git a/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc b/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc index d60ce20e162a8..71674c18f6013 100644 --- a/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc +++ b/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc @@ -31,7 +31,8 @@ namespace cinn { namespace auto_schedule { // Return lowerd ir AST for example functions used in this test -std::vector LowerCompute(const std::vector& shape, const Target& target) { +std::vector LowerCompute(const std::vector& shape, + const Target& target) { CHECK(shape.size() == 2) << "shape should be 2"; std::vector domain; for (auto i = 0; i < shape.size(); ++i) { @@ -46,11 +47,13 @@ std::vector LowerCompute(const std::vector& shape, const T C = Compute( domain, [&B](Var i, Var j) { return B(i, j); }, "C"); - return cinn::lang::LowerVec("test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); + return cinn::lang::LowerVec( + "test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); } // Create a new IRSchedule with copied ir::LoweredFunc AST -ir::IRSchedule MakeIRSchedule(const std::vector& lowered_funcs, const std::string& task_key) { +ir::IRSchedule MakeIRSchedule(const std::vector& lowered_funcs, + const std::string& task_key) { std::vector exprs; for (auto&& func : lowered_funcs) { exprs.emplace_back(optim::IRCopy(func->body)); @@ -63,7 +66,9 @@ ir::IRSchedule MakeIRSchedule(const std::vector& lowered_funcs, class TestJSONFileDatabase : public ::testing::Test { public: - TestJSONFileDatabase() : record_file_path("/tmp/test_record.json"), test_db(2, record_file_path, true) {} + TestJSONFileDatabase() + : record_file_path("/tmp/test_record.json"), + test_db(2, record_file_path, true) {} void SetUp() override { lowered_funcs = LowerCompute({32, 32}, target); } @@ -91,55 +96,76 @@ class TestJSONFileDatabase : public ::testing::Test { TEST_F(TestJSONFileDatabase, Serialize) { ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs, "test"); - auto fused = ir_sch.Fuse("B", {0, 1}); + auto fused = ir_sch.Fuse("B", {0, 1}); VLOG(3) << "after Fuse, Expr: " << fused; TuningRecord record1("test", SearchState(std::move(ir_sch), 2.0), 1.0); std::string str = test_db.RecordToJSON(record1); VLOG(3) << "RecordToJSON: " << str; - // Because the serialization of protobuf does not guarantee the order, we give all possible results. + // Because the serialization of protobuf does not guarantee the order, we give + // all possible results. std::string case1 = - "{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," - "\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":\"INTS\",\"ints\":[0,1]},{\"name\":\"block_" + "{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":" + "{\"steps\":[{\"type\":\"FuseWithName\"," + "\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":" + "\"INTS\",\"ints\":[0,1]},{\"name\":\"block_" "name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}"; std::string case2 = - "{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," - "\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":\"STRING\",\"s\":\"B\"},{\"name\":\"loops_" + "{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":" + "{\"steps\":[{\"type\":\"FuseWithName\"," + "\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":" + "\"STRING\",\"s\":\"B\"},{\"name\":\"loops_" "index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}"; EXPECT_EQ(true, str == case1 || str == case2); } TEST_F(TestJSONFileDatabase, SaveLoad) { ir::IRSchedule ir_sch1 = MakeIRSchedule(lowered_funcs, "k1"); - auto fused1 = ir_sch1.Fuse("B", {0, 1}); + auto fused1 = ir_sch1.Fuse("B", {0, 1}); ir::IRSchedule ir_sch2 = MakeIRSchedule(lowered_funcs, "k2"); - test_db.AddRecord(TuningRecord("k1", SearchState(std::move(ir_sch1), 1.5), 1.0)); - test_db.AddRecord(TuningRecord("k2", SearchState(std::move(ir_sch2), 3.5), 3.0)); + test_db.AddRecord( + TuningRecord("k1", SearchState(std::move(ir_sch1), 1.5), 1.0)); + test_db.AddRecord( + TuningRecord("k2", SearchState(std::move(ir_sch2), 3.5), 3.0)); std::vector strs = ReadLinesFromFile(record_file_path); ASSERT_EQ(strs.size(), 2); - // Because the serialization of protobuf does not guarantee the order, we give all possible results. + // Because the serialization of protobuf does not guarantee the order, we give + // all possible results. std::string case1 = - "{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," - "\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":\"INTS\",\"ints\":[0,1]},{\"name\":\"block_" + "{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":" + "{\"steps\":[{\"type\":\"FuseWithName\"," + "\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":" + "\"INTS\",\"ints\":[0,1]},{\"name\":\"block_" "name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}"; std::string case2 = - "{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," - "\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":\"STRING\",\"s\":\"B\"},{\"name\":\"loops_" + "{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":" + "{\"steps\":[{\"type\":\"FuseWithName\"," + "\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":" + "\"STRING\",\"s\":\"B\"},{\"name\":\"loops_" "index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}"; EXPECT_EQ(true, strs[0] == case1 || strs[0] == case2); - EXPECT_EQ(strs[1], "{\"taskKey\":\"k2\",\"executionCost\":3,\"predictedCost\":3.5,\"trace\":{}}"); + EXPECT_EQ(strs[1], + "{\"taskKey\":\"k2\",\"executionCost\":3,\"predictedCost\":3.5," + "\"trace\":{}}"); } TEST_F(TestJSONFileDatabase, Basic) { - test_db.AddRecord(TuningRecord("k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0)); - test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); - test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0)); - test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 8.0), 3.0)); - test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 7.0), 4.0)); - test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 6.0), 5.0)); - test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 4.0)); + test_db.AddRecord(TuningRecord( + "k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0)); + test_db.AddRecord(TuningRecord( + "k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); + test_db.AddRecord(TuningRecord( + "k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0)); + test_db.AddRecord(TuningRecord( + "k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 8.0), 3.0)); + test_db.AddRecord(TuningRecord( + "k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 7.0), 4.0)); + test_db.AddRecord(TuningRecord( + "k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 6.0), 5.0)); + test_db.AddRecord(TuningRecord( + "k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 4.0)); ASSERT_EQ(test_db.Size(), 6); auto records = test_db.LookUp("k3"); @@ -152,15 +178,24 @@ TEST_F(TestJSONFileDatabase, Basic) { } TEST_F(TestJSONFileDatabase, GetTopK) { - test_db.AddRecord(TuningRecord("k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0)); - test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); - test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0)); - test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 3.0)); - test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 4.0)); - test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 5.0)); - test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 2.0), 4.0)); - test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.2), 2.0)); - test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 3.0)); + test_db.AddRecord(TuningRecord( + "k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0)); + test_db.AddRecord(TuningRecord( + "k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); + test_db.AddRecord(TuningRecord( + "k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0)); + test_db.AddRecord(TuningRecord( + "k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 3.0)); + test_db.AddRecord(TuningRecord( + "k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 4.0)); + test_db.AddRecord(TuningRecord( + "k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 5.0)); + test_db.AddRecord(TuningRecord( + "k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 2.0), 4.0)); + test_db.AddRecord(TuningRecord( + "k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.2), 2.0)); + test_db.AddRecord(TuningRecord( + "k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 3.0)); auto records = test_db.GetTopK("k4", 3); ASSERT_EQ(records.size(), 2); @@ -170,9 +205,11 @@ TEST_F(TestJSONFileDatabase, GetTopK) { TEST_F(TestJSONFileDatabase, Reload) { ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs, "k1"); - auto fused = ir_sch.Fuse("B", {0, 1}); - test_db.AddRecord(TuningRecord("k1", SearchState(std::move(ir_sch), 1.0), 1.0)); - test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); + auto fused = ir_sch.Fuse("B", {0, 1}); + test_db.AddRecord( + TuningRecord("k1", SearchState(std::move(ir_sch), 1.0), 1.0)); + test_db.AddRecord(TuningRecord( + "k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); auto records = test_db.LookUp("k1"); ASSERT_EQ(records.size(), 1); @@ -184,11 +221,13 @@ TEST_F(TestJSONFileDatabase, Reload) { EXPECT_EQ(records[0].execution_cost, loaded_records[0].execution_cost); EXPECT_EQ(records[0].predicted_cost, loaded_records[0].predicted_cost); - // check the equality of trace info between original TuningRecord and the loaded TuningRecord + // check the equality of trace info between original TuningRecord and the + // loaded TuningRecord const auto& lhs_trace = records[0].trace; const auto& rhs_trace = loaded_records[0].trace; google::protobuf::util::MessageDifferencer dif; - static const google::protobuf::Descriptor* descriptor = cinn::ir::proto::ScheduleDesc_Step::descriptor(); + static const google::protobuf::Descriptor* descriptor = + cinn::ir::proto::ScheduleDesc_Step::descriptor(); dif.TreatAsSet(descriptor->FindFieldByName("attrs")); EXPECT_TRUE(dif.Compare(lhs_trace, rhs_trace)); @@ -203,8 +242,8 @@ TEST_F(TestJSONFileDatabase, Reload) { ASSERT_EQ(lhs_exprs.size(), rhs_exprs.size()); for (auto i = 0; i < lhs_exprs.size(); ++i) { - std::string lhs = utils::GetStreamCnt(lhs_exprs.at(i)); - std::string rhs = utils::GetStreamCnt(rhs_exprs.at(i)); + std::string lhs = utils::GetStreamCnt(lhs_exprs.at(i)); + std::string rhs = utils::GetStreamCnt(rhs_exprs.at(i)); size_t remove_prefix_len = 28; ASSERT_EQ(lhs.erase(0, remove_prefix_len), rhs.erase(0, remove_prefix_len)); } diff --git a/paddle/cinn/auto_schedule/measure/measure.h b/paddle/cinn/auto_schedule/measure/measure.h index f03ec13b9fef9..36140580ee1b5 100644 --- a/paddle/cinn/auto_schedule/measure/measure.h +++ b/paddle/cinn/auto_schedule/measure/measure.h @@ -53,7 +53,8 @@ struct MeasureResult { // The result of building with input schedule struct BuildResult { - // The scope that owns detail compilation infos of parameters in the runtime program + // The scope that owns detail compilation infos of parameters in the runtime + // program const hlir::framework::Scope* compiled_scope; // The executable program std::unique_ptr runtime_program; @@ -68,11 +69,13 @@ class ScheduleBuilder { virtual BuildResult Build(const MeasureInput& input) = 0; }; -// This interface defines how to run the built result. Like above ScheduleBuilder, -// a runner shoule be implemented with not bound to a specific task. +// This interface defines how to run the built result. Like above +// ScheduleBuilder, a runner shoule be implemented with not bound to a specific +// task. class ScheduleRunner { public: - virtual MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) = 0; + virtual MeasureResult Run(const MeasureInput& input, + const BuildResult& build_result) = 0; }; } // namespace auto_schedule diff --git a/paddle/cinn/auto_schedule/measure/measurer_test.cc b/paddle/cinn/auto_schedule/measure/measurer_test.cc index 949875d14e510..94f3540e77d5e 100644 --- a/paddle/cinn/auto_schedule/measure/measurer_test.cc +++ b/paddle/cinn/auto_schedule/measure/measurer_test.cc @@ -62,22 +62,27 @@ class TestMeasurer : public ::testing::Test { Target target = common::DefaultHostTarget(); #endif std::unordered_set fetch_ids; - auto program = CreateAddReluProgram(); - auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); - auto scope = BuildScope(target, graph); + auto program = CreateAddReluProgram(); + auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); + auto scope = BuildScope(target, graph); graph_compiler = std::make_unique(target, scope, graph); TaskCreator task_creator; - tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); - const auto& dtype_dict = graph->GetAttrs>("inferdtype"); - const auto& shape_dict = graph->GetAttrs>("infershape"); - - auto op_lowerer = std::make_unique(dtype_dict, shape_dict, target); + tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); + const auto& dtype_dict = + graph->GetAttrs>( + "inferdtype"); + const auto& shape_dict = graph->GetAttrs< + absl::flat_hash_map>( + "infershape"); + + auto op_lowerer = std::make_unique( + dtype_dict, shape_dict, target); inputs.reserve(tasks.size()); for (int i = 0; i < tasks.size(); ++i) { auto* task = &tasks[i]; task->Initialize(shape_dict, dtype_dict, op_lowerer.get()); MeasureInput input; - input.task = task; + input.task = task; input.lowered_funcs = task->lowered_funcs; inputs.emplace_back(input); } @@ -95,30 +100,37 @@ class ThrowExceptionRunner : public ScheduleRunner { struct Exception : public std::exception { const char* what() const throw() { return "RunError"; } }; - MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) override { throw Exception(); } + MeasureResult Run(const MeasureInput& input, + const BuildResult& build_result) override { + throw Exception(); + } }; TEST_F(TestMeasurer, Basic) { - auto builder = std::make_unique(graph_compiler.get()); - auto runner = std::make_unique(1); - auto measurer = std::make_unique(builder.get(), runner.get()); + auto builder = std::make_unique(graph_compiler.get()); + auto runner = std::make_unique(1); + auto measurer = + std::make_unique(builder.get(), runner.get()); std::vector results = measurer->Measure(inputs); ASSERT_EQ(inputs.size(), results.size()); } TEST_F(TestMeasurer, CatchException) { - auto builder = std::make_unique(graph_compiler.get()); - auto runner = std::make_unique(1); - auto throw_builder = std::make_unique(); - auto throw_runner = std::make_unique(); - auto measurer_with_build_error = std::make_unique(throw_builder.get(), runner.get(), 2); - std::vector results = measurer_with_build_error->Measure(inputs); + auto builder = std::make_unique(graph_compiler.get()); + auto runner = std::make_unique(1); + auto throw_builder = std::make_unique(); + auto throw_runner = std::make_unique(); + auto measurer_with_build_error = + std::make_unique(throw_builder.get(), runner.get(), 2); + std::vector results = + measurer_with_build_error->Measure(inputs); ASSERT_EQ(inputs.size(), results.size()); EXPECT_EQ(results[0].error_msg, "Build failed, error: BuildError\n"); // TODO(CtfGo): test parallel build after we support thread-safe compilation - auto measurer_with_run_error = std::make_unique(builder.get(), throw_runner.get(), 1); - results = measurer_with_run_error->Measure(inputs); + auto measurer_with_run_error = + std::make_unique(builder.get(), throw_runner.get(), 1); + results = measurer_with_run_error->Measure(inputs); ASSERT_EQ(inputs.size(), results.size()); EXPECT_EQ(results[0].error_msg, "Run failed, error: RunError\n"); } diff --git a/paddle/cinn/auto_schedule/measure/schedule_measurer.cc b/paddle/cinn/auto_schedule/measure/schedule_measurer.cc index 03b95ad26f184..bd8d5862140b2 100644 --- a/paddle/cinn/auto_schedule/measure/schedule_measurer.cc +++ b/paddle/cinn/auto_schedule/measure/schedule_measurer.cc @@ -21,10 +21,13 @@ namespace cinn { namespace auto_schedule { -ScheduleMeasurer::ScheduleMeasurer(ScheduleBuilder* builder, ScheduleRunner* runner, int num_threads) +ScheduleMeasurer::ScheduleMeasurer(ScheduleBuilder* builder, + ScheduleRunner* runner, + int num_threads) : builder_(builder), runner_(runner), num_threads_(num_threads) {} -std::vector ScheduleMeasurer::Measure(const std::vector& inputs) { +std::vector ScheduleMeasurer::Measure( + const std::vector& inputs) { if (inputs.empty()) { LOG(WARNING) << "inputs is empty"; return {}; @@ -33,41 +36,49 @@ std::vector ScheduleMeasurer::Measure(const std::vector results(inputs.size()); // define how to build a candidate with the specified index - auto build_fn = [builder = builder_, &inputs, &build_results, &results](int index) { - VLOG(6) << "Build candidate index: " << index; - auto m_start = std::chrono::steady_clock::now(); - try { - build_results[index] = builder->Build(inputs[index]); - } catch (std::exception& e) { - results[index].error_msg = utils::StringFormat("Build failed, error: %s\n", e.what()); - } - auto time_span = std::chrono::duration_cast(std::chrono::steady_clock::now() - m_start); - results[index].elapsed_time += static_cast(time_span.count()); - }; + auto build_fn = + [builder = builder_, &inputs, &build_results, &results](int index) { + VLOG(6) << "Build candidate index: " << index; + auto m_start = std::chrono::steady_clock::now(); + try { + build_results[index] = builder->Build(inputs[index]); + } catch (std::exception& e) { + results[index].error_msg = + utils::StringFormat("Build failed, error: %s\n", e.what()); + } + auto time_span = std::chrono::duration_cast( + std::chrono::steady_clock::now() - m_start); + results[index].elapsed_time += static_cast(time_span.count()); + }; // define how to run a candidate with the specified index - auto run_fn = [runner = runner_, &inputs, &build_results, &results](int index) { - VLOG(6) << "Run candidate index: " << index; - auto m_start = std::chrono::steady_clock::now(); - try { - // if error occurred in building, then skip running - if (results[index].error_msg.empty()) { - results[index] = runner->Run(inputs[index], build_results[index]); - } - } catch (std::exception& e) { - results[index].error_msg = utils::StringFormat("Run failed, error: %s\n", e.what()); - } - auto time_span = std::chrono::duration_cast(std::chrono::steady_clock::now() - m_start); - results[index].elapsed_time += static_cast(time_span.count()); - }; + auto run_fn = + [runner = runner_, &inputs, &build_results, &results](int index) { + VLOG(6) << "Run candidate index: " << index; + auto m_start = std::chrono::steady_clock::now(); + try { + // if error occurred in building, then skip running + if (results[index].error_msg.empty()) { + results[index] = runner->Run(inputs[index], build_results[index]); + } + } catch (std::exception& e) { + results[index].error_msg = + utils::StringFormat("Run failed, error: %s\n", e.what()); + } + auto time_span = std::chrono::duration_cast( + std::chrono::steady_clock::now() - m_start); + results[index].elapsed_time += static_cast(time_span.count()); + }; // measure a candidate by calling build and run successively auto measure_fn = [&build_fn, &run_fn](int index) { build_fn(index); run_fn(index); }; - // default num_threads_ is 1 and in that case it will perform all measurements sequentially inplace. - utils::parallel_run(measure_fn, utils::SequenceDispatcher(0, inputs.size()), num_threads_); + // default num_threads_ is 1 and in that case it will perform all measurements + // sequentially inplace. + utils::parallel_run( + measure_fn, utils::SequenceDispatcher(0, inputs.size()), num_threads_); VLOG(4) << "Measure " << inputs.size() << " candidates"; return results; diff --git a/paddle/cinn/auto_schedule/measure/schedule_measurer.h b/paddle/cinn/auto_schedule/measure/schedule_measurer.h index b95efc53ebe53..3a1aa7850a2ab 100644 --- a/paddle/cinn/auto_schedule/measure/schedule_measurer.h +++ b/paddle/cinn/auto_schedule/measure/schedule_measurer.h @@ -25,7 +25,9 @@ namespace auto_schedule { // which are building the input schedules and running the generated codes. class ScheduleMeasurer { public: - ScheduleMeasurer(ScheduleBuilder* builder, ScheduleRunner* runner, int num_threads = 1); + ScheduleMeasurer(ScheduleBuilder* builder, + ScheduleRunner* runner, + int num_threads = 1); // Measure a batch of inputs and return all results once. std::vector Measure(const std::vector& inputs); diff --git a/paddle/cinn/auto_schedule/measure/simple_builder.cc b/paddle/cinn/auto_schedule/measure/simple_builder.cc index 842acd47216e8..56005d5518b17 100644 --- a/paddle/cinn/auto_schedule/measure/simple_builder.cc +++ b/paddle/cinn/auto_schedule/measure/simple_builder.cc @@ -19,20 +19,24 @@ namespace auto_schedule { using hlir::framework::GraphCompiler; -SimpleBuilder::SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler) : graph_compiler_(graph_compiler) {} +SimpleBuilder::SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler) + : graph_compiler_(graph_compiler) {} BuildResult SimpleBuilder::Build(const MeasureInput& input) { - CHECK_NE(graph_compiler_, static_cast(nullptr)) << "empty handle to GraphCompiler"; + CHECK_NE(graph_compiler_, static_cast(nullptr)) + << "empty handle to GraphCompiler"; GraphCompiler::CompileOptions compile_options; compile_options.groups.emplace_back(input.task->subgraph); compile_options.lowered_funcs.emplace_back(input.lowered_funcs); compile_options.remove_unused_variables = false; - VLOG(5) << "call GraphCompiler to Build with Graph::Group size=" << compile_options.groups.size() - << ", lowered_funcs group size=" << compile_options.lowered_funcs.size(); - GraphCompiler::CompilationResult compiled_result = graph_compiler_->Build(compile_options); + VLOG(5) << "call GraphCompiler to Build with Graph::Group size=" + << compile_options.groups.size() << ", lowered_funcs group size=" + << compile_options.lowered_funcs.size(); + GraphCompiler::CompilationResult compiled_result = + graph_compiler_->Build(compile_options); BuildResult build_result; - build_result.compiled_scope = graph_compiler_->GetScope().get(); + build_result.compiled_scope = graph_compiler_->GetScope().get(); build_result.runtime_program = std::move(compiled_result.runtime_program); return build_result; } diff --git a/paddle/cinn/auto_schedule/measure/simple_runner.cc b/paddle/cinn/auto_schedule/measure/simple_runner.cc index 5d5621cc43b60..1871cfc82ae92 100644 --- a/paddle/cinn/auto_schedule/measure/simple_runner.cc +++ b/paddle/cinn/auto_schedule/measure/simple_runner.cc @@ -35,48 +35,64 @@ using hlir::framework::Tensor; // Parameters that needs to be initialized to 0. // Key is the Op name, and value is the index of the input parameter in the Op. -static const std::unordered_map> kInitWithZeroParams = { - {"lookup_table", {1}}, - {"gather", {1}}, - {"gather_nd", {1}}, - {"scatter_assign", {2}}, - {"scatter_add", {2}}, +static const std::unordered_map> + kInitWithZeroParams = { + {"lookup_table", {1}}, + {"gather", {1}}, + {"gather_nd", {1}}, + {"scatter_assign", {2}}, + {"scatter_add", {2}}, }; // Generate random value and populate them to the output address of memory -static void PopulateRandomValue(const common::Type& type, const int numel, void* raw_ptr) { +static void PopulateRandomValue(const common::Type& type, + const int numel, + void* raw_ptr) { std::random_device seed; std::default_random_engine engine(seed()); if (type == common::Bool()) { auto* fmt_ptr = reinterpret_cast(raw_ptr); std::bernoulli_distribution dist(0.5); - std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); + std::generate_n( + fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); } else if (type == common::I32()) { auto* fmt_ptr = reinterpret_cast(raw_ptr); - std::uniform_int_distribution dist(std::numeric_limits::min(), std::numeric_limits::max()); - std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); + std::uniform_int_distribution dist(std::numeric_limits::min(), + std::numeric_limits::max()); + std::generate_n( + fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); } else if (type == common::I64()) { auto* fmt_ptr = reinterpret_cast(raw_ptr); - std::uniform_int_distribution dist(std::numeric_limits::min(), - std::numeric_limits::max()); - std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); + std::uniform_int_distribution dist( + std::numeric_limits::min(), + std::numeric_limits::max()); + std::generate_n( + fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); } else if (type == common::F32()) { auto* fmt_ptr = reinterpret_cast(raw_ptr); - std::uniform_real_distribution dist(std::numeric_limits::min(), std::numeric_limits::max()); - std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); + std::uniform_real_distribution dist( + std::numeric_limits::min(), std::numeric_limits::max()); + std::generate_n( + fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); } else { - CHECK_EQ(type.bytes(), 8) << "Unsupported type: " << type << ", type.bytes = " << type.bytes(); + CHECK_EQ(type.bytes(), 8) + << "Unsupported type: " << type << ", type.bytes = " << type.bytes(); auto* fmt_ptr = reinterpret_cast(raw_ptr); - std::uniform_int_distribution dist(std::numeric_limits::min(), - std::numeric_limits::max()); - std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); + std::uniform_int_distribution dist( + std::numeric_limits::min(), + std::numeric_limits::max()); + std::generate_n( + fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); } } -// Initialize a tensor with 0 if init_with_zero == true, otherwise initialize the tensor with random value. -static void InitTensorData(Tensor tensor, const common::Target& target, bool init_with_zero) { - int mem_size = tensor->shape().numel() * tensor->type().bytes(); +// Initialize a tensor with 0 if init_with_zero == true, otherwise initialize +// the tensor with random value. +static void InitTensorData(Tensor tensor, + const common::Target& target, + bool init_with_zero) { + int mem_size = tensor->shape().numel() * tensor->type().bytes(); auto* tensor_data = tensor->mutable_data(target, tensor->type()); #ifdef CINN_WITH_CUDA if (target == common::DefaultNVGPUTarget()) { @@ -101,17 +117,20 @@ static void InitTensorData(Tensor tensor, const common::Target& target, bool ini // Find all parameter names in the task corresponding to the MeasureInput // that need to be initialized to 0 when measuring. -static std::unordered_set ParamsNeedInitWithZero(const MeasureInput& input) { +static std::unordered_set ParamsNeedInitWithZero( + const MeasureInput& input) { std::unordered_set res; - std::vector nodes = input.task->subgraph->CollectNodes(); + std::vector nodes = + input.task->subgraph->CollectNodes(); for (auto* node : nodes) { if (kInitWithZeroParams.count(node->op()->name) != 0) { std::vector param_idxs = kInitWithZeroParams.at(node->op()->name); - const auto& inlinks = node->inlinks_in_order(); + const auto& inlinks = node->inlinks_in_order(); for (int param_idx : param_idxs) { CHECK_GT(inlinks.size(), param_idx); - auto& edge = inlinks.at(param_idx); - std::string param_name = edge->source()->as()->id(); + auto& edge = inlinks.at(param_idx); + std::string param_name = + edge->source()->as()->id(); VLOG(6) << "param needs to be init with 0: " << param_name; res.insert(param_name); } @@ -128,17 +147,19 @@ SimpleRunner::SimpleRunner(int repeat_times) : repeat_times_(repeat_times) { // Prepare execution arguments of all instructions to run, a argument // may be obtained from the input of measurement or allocating new buffer // with random value. -std::map SimpleRunner::PrepareArgs(const MeasureInput& input, - const BuildResult& build_result, - hlir::framework::Scope* temp_scope) { +std::map SimpleRunner::PrepareArgs( + const MeasureInput& input, + const BuildResult& build_result, + hlir::framework::Scope* temp_scope) { std::map result; - const auto& target = input.task->target; - const auto* input_args = input.execution_args; + const auto& target = input.task->target; + const auto* input_args = input.execution_args; const auto* compiled_scope = build_result.compiled_scope; - const auto& instructions = build_result.runtime_program->GetRunInstructions(); + const auto& instructions = build_result.runtime_program->GetRunInstructions(); - std::unordered_set params_need_init_with_zero = ParamsNeedInitWithZero(input); + std::unordered_set params_need_init_with_zero = + ParamsNeedInitWithZero(input); auto fill_arg_fn = [&](const std::string& param) { VLOG(6) << "Filling argument:" << param; @@ -169,7 +190,8 @@ std::map SimpleRunner::PrepareArgs(const MeasureI temp_tensor->Resize(compiled_tensor->shape()); temp_tensor->set_type(compiled_tensor->type()); temp_tensor->mutable_data(target, compiled_tensor->type()); - InitTensorData(temp_tensor, target, params_need_init_with_zero.count(param) != 0); + InitTensorData( + temp_tensor, target, params_need_init_with_zero.count(param) != 0); result.emplace(param, temp_tensor->buffer()); }; @@ -186,7 +208,8 @@ std::map SimpleRunner::PrepareArgs(const MeasureI return result; } -MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& build_result) { +MeasureResult SimpleRunner::Run(const MeasureInput& input, + const BuildResult& build_result) { MeasureResult result; auto t_start = std::chrono::steady_clock::now(); // prepare execution arguments @@ -195,7 +218,7 @@ MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& bu auto execution_args = PrepareArgs(input, build_result, &temp_scope); // Execute each instruction repeatedly and take the average as cost. - result.execution_cost = 0; + result.execution_cost = 0; const auto& instructions = build_result.runtime_program->GetRunInstructions(); for (auto ct = 0; ct < instructions.size(); ++ct) { auto&& instr = instructions.at(ct); @@ -209,16 +232,18 @@ MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& bu CUDA_CALL(cudaDeviceSynchronize()); } #endif - auto time_span = - std::chrono::duration_cast(std::chrono::steady_clock::now() - run_start); + auto time_span = std::chrono::duration_cast( + std::chrono::steady_clock::now() - run_start); auto cost_avg = static_cast(time_span.count()) / repeat_times_; result.execution_cost += cost_avg; } - auto time_span = std::chrono::duration_cast(std::chrono::steady_clock::now() - t_start); + auto time_span = std::chrono::duration_cast( + std::chrono::steady_clock::now() - t_start); result.elapsed_time = static_cast(time_span.count()); - VLOG(4) << "A measurement done:repeat_times[" << repeat_times_ << "]total_elapsed_time[" << result.elapsed_time + VLOG(4) << "A measurement done:repeat_times[" << repeat_times_ + << "]total_elapsed_time[" << result.elapsed_time << "]us,execution_cost[" << result.execution_cost << "]us"; return result; } diff --git a/paddle/cinn/auto_schedule/measure/simple_runner.h b/paddle/cinn/auto_schedule/measure/simple_runner.h index de5ef2b152c62..d466c71b447d8 100644 --- a/paddle/cinn/auto_schedule/measure/simple_runner.h +++ b/paddle/cinn/auto_schedule/measure/simple_runner.h @@ -26,12 +26,14 @@ class SimpleRunner : public ScheduleRunner { public: SimpleRunner(int repeat_times); - MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) override; + MeasureResult Run(const MeasureInput& input, + const BuildResult& build_result) override; private: - std::map PrepareArgs(const MeasureInput& input, - const BuildResult& build_result, - hlir::framework::Scope* temp_scope); + std::map PrepareArgs( + const MeasureInput& input, + const BuildResult& build_result, + hlir::framework::Scope* temp_scope); private: // The repeat times of running instructions, diff --git a/paddle/cinn/auto_schedule/measure/simple_runner_test.cc b/paddle/cinn/auto_schedule/measure/simple_runner_test.cc index bfe93bfa0eb74..2181aa4a8d557 100644 --- a/paddle/cinn/auto_schedule/measure/simple_runner_test.cc +++ b/paddle/cinn/auto_schedule/measure/simple_runner_test.cc @@ -53,15 +53,16 @@ class TestSimpleRunner : public ::testing::Test { static frontend::Program CreateAddReluProgram(); void SetUp() override { std::unordered_set fetch_ids; - auto program = CreateAddReluProgram(); - auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); - compiled_scope = BuildScope(target, graph); - graph_compiler = std::make_unique(target, compiled_scope, graph); - auto runtime_program = graph_compiler->Build(); + auto program = CreateAddReluProgram(); + auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); + compiled_scope = BuildScope(target, graph); + graph_compiler = + std::make_unique(target, compiled_scope, graph); + auto runtime_program = graph_compiler->Build(); const auto& instructions = runtime_program->GetRunInstructions(); ASSERT_EQ(1, instructions.size()); - build_result.compiled_scope = compiled_scope.get(); + build_result.compiled_scope = compiled_scope.get(); build_result.runtime_program = std::move(runtime_program); task = std::make_unique(); @@ -71,7 +72,7 @@ class TestSimpleRunner : public ::testing::Test { task->target = common::DefaultHostTarget(); #endif task->subgraph = graph->fusion_groups.front(); - input.task = task.get(); + input.task = task.get(); } }; @@ -115,18 +116,22 @@ TEST_F(TestSimpleRunner, TimeMeasured) { BuildResult build_result; build_result.compiled_scope = nullptr; std::vector> instructions; - instructions.emplace_back( - new Instruction(common::DefaultHostTarget(), nullptr, {}, {"empty_placeholder"}, "sleep_fn")); + instructions.emplace_back(new Instruction(common::DefaultHostTarget(), + nullptr, + {}, + {"empty_placeholder"}, + "sleep_fn")); instructions.back()->SetLoweredFunc(reinterpret_cast(sleep_fn)); instructions.back()->Finalize(); - build_result.runtime_program.reset(new hlir::framework::Program(nullptr, std::move(instructions))); + build_result.runtime_program.reset( + new hlir::framework::Program(nullptr, std::move(instructions))); // to skip the condition check of params in Instruction::PreparePodArgs std::map preset_args; preset_args.emplace("empty_placeholder", cinn_pod_value_t()); input.execution_args = &preset_args; - auto runner = std::make_unique(2); + auto runner = std::make_unique(2); MeasureResult measure_result = runner->Run(input, build_result); // because the kernel function will sleep 100 us, // the cost time of execution and span in total must diff --git a/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.cc b/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.cc index ede3cce78edc5..40f40d06d9d7f 100644 --- a/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.cc +++ b/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.cc @@ -22,10 +22,12 @@ namespace cinn { namespace auto_schedule { -int ExtractNumThreads(const ir::IRSchedule& ir_schedule, const std::string& bind_axis) { +int ExtractNumThreads(const ir::IRSchedule& ir_schedule, + const std::string& bind_axis) { const ir::ScheduleDesc& trace = ir_schedule.GetTraceDesc(); for (auto&& step : trace.Steps()) { - if (step.type == "Bind" && step.attrs.find("thread_axis") != step.attrs.end() && + if (step.type == "Bind" && + step.attrs.find("thread_axis") != step.attrs.end() && absl::get(step.attrs.at("thread_axis")) == bind_axis) { CHECK_EQ(step.inputs.at("loop").size(), 1); return step.inputs.at("loop")[0].As()->extent.as_int32(); @@ -38,17 +40,21 @@ std::vector FindCandidates(const ir::ScheduleDesc& trace) { std::vector candidate_block_names; for (auto&& step : trace.Steps()) { if (step.type == "AnnotateIntAttr" && - absl::get(step.attrs.at("key")) == ir::attr::cooperative_process) { + absl::get(step.attrs.at("key")) == + ir::attr::cooperative_process) { candidate_block_names.push_back( - step.inputs.at("block")[0].As()->schedule_block.As()->name); + step.inputs.at("block")[0] + .As() + ->schedule_block.As() + ->name); } } return candidate_block_names; } bool CooperativeProcess::Apply(ir::IRSchedule* schedule) { - int num_threads = ExtractNumThreads(*schedule, "threadIdx.x"); - const ir::ScheduleDesc& trace = schedule->GetTraceDesc(); + int num_threads = ExtractNumThreads(*schedule, "threadIdx.x"); + const ir::ScheduleDesc& trace = schedule->GetTraceDesc(); std::vector candidate_block_names = FindCandidates(trace); for (auto&& candidate : candidate_block_names) { auto loop = schedule->GetLoops(candidate).back(); diff --git a/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h b/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h index 545d7078d39ed..7985afbf8dda6 100644 --- a/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h +++ b/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h @@ -20,8 +20,9 @@ namespace cinn { namespace auto_schedule { /* - * @brief Rewrite the cooperative_process annotation to actually bind the loop on threadIdx. - * This rule is used for collaborative data handling of multiple threads within the same block. + * @brief Rewrite the cooperative_process annotation to actually bind the loop + * on threadIdx. This rule is used for collaborative data handling of multiple + * threads within the same block. */ class CooperativeProcess : public PostScheduleRule { public: diff --git a/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process_test.cc b/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process_test.cc index e4cf2ab43aa64..a6e1db2a8b20e 100644 --- a/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process_test.cc +++ b/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process_test.cc @@ -31,57 +31,75 @@ class TestCooperativeProcess : public TestAutoGenRuleBase { }; TEST_F(TestCooperativeProcess, Matmul) { - default_input_names = {"X", "Y"}; - default_output_names = {"temp_matmul_out"}; - std::vector X_shape = {32, 32}; - std::vector Y_shape = {32, 32}; + default_input_names = {"X", "Y"}; + default_output_names = {"temp_matmul_out"}; + std::vector X_shape = {32, 32}; + std::vector Y_shape = {32, 32}; std::vector out_shape = {32, 32}; - int num_blocks_y = 2; - int num_blocks_x = 2; + int num_blocks_y = 2; + int num_blocks_x = 2; int num_threads_y = 8; int num_threads_x = 2; - int steps_k = 8; + int steps_k = 8; Initialize(common::DefaultNVGPUTarget()); - frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}}); - ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed); + frontend::Program matmul_op = + tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}}); + ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed); // split loops - std::vector loops = ir_schedule.GetLoops("temp_matmul_out"); + std::vector loops = ir_schedule.GetLoops("temp_matmul_out"); std::vector k_loops = ir_schedule.Split(loops[2], {steps_k, -1}); - std::vector j_loops = ir_schedule.Split(loops[1], {num_blocks_x, num_threads_x, -1}); - std::vector i_loops = ir_schedule.Split(loops[0], {num_blocks_y, num_threads_y, -1}); + std::vector j_loops = + ir_schedule.Split(loops[1], {num_blocks_x, num_threads_x, -1}); + std::vector i_loops = + ir_schedule.Split(loops[0], {num_blocks_y, num_threads_y, -1}); // reorder to "SSRRS": i0, j0, i1, j1, k0, k1, j2, i2 loops = ir_schedule.GetLoops("temp_matmul_out"); - ir_schedule.Reorder({loops[0], loops[3], loops[1], loops[4], loops[6], loops[7], loops[2], loops[5]}); + ir_schedule.Reorder({loops[0], + loops[3], + loops[1], + loops[4], + loops[6], + loops[7], + loops[2], + loops[5]}); // fuse and bind - loops = ir_schedule.GetLoops("temp_matmul_out"); + loops = ir_schedule.GetLoops("temp_matmul_out"); ir::Expr i1_j1_fused = ir_schedule.Fuse({loops[2], loops[3]}); ir::Expr i0_j0_fused = ir_schedule.Fuse({loops[0], loops[1]}); - loops = ir_schedule.GetLoops("temp_matmul_out"); + loops = ir_schedule.GetLoops("temp_matmul_out"); ir_schedule.Bind(loops[1], "threadIdx.x"); ir_schedule.Bind(loops[0], "blockIdx.x"); // cache read - ir::Expr out_block = ir_schedule.GetBlock("temp_matmul_out"); + ir::Expr out_block = ir_schedule.GetBlock("temp_matmul_out"); ir::Expr X_cache_block = ir_schedule.CacheRead(out_block, 1, "shared"); - std::string X_cache_block_name = - X_cache_block.As()->schedule_block.As()->name; + std::string X_cache_block_name = X_cache_block.As() + ->schedule_block.As() + ->name; loops = ir_schedule.GetLoops("temp_matmul_out"); ir_schedule.ComputeAt(X_cache_block, loops[2]); - std::vector X_cache_loops = ir_schedule.GetLoops(X_cache_block_name); + std::vector X_cache_loops = + ir_schedule.GetLoops(X_cache_block_name); ir_schedule.Fuse({X_cache_loops[3], X_cache_loops[4]}); - ir_schedule.Annotate(ir_schedule.GetBlock(X_cache_block_name), ir::attr::cooperative_process, 0); + ir_schedule.Annotate(ir_schedule.GetBlock(X_cache_block_name), + ir::attr::cooperative_process, + 0); - out_block = ir_schedule.GetBlock("temp_matmul_out"); + out_block = ir_schedule.GetBlock("temp_matmul_out"); ir::Expr Y_cache_block = ir_schedule.CacheRead(out_block, 2, "shared"); - std::string Y_cache_block_name = - Y_cache_block.As()->schedule_block.As()->name; + std::string Y_cache_block_name = Y_cache_block.As() + ->schedule_block.As() + ->name; loops = ir_schedule.GetLoops("temp_matmul_out"); ir_schedule.ComputeAt(Y_cache_block, loops[2]); - std::vector Y_cache_loops = ir_schedule.GetLoops(Y_cache_block_name); + std::vector Y_cache_loops = + ir_schedule.GetLoops(Y_cache_block_name); ir_schedule.Fuse({Y_cache_loops[3], Y_cache_loops[4]}); - ir_schedule.Annotate(ir_schedule.GetBlock(Y_cache_block_name), ir::attr::cooperative_process, 0); + ir_schedule.Annotate(ir_schedule.GetBlock(Y_cache_block_name), + ir::attr::cooperative_process, + 0); // apply CooperativeProcess CooperativeProcess cooperative_process; @@ -180,14 +198,15 @@ TEST_F(TestCooperativeProcess, Matmul) { ASSERT_EQ(ir, expected_ir); // build ir::Module and debug source code - auto ir_module = BuildIRModule(ir_schedule); + auto ir_module = BuildIRModule(ir_schedule); auto source_code = GenSourceCode(ir_module); VLOG(6) << "scheduled source code:\n" << source_code; // execute and check precision CheckResult( GenExecutableKernel(ir_module), - GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))), + GenExecutableKernel(BuildIRModule(MakeIRSchedule( + matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))), default_input_names, default_output_names, {X_shape, Y_shape}, diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc index fddc04b4f37f0..a07d4ffff20e8 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc @@ -29,37 +29,45 @@ static constexpr uint32_t kMaxBlocks = 256; bool IsSpatialLoop(const ir::For* for_node) { if (for_node->for_type() != ir::ForType::Serial) return false; const auto& loop_var = for_node->loop_var; - // collect cases where the loop_var used in one of reduce axis in underneath ScheduleBlock - auto used_for_reduce_axis = ir::CollectIRNodesWithoutTensor(for_node->body, [&loop_var](const Expr* x) { - const auto* block_realize = x->As(); - if (!block_realize) return false; - - const auto* schedule_block = block_realize->schedule_block.As(); - CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock"; - CHECK_EQ(block_realize->iter_values.size(), schedule_block->iter_vars.size()); - for (int i = 0; i < block_realize->iter_values.size(); ++i) { - const ir::Var& iter_var = schedule_block->iter_vars[i]; - const ir::Expr& binding = block_realize->iter_values[i]; - if (iter_var->is_reduce_axis || iter_var->name.substr(0, 6) == "reduce") { - auto used_exprs = ir::CollectIRNodesWithoutTensor(binding, [&loop_var](const Expr* x) { - const ir::_Var_* var = x->As(); - if (var && (x->same_as(loop_var) || var->name == loop_var->name)) { - return true; + // collect cases where the loop_var used in one of reduce axis in underneath + // ScheduleBlock + auto used_for_reduce_axis = ir::CollectIRNodesWithoutTensor( + for_node->body, [&loop_var](const Expr* x) { + const auto* block_realize = x->As(); + if (!block_realize) return false; + + const auto* schedule_block = + block_realize->schedule_block.As(); + CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock"; + CHECK_EQ(block_realize->iter_values.size(), + schedule_block->iter_vars.size()); + for (int i = 0; i < block_realize->iter_values.size(); ++i) { + const ir::Var& iter_var = schedule_block->iter_vars[i]; + const ir::Expr& binding = block_realize->iter_values[i]; + if (iter_var->is_reduce_axis || + iter_var->name.substr(0, 6) == "reduce") { + auto used_exprs = ir::CollectIRNodesWithoutTensor( + binding, [&loop_var](const Expr* x) { + const ir::_Var_* var = x->As(); + if (var && + (x->same_as(loop_var) || var->name == loop_var->name)) { + return true; + } + return false; + }); + if (!used_exprs.empty()) return true; } - return false; - }); - if (!used_exprs.empty()) return true; - } - } + } - return false; - }); + return false; + }); if (!used_for_reduce_axis.empty()) return false; return true; } -// count the number of loops that can be binded from the input for_node to bottom +// count the number of loops that can be binded from the input for_node to +// bottom int CountLoopCanBinded(const ir::For* for_node) { int cnt = 0; while (for_node) { @@ -68,9 +76,11 @@ int CountLoopCanBinded(const ir::For* for_node) { cnt += 1; - CHECK(for_node->body.defined() && for_node->body.As()) << "Body is not defined"; + CHECK(for_node->body.defined() && for_node->body.As()) + << "Body is not defined"; const ir::Block* body = for_node->body.As(); - // terminate when body of this loop has more than one statement or the body is not a ir::For node + // terminate when body of this loop has more than one statement or the body + // is not a ir::For node for_node = body->stmts.size() == 1 ? body->stmts[0].As() : nullptr; } return cnt; @@ -82,14 +92,18 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule, int max_blocks, int max_threads_per_block) { auto all_loops = ir_schedule->GetLoops(block_name); - CHECK_LE(num_loops_to_bind, all_loops.size()) << "The number of loops to be bind is greater than size of all_loops"; - // check whether it is the case that threadIdx has been binded but blockIdx not, - // the threadIdx can only be binded in the first loop after num_loops_to_bind loops - // because we has excluded other cases in CountLoopCanBinded + CHECK_LE(num_loops_to_bind, all_loops.size()) + << "The number of loops to be bind is greater than size of all_loops"; + // check whether it is the case that threadIdx has been binded but blockIdx + // not, the threadIdx can only be binded in the first loop after + // num_loops_to_bind loops because we has excluded other cases in + // CountLoopCanBinded bool gpu_thread_has_binded = - num_loops_to_bind < all_loops.size() && all_loops[num_loops_to_bind].As()->is_gpu_thread_binded(); - Expr fused_loop = ir_schedule->Fuse({all_loops.begin(), all_loops.begin() + num_loops_to_bind}); - int32_t extent = fused_loop.As()->extent.as_int32(); + num_loops_to_bind < all_loops.size() && + all_loops[num_loops_to_bind].As()->is_gpu_thread_binded(); + Expr fused_loop = ir_schedule->Fuse( + {all_loops.begin(), all_loops.begin() + num_loops_to_bind}); + int32_t extent = fused_loop.As()->extent.as_int32(); if (gpu_thread_has_binded) { ir_schedule->Bind(fused_loop, "blockIdx.x"); return; @@ -106,7 +120,8 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule, ir_schedule->Bind(splits[0], "blockIdx.x"); ir_schedule->Bind(splits[1], "threadIdx.x"); } else { - auto splits = ir_schedule->Split(fused_loop, {-1, max_blocks, max_threads_per_block}); + auto splits = + ir_schedule->Split(fused_loop, {-1, max_blocks, max_threads_per_block}); CHECK_EQ(splits.size(), 3); ir_schedule->Reorder({splits[1], splits[2], splits[0]}); all_loops = ir_schedule->GetLoops(block_name); @@ -126,31 +141,38 @@ RuleApplyType AutoBind::Init(ir::IRSchedule* ir_schedule) { } num_applicable_ = applicable_schedule_blocks_.size(); VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_; - return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; + return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules + : RuleApplyType::kCannotApply; } void AutoBind::Apply(int index) { - CHECK_LT(index, applicable_schedule_blocks_.size()) << "invalid apply index:" << index; + CHECK_LT(index, applicable_schedule_blocks_.size()) + << "invalid apply index:" << index; auto applied_block = applicable_schedule_blocks_.at(index); - auto all_loops = ir_schedule_->GetLoops(applied_block); + auto all_loops = ir_schedule_->GetLoops(applied_block); BindGPUIndex(ir_schedule_, - applied_block.As()->schedule_block.As()->name, + applied_block.As() + ->schedule_block.As() + ->name, CountLoopCanBinded(all_loops[0].As()), kMaxBlocks, target_->max_num_threads()); return; } -RuleApplyType AutoBind::AnalyseApplyType(SearchState state, const std::string& block_name) const { +RuleApplyType AutoBind::AnalyseApplyType(SearchState state, + const std::string& block_name) const { Expr block_expr = state->ir_schedule.GetBlock(block_name); - auto all_loops = state->ir_schedule.GetLoops(block_expr); - return CountLoopCanBinded(all_loops[0].As()) > 0 ? RuleApplyType::kApplyAndPruneOtherRules - : RuleApplyType::kCannotApply; + auto all_loops = state->ir_schedule.GetLoops(block_expr); + return CountLoopCanBinded(all_loops[0].As()) > 0 + ? RuleApplyType::kApplyAndPruneOtherRules + : RuleApplyType::kCannotApply; } -std::vector AutoBind::ApplyOnBlock(SearchState state, const std::string& block_name) { +std::vector AutoBind::ApplyOnBlock(SearchState state, + const std::string& block_name) { SearchState new_state = state.Copy(); - auto all_loops = state->ir_schedule.GetLoops(block_name); + auto all_loops = state->ir_schedule.GetLoops(block_name); BindGPUIndex(&new_state->ir_schedule, block_name, CountLoopCanBinded(all_loops[0].As()), diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h index 8b05ec75b3e9b..e4dfb59e09ff0 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h @@ -36,9 +36,11 @@ class AutoBind : public AutoGenRule { std::string GetRuleName() const override { return "AutoBind"; } - RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; + RuleApplyType AnalyseApplyType(SearchState state, + const std::string& block_name) const override; - std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override; + std::vector ApplyOnBlock(SearchState state, + const std::string& block_name) override; private: std::vector applicable_schedule_blocks_; diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind_test.cc index b15a2267add47..751c4f931d6d1 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind_test.cc @@ -28,17 +28,19 @@ namespace cinn { namespace auto_schedule { -static constexpr uint32_t kMaxBlocks = 256; +static constexpr uint32_t kMaxBlocks = 256; static constexpr uint32_t kMaxThreadsPerBlock = 1024; class TestAutoBind : public TestAutoGenRuleBase { public: - std::vector default_input_names = {"X", "Y"}; + std::vector default_input_names = {"X", "Y"}; std::vector default_output_names = {"temp_matmul_out"}; - void TestApplyOnElementWiseAdd(const std::vector& shape, const std::string& block_name) { + void TestApplyOnElementWiseAdd(const std::vector& shape, + const std::string& block_name) { Initialize(common::DefaultNVGPUTarget()); - auto test_program = tests::OpBuilder("elementwise_add").Build({{"X", shape}, {"Y", shape}}); + auto test_program = + tests::OpBuilder("elementwise_add").Build({{"X", shape}, {"Y", shape}}); // construct input parameter ir::IRSchedule ir_schedule = MakeIRSchedule(test_program); SearchState state(ir_schedule, 0, {}); @@ -48,15 +50,17 @@ class TestAutoBind : public TestAutoGenRuleBase { // apply AutoBind auto_bind(target_); - ASSERT_EQ(auto_bind.AnalyseApplyType(state, block_name), RuleApplyType::kApplyAndPruneOtherRules); - auto result = auto_bind.ApplyOnBlock(state, block_name)[0]; + ASSERT_EQ(auto_bind.AnalyseApplyType(state, block_name), + RuleApplyType::kApplyAndPruneOtherRules); + auto result = auto_bind.ApplyOnBlock(state, block_name)[0]; std::vector exprs = result->ir_schedule.GetModule().GetExprs(); EXPECT_EQ(exprs.size(), 1UL); VLOG(6) << "AutoBind applied Expr: " << exprs[0]; // check bind result auto all_loops = result->ir_schedule.GetLoops(block_name); - int total_num = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + int total_num = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); if (total_num <= kMaxThreadsPerBlock) { ASSERT_EQ(all_loops.size(), 1); EXPECT_EQ(all_loops[0].As()->extent.as_int32(), total_num); @@ -64,27 +68,33 @@ class TestAutoBind : public TestAutoGenRuleBase { } else if (total_num <= kMaxBlocks * kMaxThreadsPerBlock) { ASSERT_EQ(all_loops.size(), 2); EXPECT_EQ(all_loops[0].As()->extent.as_int32(), - static_cast(std::ceil(double(total_num) / kMaxThreadsPerBlock))); + static_cast( + std::ceil(double(total_num) / kMaxThreadsPerBlock))); EXPECT_TRUE(all_loops[0].As()->is_gpu_block_binded()); - EXPECT_EQ(all_loops[1].As()->extent.as_int32(), kMaxThreadsPerBlock); + EXPECT_EQ(all_loops[1].As()->extent.as_int32(), + kMaxThreadsPerBlock); EXPECT_TRUE(all_loops[1].As()->is_gpu_thread_binded()); } else { ASSERT_EQ(all_loops.size(), 3); EXPECT_EQ(all_loops[0].As()->extent.as_int32(), kMaxBlocks); EXPECT_TRUE(all_loops[0].As()->is_gpu_block_binded()); - EXPECT_EQ(all_loops[1].As()->extent.as_int32(), kMaxThreadsPerBlock); + EXPECT_EQ(all_loops[1].As()->extent.as_int32(), + kMaxThreadsPerBlock); EXPECT_TRUE(all_loops[1].As()->is_gpu_thread_binded()); EXPECT_EQ(all_loops[2].As()->extent.as_int32(), - static_cast(std::ceil(double(total_num) / (kMaxBlocks * kMaxThreadsPerBlock)))); + static_cast(std::ceil( + double(total_num) / (kMaxBlocks * kMaxThreadsPerBlock)))); EXPECT_FALSE(all_loops[2].As()->is_binded()); } // build and run - auto ir_module = BuildIRModule(result->ir_schedule); + auto ir_module = BuildIRModule(result->ir_schedule); auto source_code = GenSourceCode(ir_module); VLOG(6) << "Optimized source code:\n" << source_code; - auto manual_ir_module = BuildIRModule(MakeIRSchedule(test_program, /* apply_manual_schedule*/ true)); - VLOG(6) << "Manual-schedule compiled source code:\n" << GenSourceCode(manual_ir_module); + auto manual_ir_module = BuildIRModule( + MakeIRSchedule(test_program, /* apply_manual_schedule*/ true)); + VLOG(6) << "Manual-schedule compiled source code:\n" + << GenSourceCode(manual_ir_module); CheckResult(GenExecutableKernel(ir_module), GenExecutableKernel(manual_ir_module), default_input_names, @@ -97,16 +107,20 @@ class TestAutoBind : public TestAutoGenRuleBase { TEST_F(TestAutoBind, AnalyseApplyType) { Initialize(common::DefaultNVGPUTarget()); - ir::IRSchedule ir_schedule = MakeIRSchedule(tests::OpBuilder("matmul").Build({{"X", {32, 64}}, {"Y", {64, 32}}})); + ir::IRSchedule ir_schedule = MakeIRSchedule( + tests::OpBuilder("matmul").Build({{"X", {32, 64}}, {"Y", {64, 32}}})); SearchState state(ir_schedule, 0, {}); AutoBind auto_bind(target_); const std::string& applied_block_name = default_output_names.back(); // outer two loops of initial Expr are spatial loops, so it can be applied - EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name), + RuleApplyType::kApplyAndPruneOtherRules); state->ir_schedule.Fuse(applied_block_name, {0, 1}); - state->ir_schedule.Bind(state->ir_schedule.GetLoops(applied_block_name)[0], "threadIdx.x"); + state->ir_schedule.Bind(state->ir_schedule.GetLoops(applied_block_name)[0], + "threadIdx.x"); // after fuse and bind, there is no loops to be binded. - EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name), RuleApplyType::kCannotApply); + EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name), + RuleApplyType::kCannotApply); } TEST_F(TestAutoBind, ApplyOnBlock) { diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.cc index af0d6a9e99638..fd417c0dfbb9a 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.cc @@ -27,12 +27,16 @@ namespace auto_schedule { AutoGenRule::AutoGenRule(const common::Target& target) : target_(&target) {} int AutoGenRule::NumberApplicable() const { - CHECK_GE(num_applicable_, 0) << "Call " << GetRuleName() << "::NumberApplicable() without initialization."; + CHECK_GE(num_applicable_, 0) + << "Call " << GetRuleName() + << "::NumberApplicable() without initialization."; return num_applicable_; } void AutoGenRule::ApplyRandomly() { - CHECK_GT(num_applicable_, 0) << "Call " << GetRuleName() << "::ApplyRandomly() with NumberApplicable() == 0"; + CHECK_GT(num_applicable_, 0) + << "Call " << GetRuleName() + << "::ApplyRandomly() with NumberApplicable() == 0"; int index = rand() % num_applicable_; return Apply(index); } diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h index 4bca9f34483bc..6b74861637c61 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h @@ -29,15 +29,18 @@ enum class RuleApplyType : int { // This rule cannot be applied to ModuleExpr. kCannotApply = 0, // This rule can be applied to ModuleExpr, - // and the original ModuleExpr will be retained for branching with other rules. + // and the original ModuleExpr will be retained for branching with other + // rules. kApply = 1, // This rule can be applied, but the original ModuleExpr will be deleted, - // so the branches with other rules applied on the original ModuleExpr will be pruned. + // so the branches with other rules applied on the original ModuleExpr will be + // pruned. kApplyAndPruneOtherRules = 2, }; /** - * Base class for rules of auto-generating schedule (like Ansor's sketch generation) + * Base class for rules of auto-generating schedule (like Ansor's sketch + * generation) * */ class AutoGenRule { @@ -46,7 +49,8 @@ class AutoGenRule { ~AutoGenRule() = default; // Initialize the AutoGenRule, it must be called before further actions. - // Returns false if the rule cannot be applied on the mod_expr, true otherwise. + // Returns false if the rule cannot be applied on the mod_expr, true + // otherwise. virtual RuleApplyType Init(ir::IRSchedule* ir_schedule) = 0; // CINN IRSchedule can contain many ScheduleBlock(s) and Loop(s), so @@ -65,11 +69,15 @@ class AutoGenRule { // Returns the name of the rule, used for debug. virtual std::string GetRuleName() const = 0; - // Analyze the ApplyType of the rule used for a block determined by a specific SearchState and block name - virtual RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const = 0; + // Analyze the ApplyType of the rule used for a block determined by a specific + // SearchState and block name + virtual RuleApplyType AnalyseApplyType( + SearchState state, const std::string& block_name) const = 0; - // Apply the rule to a block determined by a specific SearchState and block name - virtual std::vector ApplyOnBlock(SearchState state, const std::string& block_name) = 0; + // Apply the rule to a block determined by a specific SearchState and block + // name + virtual std::vector ApplyOnBlock( + SearchState state, const std::string& block_name) = 0; protected: // number of ScheduleBlock that can apply this auto gen rule diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc index 872e66e8928f5..c5efb376d12c4 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc @@ -34,31 +34,38 @@ namespace cinn { namespace auto_schedule { -AutoInline::AutoInline(const common::Target& target, const std::unordered_set& no_inline_output_names) +AutoInline::AutoInline( + const common::Target& target, + const std::unordered_set& no_inline_output_names) : AutoGenRule(target), no_inline_output_names_(no_inline_output_names) {} -bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const { - const ir::ScheduleBlockRealize* sche_block_realize = sche_block_realize_expr.As(); - const ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As(); - ir::Expr compute_body = sche_block->body; - ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr); +bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, + ir::IRSchedule* ir_sch) const { + const ir::ScheduleBlockRealize* sche_block_realize = + sche_block_realize_expr.As(); + const ir::ScheduleBlock* sche_block = + sche_block_realize->schedule_block.As(); + ir::Expr compute_body = sche_block->body; + ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr); // Check the schedule block to be inlined is not a reduce tensor. - std::set find_store = - ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { return x->As(); }); + std::set find_store = ir::CollectIRNodesWithoutTensor( + compute_body, [&](const Expr* x) { return x->As(); }); if (find_store.size() != 1UL) { return false; } ir::Expr tensor_expr = (*find_store.begin()).As()->tensor; - ir::Tensor tensor = tensor_expr.as_tensor_ref(); + ir::Tensor tensor = tensor_expr.as_tensor_ref(); if (tensor->is_reduce_tensor()) { return false; } // LoweredFunc output can be tensor name or tensor buffer name - if (no_inline_output_names_.find(tensor->name) != no_inline_output_names_.end() || - no_inline_output_names_.find(tensor->buffer->name) != no_inline_output_names_.end()) { + if (no_inline_output_names_.find(tensor->name) != + no_inline_output_names_.end() || + no_inline_output_names_.find(tensor->buffer->name) != + no_inline_output_names_.end()) { return false; } @@ -70,26 +77,32 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir:: // Check this schedule block is the only writer of the tensor. find_store = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { - return x->As() && (x->As()->tensor).as_tensor_ref()->name == tensor->name; + return x->As() && + (x->As()->tensor).as_tensor_ref()->name == tensor->name; }); if (find_store.size() != 1UL) { return false; } - // Check there is no overlap between the buffers the schedule block reads and writes. - std::set find_load = ir::CollectIRNodesWithoutTensor( - compute_body, [&](const Expr* x) { return x->As() && x->As()->tensor == tensor_expr; }); + // Check there is no overlap between the buffers the schedule block reads and + // writes. + std::set find_load = + ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { + return x->As() && x->As()->tensor == tensor_expr; + }); if (!find_load.empty()) { return false; } ir::Expr store = *(find_store.begin()); - ir::ComputeInliner inliner(store.As()->tensor.as_tensor_ref(), store); + ir::ComputeInliner inliner(store.As()->tensor.as_tensor_ref(), + store); if (!inliner.BodyPatternAllowInline()) { return false; } - ir::LeafBlockRemovalPlan remove_plan(sche_block_realize_expr, &inliner.src_stmt, &inliner.tgt_stmt); + ir::LeafBlockRemovalPlan remove_plan( + sche_block_realize_expr, &inliner.src_stmt, &inliner.tgt_stmt); remove_plan(&root); if (!inliner.src_stmt.defined() || !inliner.tgt_stmt.defined()) { return false; @@ -99,16 +112,20 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir:: return true; } -AutoInlineType AutoInline::AnalyzeInlineType(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const { - const ir::ScheduleBlockRealize* sche_block_realize = sche_block_realize_expr.As(); - const ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As(); +AutoInlineType AutoInline::AnalyzeInlineType( + const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const { + const ir::ScheduleBlockRealize* sche_block_realize = + sche_block_realize_expr.As(); + const ir::ScheduleBlock* sche_block = + sche_block_realize->schedule_block.As(); // Inline if the block has only 1 write buffer if (sche_block->write_buffers.size() != 1) { return AutoInlineType::kCannotInline; } - std::unordered_set no_inline_node_types = {ir::IrNodeTy::IfThenElse}; + std::unordered_set no_inline_node_types = { + ir::IrNodeTy::IfThenElse}; if (ContainsNodeType(sche_block->body, no_inline_node_types)) { return AutoInlineType::kCannotInline; } @@ -125,31 +142,38 @@ AutoInlineType AutoInline::AnalyzeInlineType(const Expr& sche_block_realize_expr } RuleApplyType AutoInline::Init(ir::IRSchedule* ir_schedule) { - ir_schedule_ = ir_schedule; + ir_schedule_ = ir_schedule; all_block_realizes_ = ir_schedule_->GetAllBlocks(); apply_indices_and_type_.clear(); num_applicable_ = 0; for (size_t i = 0; i < all_block_realizes_.size(); ++i) { - ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As(); - AnalyzeScheduleBlockReadWriteBuffer(sche_block_realize->schedule_block.As()); - AutoInlineType type = AnalyzeInlineType(all_block_realizes_[i], ir_schedule_); + ir::ScheduleBlockRealize* sche_block_realize = + all_block_realizes_[i].As(); + AnalyzeScheduleBlockReadWriteBuffer( + sche_block_realize->schedule_block.As()); + AutoInlineType type = + AnalyzeInlineType(all_block_realizes_[i], ir_schedule_); if (type != AutoInlineType::kCannotInline) { ++num_applicable_; apply_indices_and_type_.push_back({i, type}); } } - return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; + return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules + : RuleApplyType::kCannotApply; } void AutoInline::Apply(int index) { CHECK(ir_schedule_ != nullptr) << "Run AutoInline::Apply without Init"; - CHECK(num_applicable_ > 0 && apply_indices_and_type_.size() == num_applicable_) + CHECK(num_applicable_ > 0 && + apply_indices_and_type_.size() == num_applicable_) << "AutoInline::Apply pre-condition doesn't meet"; CHECK(index >= 0 && num_applicable_ > index) - << "Invalid index for AutoInline::Apply, the index needs 0 <= index && index < NumberApplicable(), " - << "Currently index = " << index << ", NumberApplicable() = " << num_applicable_; + << "Invalid index for AutoInline::Apply, the index needs 0 <= index && " + "index < NumberApplicable(), " + << "Currently index = " << index + << ", NumberApplicable() = " << num_applicable_; int apply_index = apply_indices_and_type_[index].first; Apply(ir_schedule_, all_block_realizes_[apply_index]); @@ -158,20 +182,25 @@ void AutoInline::Apply(int index) { std::string AutoInline::GetRuleName() const { return "AutoInline"; } -RuleApplyType AutoInline::AnalyseApplyType(SearchState state, const std::string& block_name) const { - Expr block_expr = state->ir_schedule.GetBlock(block_name); +RuleApplyType AutoInline::AnalyseApplyType( + SearchState state, const std::string& block_name) const { + Expr block_expr = state->ir_schedule.GetBlock(block_name); auto* block_realize = block_expr.As(); CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr; - AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As()); + AnalyzeScheduleBlockReadWriteBuffer( + block_realize->schedule_block.As()); AutoInlineType type = AnalyzeInlineType(block_expr, &state->ir_schedule); - return type == AutoInlineType::kCannotInline ? RuleApplyType::kCannotApply : RuleApplyType::kApplyAndPruneOtherRules; + return type == AutoInlineType::kCannotInline + ? RuleApplyType::kCannotApply + : RuleApplyType::kApplyAndPruneOtherRules; } -std::vector AutoInline::ApplyOnBlock(SearchState state, const std::string& block_name) { +std::vector AutoInline::ApplyOnBlock( + SearchState state, const std::string& block_name) { SearchState new_state = state.Copy(); - Expr block_expr = new_state->ir_schedule.GetBlock(block_name); + Expr block_expr = new_state->ir_schedule.GetBlock(block_name); Apply(&new_state->ir_schedule, block_expr); return {new_state}; @@ -181,7 +210,8 @@ void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { auto* block_realize = block_expr.As(); CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr; - AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As()); + AnalyzeScheduleBlockReadWriteBuffer( + block_realize->schedule_block.As()); AutoInlineType type = AnalyzeInlineType(block_expr, ir_schedule); if (type == AutoInlineType::kInlineIntoConsumer) { @@ -202,10 +232,12 @@ void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { // we need to re-analyze all_block_realizes_ = ir_schedule->GetAllBlocks(); for (size_t i = 0; i < all_block_realizes_.size(); ++i) { - ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As(); - ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As(); - sche_block->read_buffers = {}; - sche_block->write_buffers = {}; + ir::ScheduleBlockRealize* sche_block_realize = + all_block_realizes_[i].As(); + ir::ScheduleBlock* sche_block = + sche_block_realize->schedule_block.As(); + sche_block->read_buffers = {}; + sche_block->write_buffers = {}; AnalyzeScheduleBlockReadWriteBuffer(sche_block); } } diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h index db89070b4529d..02090467049a0 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h @@ -41,7 +41,8 @@ enum class AutoInlineType : int { class AutoInline : public AutoGenRule { public: - AutoInline(const common::Target& target, const std::unordered_set& no_inline_output_names); + AutoInline(const common::Target& target, + const std::unordered_set& no_inline_output_names); ~AutoInline() = default; RuleApplyType Init(ir::IRSchedule* ir_schedule) override; @@ -50,13 +51,17 @@ class AutoInline : public AutoGenRule { std::string GetRuleName() const override; - AutoInlineType AnalyzeInlineType(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const; + AutoInlineType AnalyzeInlineType(const Expr& sche_block_realize_expr, + ir::IRSchedule* ir_sch) const; - bool CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const; + bool CanInlineIntoConsumer(const Expr& sche_block_realize_expr, + ir::IRSchedule* ir_sch) const; - RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; + RuleApplyType AnalyseApplyType(SearchState state, + const std::string& block_name) const override; - std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override; + std::vector ApplyOnBlock(SearchState state, + const std::string& block_name) override; private: void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc index dd54ed59e1f34..a4e54c0731987 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc @@ -63,7 +63,14 @@ TEST(AutoInline, SingleLoopInline) { poly::StageMap stages = CreateStages({A, B, C}); std::vector funcs = - lang::LowerVec("TestAutoInline_SingleLoopInline", stages, {A, C}, {}, {}, nullptr, target, true); + lang::LowerVec("TestAutoInline_SingleLoopInline", + stages, + {A, C}, + {}, + {}, + nullptr, + target, + true); VLOG(6) << "Expr after lowering:"; VLOG(6) << funcs[0]->body; @@ -74,7 +81,7 @@ TEST(AutoInline, SingleLoopInline) { */ ir::IRSchedule ir_sch(ir::ModuleExpr(std::vector{funcs[0]->body})); SearchState state(ir_sch, 0, {}); - ir::Expr block_b = ir_sch.GetBlock("B"); + ir::Expr block_b = ir_sch.GetBlock("B"); std::vector loops = ir_sch.GetLoops("C"); ir_sch.ComputeAt(block_b, loops[0]); @@ -90,12 +97,13 @@ TEST(AutoInline, SingleLoopInline) { EXPECT_EQ(exprs.size(), 1UL); // ApplyOnBlock - EXPECT_EQ(auto_inline.AnalyseApplyType(state, "B"), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(auto_inline.AnalyseApplyType(state, "B"), + RuleApplyType::kApplyAndPruneOtherRules); auto new_states = auto_inline.ApplyOnBlock(state, "B"); auto test_func = [](ir::IRSchedule* ir_sch) { ir::ModuleExpr mod_expr_after_inline = ir_sch->GetModule(); - std::vector exprs = mod_expr_after_inline.GetExprs(); + std::vector exprs = mod_expr_after_inline.GetExprs(); EXPECT_EQ(exprs.size(), 1UL); std::stringstream ss; @@ -130,7 +138,8 @@ TEST(AutoInline, SingleLoopInline) { // Cannot inline above expr again EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply); - EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "C"), RuleApplyType::kCannotApply); + EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "C"), + RuleApplyType::kCannotApply); } TEST(AutoInline, AddReluInline) { @@ -148,15 +157,20 @@ TEST(AutoInline, AddReluInline) { frontend::Program program = builder.Build(); FLAGS_cinn_ir_schedule = true; - auto graph = std::make_shared(program, target); + auto graph = std::make_shared(program, target); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - const auto& dtype_dict = graph->GetAttrs>("inferdtype"); - const auto& shape_dict = graph->GetAttrs>("infershape"); - auto op_lowerer = std::make_unique(dtype_dict, shape_dict, target); + const auto& dtype_dict = + graph->GetAttrs>( + "inferdtype"); + const auto& shape_dict = graph->GetAttrs< + absl::flat_hash_map>("infershape"); + auto op_lowerer = std::make_unique( + dtype_dict, shape_dict, target); EXPECT_EQ(graph->fusion_groups.size(), 1UL); - std::vector funcs = op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]); + std::vector funcs = + op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]); VLOG(6) << "Expr before auto inline: " << funcs[0]->body; @@ -170,7 +184,7 @@ TEST(AutoInline, AddReluInline) { auto_inline.Apply(1); ir::ModuleExpr mod_expr_after_inline = ir_sch.GetModule(); - std::vector exprs = mod_expr_after_inline.GetExprs(); + std::vector exprs = mod_expr_after_inline.GetExprs(); EXPECT_EQ(exprs.size(), 1UL); std::stringstream ss; @@ -186,15 +200,17 @@ TEST(AutoInline, AddReluInline) { auto_inline.Apply(0); // ApplyOnBlock - EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_1"), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_1"), + RuleApplyType::kApplyAndPruneOtherRules); auto new_states = auto_inline.ApplyOnBlock(state, "var_1"); // Auto Inline again - EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_3"), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_3"), + RuleApplyType::kApplyAndPruneOtherRules); new_states = auto_inline.ApplyOnBlock(new_states[0], "var_3"); auto test_func = [](ir::IRSchedule* ir_sch) { ir::ModuleExpr final_mod_expr = ir_sch->GetModule(); - auto exprs = final_mod_expr.GetExprs(); + auto exprs = final_mod_expr.GetExprs(); EXPECT_EQ(exprs.size(), 1UL); std::stringstream ss; @@ -238,7 +254,8 @@ TEST(AutoInline, AddReluInline) { // Cannot inline above expr again EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply); - EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_2"), RuleApplyType::kCannotApply); + EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_2"), + RuleApplyType::kCannotApply); } #ifdef CINN_WITH_CUDA @@ -246,14 +263,8 @@ class TestAutoInline : public TestAutoGenRuleBase {}; /* The single chain graph composed of multiple blocks can be inlined into one. * - * Before AutoInline: The output of the previous block is the input of another block. - * Loop1: - * x1 = Add() - * Loop2: - * x2 = Multiply(x1) - * Loop3: - * x3 = Add(x2) - * Loop4: + * Before AutoInline: The output of the previous block is the input of another + * block. Loop1: x1 = Add() Loop2: x2 = Multiply(x1) Loop3: x3 = Add(x2) Loop4: * x4 = Relu(x3) * * After AutoInline: All loops are inlined into a loop. @@ -263,18 +274,22 @@ class TestAutoInline : public TestAutoGenRuleBase {}; TEST_F(TestAutoInline, SingleChain) { Target target = common::DefaultNVGPUTarget(); Initialize(target); - std::vector input_names = {"bias", "conv_output", "bn_scale", "bn_offset"}; - std::vector output_names = {"var_6", "var_5", "var_1", "var", "var_0", "var_4", "var_3"}; + std::vector input_names = { + "bias", "conv_output", "bn_scale", "bn_offset"}; + std::vector output_names = { + "var_6", "var_5", "var_1", "var", "var_0", "var_4", "var_3"}; std::vector conv_output_shape = {1, 512, 56, 56}; - int32_t channel = conv_output_shape[1]; - std::vector inputs_varinfo({{"conv_output", conv_output_shape}, - {"bias", {channel, 1, 1}}, - {"bn_scale", {channel, 1, 1}}, - {"bn_offset", {channel, 1, 1}}}); + int32_t channel = conv_output_shape[1]; + std::vector inputs_varinfo( + {{"conv_output", conv_output_shape}, + {"bias", {channel, 1, 1}}, + {"bn_scale", {channel, 1, 1}}, + {"bn_offset", {channel, 1, 1}}}); // Construct the computation graph and convert it to ir::Expr Context::Global().ResetNameId(); - ir::IRSchedule ir_schedule = MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo)); + ir::IRSchedule ir_schedule = + MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo)); SearchState state(ir_schedule, 0, {}); std::vector func_bodys = ir_schedule.GetModule().GetExprs(); ASSERT_EQ(func_bodys.size(), 1UL); @@ -282,20 +297,23 @@ TEST_F(TestAutoInline, SingleChain) { // Apply AutoInline for every block that can be inline AutoInline auto_inline(target_, {output_names.front()}); - EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_3"), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_3"), + RuleApplyType::kApplyAndPruneOtherRules); auto new_states = auto_inline.ApplyOnBlock(state, "var_3"); - std::vector inline_block_names({"var_4", "var_5", "var_6", "var", "var_0", "var_1"}); + std::vector inline_block_names( + {"var_4", "var_5", "var_6", "var", "var_0", "var_1"}); for (const auto& inline_block_name : inline_block_names) { new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name); } - std::vector exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); + std::vector exprs = + new_states[0]->ir_schedule.GetModule().GetExprs(); EXPECT_EQ(exprs.size(), 1UL); VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; // build ir::Module and debug source code auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); - auto build_module_manually = - BuildIRModule(MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo), -1, true)); + auto build_module_manually = BuildIRModule(MakeIRSchedule( + tests::BiasBnReLUBuilder().Build(inputs_varinfo), -1, true)); auto source_code_auto = GenSourceCode(build_module_auto); VLOG(6) << " auto-schedule source code:\n" << source_code_auto; auto source_code_manually = GenSourceCode(build_module_manually); @@ -305,7 +323,10 @@ TEST_F(TestAutoInline, SingleChain) { GenExecutableKernel(build_module_manually), input_names, output_names, - {{conv_output_shape[1], 1, 1}, conv_output_shape, conv_output_shape, conv_output_shape}, + {{conv_output_shape[1], 1, 1}, + conv_output_shape, + conv_output_shape, + conv_output_shape}, {conv_output_shape, {1}, {1}, {1}, {1}, {1}, {1}}, target); } @@ -328,14 +349,15 @@ TEST_F(TestAutoInline, SingleChain) { TEST_F(TestAutoInline, InlineToMultiConsumers) { Target target = common::DefaultNVGPUTarget(); Initialize(target); - std::vector input_names = {"x"}; + std::vector input_names = {"x"}; std::vector output_names = {"var_2", "var_1", "var_0"}; std::vector input_shape{256, 256}; std::vector inputs_varinfo({{"x", input_shape}}); // Construct the computation graph and convert it to ir::Expr Context::Global().ResetNameId(); - ir::IRSchedule ir_schedule = MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo)); + ir::IRSchedule ir_schedule = + MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo)); SearchState state(ir_schedule, 0, {}); std::vector func_bodys = ir_schedule.GetModule().GetExprs(); ASSERT_EQ(func_bodys.size(), 1UL); @@ -343,17 +365,19 @@ TEST_F(TestAutoInline, InlineToMultiConsumers) { // Apply AutoInline for every block that can be inline AutoInline auto_inline(target_, {output_names.front()}); - EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_0"), RuleApplyType::kApplyAndPruneOtherRules); - auto new_states = auto_inline.ApplyOnBlock(state, "var_1"); - new_states = auto_inline.ApplyOnBlock(state, "var_0"); - std::vector exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); + EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_0"), + RuleApplyType::kApplyAndPruneOtherRules); + auto new_states = auto_inline.ApplyOnBlock(state, "var_1"); + new_states = auto_inline.ApplyOnBlock(state, "var_0"); + std::vector exprs = + new_states[0]->ir_schedule.GetModule().GetExprs(); EXPECT_EQ(exprs.size(), 1UL); VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; // build ir::Module and debug source code auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); - auto build_module_manually = - BuildIRModule(MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo), -1, true)); + auto build_module_manually = BuildIRModule(MakeIRSchedule( + tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo), -1, true)); auto source_code_auto = GenSourceCode(build_module_auto); VLOG(6) << " auto-schedule source code:\n" << source_code_auto; auto source_code_manually = GenSourceCode(build_module_manually); @@ -386,15 +410,21 @@ TEST_F(TestAutoInline, InlineToMultiConsumers) { TEST_F(TestAutoInline, OnlySpatialOp) { Target target = common::DefaultNVGPUTarget(); Initialize(target); - std::vector input_names = {"x", "y"}; - std::vector output_names = { - "var_6", "var_4", "constant_idx_last", "constant_idx_first", "var_2", "var_5"}; + std::vector input_names = {"x", "y"}; + std::vector output_names = {"var_6", + "var_4", + "constant_idx_last", + "constant_idx_first", + "var_2", + "var_5"}; std::vector input_shape{256, 256}; - std::vector inputs_varinfo({{"x", input_shape}, {"y", input_shape}}); + std::vector inputs_varinfo( + {{"x", input_shape}, {"y", input_shape}}); // Construct the computation graph and convert it to ir::Expr Context::Global().ResetNameId(); - ir::IRSchedule ir_schedule = MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo)); + ir::IRSchedule ir_schedule = + MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo)); SearchState state(ir_schedule, 0, {}); std::vector func_bodys = ir_schedule.GetModule().GetExprs(); ASSERT_EQ(func_bodys.size(), 1UL); @@ -402,20 +432,23 @@ TEST_F(TestAutoInline, OnlySpatialOp) { // Apply AutoInline for every block that can be inline AutoInline auto_inline(target_, {output_names.front()}); - EXPECT_EQ(auto_inline.AnalyseApplyType(state, "constant_idx_first"), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(auto_inline.AnalyseApplyType(state, "constant_idx_first"), + RuleApplyType::kApplyAndPruneOtherRules); auto new_states = auto_inline.ApplyOnBlock(state, "constant_idx_first"); - std::vector inline_block_names({"constant_idx_last", "var_2", "var_5", "var_4"}); + std::vector inline_block_names( + {"constant_idx_last", "var_2", "var_5", "var_4"}); for (const auto& inline_block_name : inline_block_names) { new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name); } - std::vector exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); + std::vector exprs = + new_states[0]->ir_schedule.GetModule().GetExprs(); EXPECT_EQ(exprs.size(), 1UL); VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; // build ir::Module and debug source code auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); - auto build_module_manually = - BuildIRModule(MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo), -1, true)); + auto build_module_manually = BuildIRModule(MakeIRSchedule( + tests::GatherAddSubBuilder().Build(inputs_varinfo), -1, true)); auto source_code_auto = GenSourceCode(build_module_auto); VLOG(6) << " auto-schedule source code:\n" << source_code_auto; auto source_code_manually = GenSourceCode(build_module_manually); @@ -445,13 +478,14 @@ TEST_F(TestAutoInline, OnlySpatialOp) { TEST_F(TestAutoInline, NoReadBufferOp) { Target target = common::DefaultNVGPUTarget(); Initialize(target); - std::vector input_names = {"x"}; + std::vector input_names = {"x"}; std::vector output_names = {"var_0", "fill_constant"}; std::vector input_shape{256, 256}; std::vector inputs_varinfo({{"x", input_shape}}); // Construct the computation graph and convert it to ir::Expr - ir::IRSchedule ir_schedule = MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo)); + ir::IRSchedule ir_schedule = + MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo)); SearchState state(ir_schedule, 0, {}); std::vector func_bodys = ir_schedule.GetModule().GetExprs(); ASSERT_EQ(func_bodys.size(), 1UL); @@ -459,16 +493,18 @@ TEST_F(TestAutoInline, NoReadBufferOp) { // Apply AutoInline for every block that can be inline AutoInline auto_inline(target_, {output_names.front()}); - EXPECT_EQ(auto_inline.AnalyseApplyType(state, "fill_constant"), RuleApplyType::kApplyAndPruneOtherRules); - auto new_states = auto_inline.ApplyOnBlock(state, "fill_constant"); - std::vector exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); + EXPECT_EQ(auto_inline.AnalyseApplyType(state, "fill_constant"), + RuleApplyType::kApplyAndPruneOtherRules); + auto new_states = auto_inline.ApplyOnBlock(state, "fill_constant"); + std::vector exprs = + new_states[0]->ir_schedule.GetModule().GetExprs(); EXPECT_EQ(exprs.size(), 1UL); VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; // build ir::Module and debug source code auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); - auto build_module_manually = - BuildIRModule(MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo), -1, true)); + auto build_module_manually = BuildIRModule(MakeIRSchedule( + tests::FillConstantAddBuilder().Build(inputs_varinfo), -1, true)); auto source_code_auto = GenSourceCode(build_module_auto); VLOG(6) << " auto-schedule source code:\n" << source_code_auto; auto source_code_manually = GenSourceCode(build_module_manually); diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc index 3483992421d64..a42df90350790 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc @@ -33,11 +33,13 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const { auto has_reduce_iter = [](const Expr* x) { auto* block_realize = x->As(); if (block_realize) { - auto* schedule_block = block_realize->schedule_block.As(); + auto* schedule_block = + block_realize->schedule_block.As(); CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock"; for (auto&& var : schedule_block->iter_vars) { if (var->is_reduce_axis) { - VLOG(6) << "find ScheduleBlockRealize:" << *x << " has reduce_axis:" << var; + VLOG(6) << "find ScheduleBlockRealize:" << *x + << " has reduce_axis:" << var; return true; } } @@ -46,7 +48,8 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const { }; // whether has any for-loop with non-serial type auto has_nonserial_loop = [](const Expr* x) { - if (x->As() && x->As()->for_type() != ir::ForType::Serial) { + if (x->As() && + x->As()->for_type() != ir::ForType::Serial) { VLOG(6) << "find non-serial loop:" << *x; return true; } @@ -55,13 +58,15 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const { auto find_target_exprs = ir::CollectIRNodesWithoutTensor( schedule_block->body, - [&has_reduce_iter, &has_nonserial_loop](const Expr* x) { return has_reduce_iter(x) || has_nonserial_loop(x); }); + [&has_reduce_iter, &has_nonserial_loop](const Expr* x) { + return has_reduce_iter(x) || has_nonserial_loop(x); + }); return !find_target_exprs.empty(); } RuleApplyType AutoUnroll::Init(ir::IRSchedule* ir_schedule) { - ir_schedule_ = ir_schedule; + ir_schedule_ = ir_schedule; auto block_realizes = ir_schedule_->GetAllBlocks(); // A schedule block can perform `auto_unroll` rule should meet two conditions: @@ -71,47 +76,58 @@ RuleApplyType AutoUnroll::Init(ir::IRSchedule* ir_schedule) { std::set deduplicate_results; for (size_t i = 0; i < block_realizes.size(); ++i) { // find root block - Expr root_block = ir_schedule_->GetRootBlock(block_realizes[i]); + Expr root_block = ir_schedule_->GetRootBlock(block_realizes[i]); auto* block_realize = root_block.As(); CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block; - auto* schedule_block = block_realize->schedule_block.As(); - CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:" << Expr(block_realize); + auto* schedule_block = + block_realize->schedule_block.As(); + CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:" + << Expr(block_realize); if (MeetCondition(schedule_block)) { deduplicate_results.emplace(root_block); } } - applicable_schedule_blocks_ = {deduplicate_results.begin(), deduplicate_results.end()}; - num_applicable_ = applicable_schedule_blocks_.size(); + applicable_schedule_blocks_ = {deduplicate_results.begin(), + deduplicate_results.end()}; + num_applicable_ = applicable_schedule_blocks_.size(); VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_; - return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; + return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules + : RuleApplyType::kCannotApply; } void AutoUnroll::Apply(int index) { - CHECK_LT(index, applicable_schedule_blocks_.size()) << "invalid apply index:" << index; + CHECK_LT(index, applicable_schedule_blocks_.size()) + << "invalid apply index:" << index; auto applied_block = applicable_schedule_blocks_.at(index); - int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()]; - ir_schedule_->Annotate(applied_block, ir::attr::auto_unroll_max_step, max_step); + int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()]; + ir_schedule_->Annotate( + applied_block, ir::attr::auto_unroll_max_step, max_step); return; } -RuleApplyType AutoUnroll::AnalyseApplyType(SearchState state, const std::string& block_name) const { - Expr block_expr = state->ir_schedule.GetBlock(block_name); - Expr root_block = state->ir_schedule.GetRootBlock(block_expr); +RuleApplyType AutoUnroll::AnalyseApplyType( + SearchState state, const std::string& block_name) const { + Expr block_expr = state->ir_schedule.GetBlock(block_name); + Expr root_block = state->ir_schedule.GetRootBlock(block_expr); auto* block_realize = root_block.As(); CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block; auto* schedule_block = block_realize->schedule_block.As(); - CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:" << Expr(block_realize); + CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:" + << Expr(block_realize); - return MeetCondition(schedule_block) ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; + return MeetCondition(schedule_block) ? RuleApplyType::kApplyAndPruneOtherRules + : RuleApplyType::kCannotApply; } -std::vector AutoUnroll::ApplyOnBlock(SearchState state, const std::string& block_name) { +std::vector AutoUnroll::ApplyOnBlock( + SearchState state, const std::string& block_name) { SearchState new_state = state.Copy(); - Expr block_expr = new_state->ir_schedule.GetBlock(block_name); - Expr applied_block = new_state->ir_schedule.GetRootBlock(block_expr); - int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()]; - new_state->ir_schedule.Annotate(applied_block, ir::attr::auto_unroll_max_step, max_step); + Expr block_expr = new_state->ir_schedule.GetBlock(block_name); + Expr applied_block = new_state->ir_schedule.GetRootBlock(block_expr); + int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()]; + new_state->ir_schedule.Annotate( + applied_block, ir::attr::auto_unroll_max_step, max_step); return {new_state}; } diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h index b42c3eed78683..ee2f2f1ea42ac 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h @@ -24,10 +24,11 @@ namespace cinn { namespace auto_schedule { -// This rule can be applied in a ScheduleBlock has reduce axis or has loops with non-serial type. -// As a result, it will set a attribute with key named ir::attr::auto_unroll_max_step and value -// indicating max permitted unrolled step in the applied ScheduleBlock. Finally, UnrollLoop pass -// will do unroll based on actual situation. +// This rule can be applied in a ScheduleBlock has reduce axis or has loops with +// non-serial type. As a result, it will set a attribute with key named +// ir::attr::auto_unroll_max_step and value indicating max permitted unrolled +// step in the applied ScheduleBlock. Finally, UnrollLoop pass will do unroll +// based on actual situation. class AutoUnroll : public AutoGenRule { public: AutoUnroll(const common::Target& target) : AutoGenRule(target) {} @@ -39,9 +40,11 @@ class AutoUnroll : public AutoGenRule { std::string GetRuleName() const override { return "AutoUnroll"; } - RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; + RuleApplyType AnalyseApplyType(SearchState state, + const std::string& block_name) const override; - std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override; + std::vector ApplyOnBlock(SearchState state, + const std::string& block_name) override; private: bool MeetCondition(const ir::ScheduleBlock* schedule_block) const; diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc index 4307ac7837376..7c4d313cc0ff0 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc @@ -39,7 +39,8 @@ TEST(AutoUnroll, Init) { Target target = common::DefaultHostTarget(); #endif auto stages = CreateStages({C}); - auto funcs = cinn::lang::LowerVec("test_init", stages, {A, B, C}, {}, {}, nullptr, target, true); + auto funcs = cinn::lang::LowerVec( + "test_init", stages, {A, B, C}, {}, {}, nullptr, target, true); auto ast_expr = funcs[0]->body; ir::IRSchedule init_schedule(ir::ModuleExpr({ast_expr})); @@ -58,7 +59,9 @@ TEST(AutoUnroll, UnrollableApply) { Placeholder B("B", {K, N}); Var k(K.as_int32(), "k0"); Tensor C = Compute( - {M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + {M, N}, + [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, + "C"); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -66,11 +69,14 @@ TEST(AutoUnroll, UnrollableApply) { Target target = common::DefaultHostTarget(); #endif auto stages = CreateStages({C}); - auto funcs = cinn::lang::LowerVec("test_unrollable", stages, {A, B, C}, {}, {}, nullptr, target, true); + auto funcs = cinn::lang::LowerVec( + "test_unrollable", stages, {A, B, C}, {}, {}, nullptr, target, true); - auto ast_expr = funcs[0]->body; - auto* init_block_realize = ast_expr.As()->stmts.front().As(); - auto* init_schedule_block = init_block_realize->schedule_block.As(); + auto ast_expr = funcs[0]->body; + auto* init_block_realize = + ast_expr.As()->stmts.front().As(); + auto* init_schedule_block = + init_block_realize->schedule_block.As(); ASSERT_NE(init_schedule_block, nullptr); ASSERT_TRUE(init_schedule_block->attrs.empty()); VLOG(6) << "Before auto-unroll:\n" << ast_expr; @@ -78,25 +84,34 @@ TEST(AutoUnroll, UnrollableApply) { AutoUnroll test_rule(target); ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); SearchState state(ir_schedule, 0, {}); - ASSERT_EQ(test_rule.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules); + ASSERT_EQ(test_rule.Init(&ir_schedule), + RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(test_rule.NumberApplicable(), 1); test_rule.ApplyRandomly(); // ApplyOnBlock - EXPECT_EQ(test_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules); - std::vector states = test_rule.ApplyOnBlock(state, "C"); + EXPECT_EQ(test_rule.AnalyseApplyType(state, "C"), + RuleApplyType::kApplyAndPruneOtherRules); + std::vector states = + test_rule.ApplyOnBlock(state, "C"); auto test_func = [](IRSchedule* ir_sch) { - Expr applied_expr = ir_sch->GetModule().GetExprs().front(); - auto* applied_block_realize = applied_expr.As()->stmts.front().As(); - auto* applied_schedule_block = applied_block_realize->schedule_block.As(); + Expr applied_expr = ir_sch->GetModule().GetExprs().front(); + auto* applied_block_realize = applied_expr.As() + ->stmts.front() + .As(); + auto* applied_schedule_block = + applied_block_realize->schedule_block.As(); ASSERT_FALSE(applied_schedule_block->attrs.empty()); - EXPECT_EQ(applied_schedule_block->attrs.count(ir::attr::auto_unroll_max_step), 1); - const auto& attr_value = applied_schedule_block->attrs.at(ir::attr::auto_unroll_max_step); - const int* max_step = absl::get_if(&attr_value); + EXPECT_EQ( + applied_schedule_block->attrs.count(ir::attr::auto_unroll_max_step), 1); + const auto& attr_value = + applied_schedule_block->attrs.at(ir::attr::auto_unroll_max_step); + const int* max_step = absl::get_if(&attr_value); EXPECT_NE(max_step, nullptr); EXPECT_LE(*max_step, 128); - VLOG(6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n" << ir_sch->GetModule().GetExprs().front(); + VLOG(6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n" + << ir_sch->GetModule().GetExprs().front(); }; test_func(&ir_schedule); diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/mix_rules_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/mix_rules_test.cc index e1779ad426b4b..2c2478d71723a 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/mix_rules_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/mix_rules_test.cc @@ -29,32 +29,35 @@ namespace auto_schedule { class TestMixRules : public TestAutoGenRuleBase { public: - std::vector default_input_names = {"X", "Y"}; + std::vector default_input_names = {"X", "Y"}; std::vector default_output_names = {"temp_matmul_out"}; }; TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) { - frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}); + frontend::Program matmul_op = + tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}); Initialize(common::DefaultNVGPUTarget()); - ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op); + ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op); std::vector func_bodys = ir_schedule.GetModule().GetExprs(); ASSERT_EQ(func_bodys.size(), 1UL); VLOG(6) << "Original Expr:\n" << func_bodys[0]; // Apply MultiLevelTiling - MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch)); + MultiLevelTiling multi_level_tiling( + target_, MultiLevelTiling::kConfigs.at(target_.arch)); multi_level_tiling.Init(&ir_schedule); ASSERT_EQ(multi_level_tiling.NumberApplicable(), 1); multi_level_tiling.ApplyRandomly(); VLOG(6) << "after MultiLevelTiling Expr:\n" << func_bodys[0]; // build ir::Module and debug source code - auto ir_module = BuildIRModule(ir_schedule); + auto ir_module = BuildIRModule(ir_schedule); auto source_code = GenSourceCode(ir_module); VLOG(6) << "scheduled source code:\n" << source_code; // execute and check precision CheckResult(GenExecutableKernel(ir_module), - GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, /* apply_manual_schedule */ true))), + GenExecutableKernel(BuildIRModule( + MakeIRSchedule(matmul_op, /* apply_manual_schedule */ true))), default_input_names, default_output_names, {{32, 32}, {32, 32}}, diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.cc index 9929e90393c8c..c35bed808a8e0 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.cc @@ -38,7 +38,8 @@ namespace cinn { namespace auto_schedule { -MultiLevelTiling::MultiLevelTiling(const common::Target& target, const Config& config) +MultiLevelTiling::MultiLevelTiling(const common::Target& target, + const Config& config) : AutoGenRule(target), config_(config) { for (int i = 0; i < config_.tile_struct.size(); ++i) { if (config_.tile_struct[i] == 'S') { @@ -51,25 +52,29 @@ MultiLevelTiling::MultiLevelTiling(const common::Target& target, const Config& c } } -bool MultiLevelTiling::MeetCondition(const ir::ScheduleBlockRealize& sche_block_realize) const { +bool MultiLevelTiling::MeetCondition( + const ir::ScheduleBlockRealize& sche_block_realize) const { return NeedsMultiLevelTiling(sche_block_realize); } RuleApplyType MultiLevelTiling::Init(ir::IRSchedule* ir_schedule) { - ir_schedule_ = ir_schedule; + ir_schedule_ = ir_schedule; all_block_realizes_ = ir_schedule_->GetAllBlocks(); applicable_indices_.clear(); num_applicable_ = 0; for (size_t i = 0; i < all_block_realizes_.size(); ++i) { - ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As(); - AnalyzeScheduleBlockReadWriteBuffer(sche_block_realize->schedule_block.As()); + ir::ScheduleBlockRealize* sche_block_realize = + all_block_realizes_[i].As(); + AnalyzeScheduleBlockReadWriteBuffer( + sche_block_realize->schedule_block.As()); if (MeetCondition(*sche_block_realize)) { ++num_applicable_; applicable_indices_.push_back(i); } } - return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; + return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules + : RuleApplyType::kCannotApply; } void MultiLevelTiling::Apply(int index) { @@ -77,12 +82,16 @@ void MultiLevelTiling::Apply(int index) { CHECK(num_applicable_ > 0 && applicable_indices_.size() == num_applicable_) << "MultiLevelTiling::Apply pre-condition doesn't meet"; CHECK(index >= 0 && num_applicable_ > index) - << "Invalid index for MultiLevelTiling::Apply, the index needs 0 <= index && index < NumberApplicable(), " - << "Currently index = " << index << ", NumberApplicable() = " << num_applicable_; + << "Invalid index for MultiLevelTiling::Apply, the index needs 0 <= " + "index && index < NumberApplicable(), " + << "Currently index = " << index + << ", NumberApplicable() = " << num_applicable_; int apply_index = applicable_indices_[index]; - std::string block_name = - all_block_realizes_[apply_index].As()->schedule_block.As()->name; + std::string block_name = all_block_realizes_[apply_index] + .As() + ->schedule_block.As() + ->name; Expr block_expr = all_block_realizes_[apply_index]; ApplyTiling(ir_schedule_, block_expr); block_expr = ir_schedule_->GetBlock(block_name); @@ -96,19 +105,24 @@ void MultiLevelTiling::Apply(int index) { std::string MultiLevelTiling::GetRuleName() const { return "MultiLevelTiling"; } -RuleApplyType MultiLevelTiling::AnalyseApplyType(SearchState state, const std::string& block_name) const { - Expr block_expr = state->ir_schedule.GetBlock(block_name); +RuleApplyType MultiLevelTiling::AnalyseApplyType( + SearchState state, const std::string& block_name) const { + Expr block_expr = state->ir_schedule.GetBlock(block_name); auto* block_realize = block_expr.As(); CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr; - AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As()); + AnalyzeScheduleBlockReadWriteBuffer( + block_realize->schedule_block.As()); - return NeedsMultiLevelTiling(*block_realize) ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; + return NeedsMultiLevelTiling(*block_realize) + ? RuleApplyType::kApplyAndPruneOtherRules + : RuleApplyType::kCannotApply; } -std::vector MultiLevelTiling::ApplyOnBlock(SearchState state, const std::string& block_name) { - SearchState new_state = state.Copy(); +std::vector MultiLevelTiling::ApplyOnBlock( + SearchState state, const std::string& block_name) { + SearchState new_state = state.Copy(); ir::IRSchedule* ir_sch = &new_state->ir_schedule; - Expr block_expr = ir_sch->GetBlock(block_name); + Expr block_expr = ir_sch->GetBlock(block_name); ApplyTiling(ir_sch, block_expr); block_expr = ir_sch->GetBlock(block_name); ApplyCacheRead(ir_sch, block_expr); @@ -119,14 +133,18 @@ std::vector MultiLevelTiling::ApplyOnBlock(SearchState state, const return {new_state}; } -void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { - ir::ScheduleBlockRealize* sche_block_realize = block_expr.As(); - ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As(); +void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, + ir::Expr& block_expr) { + ir::ScheduleBlockRealize* sche_block_realize = + block_expr.As(); + ir::ScheduleBlock* sche_block = + sche_block_realize->schedule_block.As(); tile_loops_.clear(); tile_loops_.resize(config_.tile_struct.size()); std::vector for_exprs = ir_schedule->GetLoops(block_expr); - VLOG(5) << "The number of loops to split in MultiLevelTiling is " << for_exprs.size(); + VLOG(5) << "The number of loops to split in MultiLevelTiling is " + << for_exprs.size(); for (int i = for_exprs.size() - 1; i >= 0; --i) { ir::For* ir_for = for_exprs[i].As(); VLOG(6) << "Applying Split for MultiLevelTiling on: " << Expr(ir_for); @@ -141,8 +159,10 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ int num_split = idx->size(); if (num_split > 1) { - std::vector tile_split_factor = ir_schedule->SamplePerfectTile(Expr(ir_for), num_split, 64); - std::vector splited = ir_schedule->Split(Expr(ir_for), tile_split_factor); + std::vector tile_split_factor = + ir_schedule->SamplePerfectTile(Expr(ir_for), num_split, 64); + std::vector splited = + ir_schedule->Split(Expr(ir_for), tile_split_factor); VLOG(6) << "Finish Split for MultiLevelTiling on above loop"; for (int j = 0; j < num_split; ++j) { tile_loops_[idx->at(j)].push_back(splited[j]); @@ -159,7 +179,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ for (int i = 0; i < for_exprs.size(); ++i) { loop_var_name_to_idx[for_exprs[i].As()->loop_var->name] = i; } - CHECK(loop_var_name_to_idx.size() == for_exprs.size()) << "Loops contain duplicate loop var names after split"; + CHECK(loop_var_name_to_idx.size() == for_exprs.size()) + << "Loops contain duplicate loop var names after split"; std::vector splited_loops; for (auto& t : tile_loops_) { @@ -173,7 +194,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ } Expr reordered_expr = ir_schedule->Reorder(splited_loops); - VLOG(5) << "Finish Reorder in MultiLevelTiling, now do Fuse and Binding on the main loop chain"; + VLOG(5) << "Finish Reorder in MultiLevelTiling, now do Fuse and Binding on " + "the main loop chain"; int num_binds = std::min(config_.bind_axis.size(), tile_loops_.size()); for (int i = 0; i < num_binds; ++i) { @@ -182,7 +204,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ for (int j = 0; j < for_exprs.size(); ++j) { loop_var_name_to_idx[for_exprs[j].As()->loop_var->name] = j; } - CHECK(loop_var_name_to_idx.size() == for_exprs.size()) << "Loops contain duplicate loop var names before Fusion"; + CHECK(loop_var_name_to_idx.size() == for_exprs.size()) + << "Loops contain duplicate loop var names before Fusion"; // Some loops extent may exceed the limited max factor (For example, // exceed the limit number of CUDA threads), here we check whether @@ -191,14 +214,14 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ // // If yes, we fuse those loops and bind the fused loop // If no, we bind the first loop whose extent is less than the factor. - int extent_prod = 1; + int extent_prod = 1; int first_idx_less_than_max_factor = -1; for (int j = 0; j < tile_loops_[i].size(); ++j) { const ir::For* tile_loop = tile_loops_[i][j].As(); CHECK(tile_loop) << "tiles store non For Expr"; - int idx = loop_var_name_to_idx[tile_loop->loop_var->name]; + int idx = loop_var_name_to_idx[tile_loop->loop_var->name]; tile_loops_[i][j] = for_exprs[idx]; - int extent = tile_loop->extent.as_int32(); // maybe int64? + int extent = tile_loop->extent.as_int32(); // maybe int64? extent_prod *= extent; if (first_idx_less_than_max_factor == -1 && extent <= max_factor_) { first_idx_less_than_max_factor = idx; @@ -209,7 +232,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ Expr fused = ir_schedule->Fuse(tile_loops_[i]); ir_schedule->Bind(fused, config_.bind_axis[i]); } else if (first_idx_less_than_max_factor != -1) { - ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor], config_.bind_axis[i]); + ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor], + config_.bind_axis[i]); } } @@ -229,13 +253,17 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ } } if (!other_loop_chain_schedule.defined()) { - LOG(WARNING) << "Has non-main loop chain, but not corresponding ScheduleBlock in MultiLevelTiling"; + LOG(WARNING) << "Has non-main loop chain, but not corresponding " + "ScheduleBlock in MultiLevelTiling"; continue; } std::string other_loop_schedule_name = - other_loop_chain_schedule.As()->schedule_block.As()->name; - VLOG(6) << "Found other_loop_schedule_name = " << other_loop_schedule_name; + other_loop_chain_schedule.As() + ->schedule_block.As() + ->name; + VLOG(6) << "Found other_loop_schedule_name = " + << other_loop_schedule_name; int fuse_index = 0; for (int i = 0; i < num_binds; ++i) { for_exprs = ir_schedule->GetLoops(other_loop_schedule_name); @@ -247,23 +275,26 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ // // If yes, we fuse those loops and bind the fused loop // If no, we bind the first loop whose extent is less than the factor. - int extent_prod = 1; + int extent_prod = 1; int first_idx_less_than_max_factor = -1; for (int j = 0; j < tile_loops_[i].size(); ++j) { - int extent = for_exprs[fuse_index + j].As()->extent.as_int32(); + int extent = + for_exprs[fuse_index + j].As()->extent.as_int32(); extent_prod *= extent; if (first_idx_less_than_max_factor == -1 && extent <= max_factor_) { first_idx_less_than_max_factor = fuse_index + j; } } if (extent_prod <= max_factor_) { - std::vector loops_to_fuse(for_exprs.begin() + fuse_index, - for_exprs.begin() + fuse_index + tile_loops_[i].size()); + std::vector loops_to_fuse( + for_exprs.begin() + fuse_index, + for_exprs.begin() + fuse_index + tile_loops_[i].size()); Expr fused = ir_schedule->Fuse(loops_to_fuse); ir_schedule->Bind(fused, config_.bind_axis[i]); fuse_index += 1; } else if (first_idx_less_than_max_factor != -1) { - ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor], config_.bind_axis[i]); + ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor], + config_.bind_axis[i]); fuse_index += tile_loops_[i].size(); } } @@ -272,10 +303,13 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ } } -void MultiLevelTiling::ApplyCacheRead(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { - ir::ScheduleBlockRealize* sch_block_realize = block_expr.As(); - ir::ScheduleBlock* sch_block = sch_block_realize->schedule_block.As(); - std::string block_name = sch_block->name; +void MultiLevelTiling::ApplyCacheRead(ir::IRSchedule* ir_schedule, + ir::Expr& block_expr) { + ir::ScheduleBlockRealize* sch_block_realize = + block_expr.As(); + ir::ScheduleBlock* sch_block = + sch_block_realize->schedule_block.As(); + std::string block_name = sch_block->name; // Analyze which buffers can be cached std::vector read_buffer_indexes; @@ -302,100 +336,125 @@ void MultiLevelTiling::ApplyCacheRead(ir::IRSchedule* ir_schedule, ir::Expr& blo } // 2.Do CacheRead and get the cache block - ir::Expr cache_block = ir_schedule->CacheRead(block_expr, read_buffer_index, config_.read_cache_memory_type); + ir::Expr cache_block = ir_schedule->CacheRead( + block_expr, read_buffer_index, config_.read_cache_memory_type); std::string cache_block_name = - cache_block.As()->schedule_block.As()->name; + cache_block.As() + ->schedule_block.As() + ->name; - std::string target_for_loop_name = loops.back().As()->loop_var->name; + std::string target_for_loop_name = + loops.back().As()->loop_var->name; // 3.Place the cache_block under target_for_loop // The original block expr is invalid after the CacheRead schedule, - // so we reacquire the block expr after the schedule according to the block name - block_expr = ir_schedule->GetBlock(block_name); + // so we reacquire the block expr after the schedule according to the + // block name + block_expr = ir_schedule->GetBlock(block_name); std::vector for_exprs = ir_schedule->GetLoops(block_expr); for (const Expr& for_expr : for_exprs) { - if (for_expr.As()->loop_var->name.find(target_for_loop_name) != std::string::npos) { + if (for_expr.As()->loop_var->name.find(target_for_loop_name) != + std::string::npos) { ir_schedule->ComputeAt(cache_block, for_expr, true); break; } } - // 4.Threads under the same block cooperative fetch data from global memory. - Expr new_cache_block = ir_schedule->GetBlock(cache_block_name); - auto cache_block_loops = ir_schedule->GetLoops(new_cache_block); + // 4.Threads under the same block cooperative fetch data from global + // memory. + Expr new_cache_block = ir_schedule->GetBlock(cache_block_name); + auto cache_block_loops = ir_schedule->GetLoops(new_cache_block); std::vector compute_at_extra_var = utils::Split( - absl::get( - new_cache_block.As()->schedule_block.As()->attrs.at( - "compute_at_extra_var")), + absl::get(new_cache_block.As() + ->schedule_block.As() + ->attrs.at("compute_at_extra_var")), ","); std::vector buffer_loops; // int nthreads = 1; for (const Expr& for_expr : cache_block_loops) { if (std::find(compute_at_extra_var.begin(), compute_at_extra_var.end(), - for_expr.As()->loop_var->name) != compute_at_extra_var.end()) { + for_expr.As()->loop_var->name) != + compute_at_extra_var.end()) { buffer_loops.push_back(for_expr); } } auto fused_buffer_loop = ir_schedule->Fuse(buffer_loops); - // TODO(BiynXu): Implement vectorize fetching data and pass in vector length - ir_schedule->Annotate(ir_schedule->GetBlock(cache_block_name), ir::attr::cooperative_process, 0); + // TODO(BiynXu): Implement vectorize fetching data and pass in vector + // length + ir_schedule->Annotate(ir_schedule->GetBlock(cache_block_name), + ir::attr::cooperative_process, + 0); } } } -void MultiLevelTiling::ApplyCacheWrite(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { - ir::Expr cache_block = ir_schedule->CacheWrite(block_expr, 0, config_.write_cache_memory_type); +void MultiLevelTiling::ApplyCacheWrite(ir::IRSchedule* ir_schedule, + ir::Expr& block_expr) { + ir::Expr cache_block = + ir_schedule->CacheWrite(block_expr, 0, config_.write_cache_memory_type); for (int level : config_.write_cache_levels) { const auto loops = tile_loops_.at(level - 1); if (loops.size() == 0) { continue; } - std::string target_for_loop_name = loops.back().As()->loop_var->name; - // Because the block name is changed in CacheWrite, we need to calculate the derived name - // according to the logic of CacheWrite and find the loop structure according to the derived name. + std::string target_for_loop_name = + loops.back().As()->loop_var->name; + // Because the block name is changed in CacheWrite, we need to calculate the + // derived name according to the logic of CacheWrite and find the loop + // structure according to the derived name. const std::string original_block_name = - block_expr.As()->schedule_block.As()->name; - const std::string derivative_block_name = - original_block_name + "_" + config_.write_cache_memory_type + "_temp_buffer"; + block_expr.As() + ->schedule_block.As() + ->name; + const std::string derivative_block_name = original_block_name + "_" + + config_.write_cache_memory_type + + "_temp_buffer"; std::vector for_exprs = ir_schedule->GetLoops(derivative_block_name); for (const Expr& for_expr : for_exprs) { - if (for_expr.As()->loop_var->name.find(target_for_loop_name) != std::string::npos) { - ir_schedule->ReverseComputeAt(ir_schedule->GetBlock(original_block_name), for_expr, true); + if (for_expr.As()->loop_var->name.find(target_for_loop_name) != + std::string::npos) { + ir_schedule->ReverseComputeAt( + ir_schedule->GetBlock(original_block_name), for_expr, true); } } - const std::string reduce_init_block_name = original_block_name + "__reduce_init"; - for_exprs = ir_schedule->GetLoops(derivative_block_name); + const std::string reduce_init_block_name = + original_block_name + "__reduce_init"; + for_exprs = ir_schedule->GetLoops(derivative_block_name); for (const Expr& for_expr : for_exprs) { - if (for_expr.As()->loop_var->name.find(target_for_loop_name) != std::string::npos && + if (for_expr.As()->loop_var->name.find(target_for_loop_name) != + std::string::npos && ir_schedule->HasBlock(reduce_init_block_name)) { - ir_schedule->SimpleComputeAt(ir_schedule->GetBlock(reduce_init_block_name), for_expr); + ir_schedule->SimpleComputeAt( + ir_schedule->GetBlock(reduce_init_block_name), for_expr); } } } } -const std::unordered_map MultiLevelTiling::kConfigs{ - {common::Target::Arch::NVGPU, - MultiLevelTiling::Config{ - /*bind_axis*/ std::vector{"blockIdx.x", "threadIdx.x"}, - /*tile_struct*/ std::string("SSSRRSRS"), - /*read_cache_memory_type*/ std::string("shared"), - /*read_cache_levels*/ std::vector{4}, - /*write_cache_memory_type*/ std::string("local"), - /*write_cache_levels*/ std::vector{3}, - }}, - {common::Target::Arch::X86, - MultiLevelTiling::Config{ - /*bind_axis*/ std::vector{}, - /*tile_struct*/ std::string("SSRSRS"), - /*read_cache_memory_type*/ std::string("local"), - /*read_cache_levels*/ std::vector{3}, - /*write_cache_memory_type*/ std::string("local"), - /*write_cache_levels*/ std::vector{2}, - }}}; +const std::unordered_map + MultiLevelTiling::kConfigs{ + {common::Target::Arch::NVGPU, + MultiLevelTiling::Config{ + /*bind_axis*/ std::vector{"blockIdx.x", + "threadIdx.x"}, + /*tile_struct*/ std::string("SSSRRSRS"), + /*read_cache_memory_type*/ std::string("shared"), + /*read_cache_levels*/ std::vector{4}, + /*write_cache_memory_type*/ std::string("local"), + /*write_cache_levels*/ std::vector{3}, + }}, + {common::Target::Arch::X86, + MultiLevelTiling::Config{ + /*bind_axis*/ std::vector{}, + /*tile_struct*/ std::string("SSRSRS"), + /*read_cache_memory_type*/ std::string("local"), + /*read_cache_levels*/ std::vector{3}, + /*write_cache_memory_type*/ std::string("local"), + /*write_cache_levels*/ std::vector{2}, + }}}; } // namespace auto_schedule } // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h index 7c54047d8a81b..fbbf3efd0bf60 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h @@ -72,9 +72,11 @@ class MultiLevelTiling : public AutoGenRule { // Returns true if sche_block_realize is applicable by MultiLevelTiling bool MeetCondition(const ir::ScheduleBlockRealize& sche_block_realize) const; - RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; + RuleApplyType AnalyseApplyType(SearchState state, + const std::string& block_name) const override; - std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override; + std::vector ApplyOnBlock(SearchState state, + const std::string& block_name) override; // Sample pair of integer type (a, b) such as a * b = extent template @@ -88,10 +90,10 @@ class MultiLevelTiling : public AutoGenRule { if (candidates.size() == 0) { return {1, T(extent)}; } - int index = rand() % candidates.size(); + int index = rand() % candidates.size(); std::vector pick = candidates[index]; if (rand() % 2 != 0) { - T tmp = pick[0]; + T tmp = pick[0]; pick[0] = pick[1]; pick[1] = tmp; } @@ -101,7 +103,8 @@ class MultiLevelTiling : public AutoGenRule { // Sample num_split integers whose product equals extent template std::vector SampleTileSplit(T extent, int num_split) const { - CHECK_GT(num_split, 0) << "num_split in SampleTileSplit must be greater than 0"; + CHECK_GT(num_split, 0) + << "num_split in SampleTileSplit must be greater than 0"; if (num_split == 1) { return {extent}; } @@ -109,7 +112,7 @@ class MultiLevelTiling : public AutoGenRule { if (num_split == 2) { return two_split; } - int half = num_split >> 1; + int half = num_split >> 1; std::vector result = SampleTileSplit(two_split[0], half); std::vector remind = SampleTileSplit(two_split[1], num_split - half); result.insert(result.end(), remind.begin(), remind.end()); diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc index ab67854e4cc08..620c775ffdbeb 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc @@ -48,11 +48,13 @@ TEST(MultiLevelTile, SampleSplitTwo) { Target target = common::DefaultHostTarget(); #endif - MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); + MultiLevelTiling multi_level_tiling( + target, MultiLevelTiling::kConfigs.at(target.arch)); for (int i = 0; i < 100; ++i) { - size_t number_to_split = rand() % 65535 + 2; // random number in [2, 2^16] - std::vector split = multi_level_tiling.SampleSplitTwo(number_to_split); + size_t number_to_split = rand() % 65535 + 2; // random number in [2, 2^16] + std::vector split = + multi_level_tiling.SampleSplitTwo(number_to_split); EXPECT_EQ(split.size(), 2UL); EXPECT_EQ(split[0] * split[1], number_to_split); } @@ -67,12 +69,14 @@ TEST(MultiLevelTile, SampleTileSplit) { Target target = common::DefaultHostTarget(); #endif - MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); + MultiLevelTiling multi_level_tiling( + target, MultiLevelTiling::kConfigs.at(target.arch)); for (int i = 0; i < 100; ++i) { - int number_to_split = rand() % 65535 + 2; // random number in [2, 2^16] - int split_size = rand() % 5 + 1; // random in [1, 5] - std::vector split = multi_level_tiling.SampleTileSplit(number_to_split, split_size); + int number_to_split = rand() % 65535 + 2; // random number in [2, 2^16] + int split_size = rand() % 5 + 1; // random in [1, 5] + std::vector split = + multi_level_tiling.SampleTileSplit(number_to_split, split_size); EXPECT_EQ(split.size(), static_cast(split_size)); int product = 1; for (int num : split) { @@ -102,21 +106,31 @@ TEST(MultiLevelTile, SimpleLoops) { poly::StageMap stages = CreateStages({C}); std::vector funcs = - lang::LowerVec("TestMultiLevelTile_SimpleLoops", stages, {C}, {}, {}, nullptr, target, true); + lang::LowerVec("TestMultiLevelTile_SimpleLoops", + stages, + {C}, + {}, + {}, + nullptr, + target, + true); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr before MultiLevelTiling: "; VLOG(6) << ast_expr; - MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); + MultiLevelTiling multi_level_tiling( + target, MultiLevelTiling::kConfigs.at(target.arch)); ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); SearchState state(ir_schedule, 0, {}); - EXPECT_EQ(multi_level_tiling.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(multi_level_tiling.Init(&ir_schedule), + RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1); multi_level_tiling.ApplyRandomly(); // ApplyOnBlock - EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"), + RuleApplyType::kApplyAndPruneOtherRules); auto new_states = multi_level_tiling.ApplyOnBlock(state, "C"); auto test_func = [](ir::IRSchedule* ir_sch) { @@ -152,26 +166,30 @@ TEST(MulitLevelTile, MatrixMultiply) { Var k(K.as_int32(), "reduce_axis_k"); ir::Tensor C = Compute( - {M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + {M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, +"C"); poly::StageMap stages = CreateStages({C}); std::vector funcs = - lang::LowerVec("TestMultiLevelTile_MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true); + lang::LowerVec("TestMultiLevelTile_MatrixMultiply", stages, {C}, {}, {}, +nullptr, target, true); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr before MultiLevelTiling: "; VLOG(6) << ast_expr; - MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); - ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); - SearchState state(ir_schedule, 0, {}); - EXPECT_EQ(multi_level_tiling.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules); + MultiLevelTiling multi_level_tiling(target, +MultiLevelTiling::kConfigs.at(target.arch)); ir::IRSchedule +ir_schedule(ir::ModuleExpr({ast_expr})); SearchState state(ir_schedule, 0, {}); + EXPECT_EQ(multi_level_tiling.Init(&ir_schedule), +RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1); multi_level_tiling.ApplyRandomly(); // ApplyOnBlock - EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules); - auto new_states = multi_level_tiling.ApplyOnBlock(state, "C"); + EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"), +RuleApplyType::kApplyAndPruneOtherRules); auto new_states = +multi_level_tiling.ApplyOnBlock(state, "C"); auto test_func = [](ir::IRSchedule* ir_sch) { std::vector exprs = ir_sch->GetModule().GetExprs(); @@ -194,25 +212,28 @@ class TestMultiLevelTiling : public TestAutoGenRuleBase { }; TEST_F(TestMultiLevelTiling, Matmul) { - default_input_names = {"X", "Y"}; - default_output_names = {"temp_matmul_out"}; - std::vector X_shape = {32, 32}; - std::vector Y_shape = {32, 32}; + default_input_names = {"X", "Y"}; + default_output_names = {"temp_matmul_out"}; + std::vector X_shape = {32, 32}; + std::vector Y_shape = {32, 32}; std::vector out_shape = {32, 32}; Initialize(common::DefaultNVGPUTarget()); - frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}}); - ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed); + frontend::Program matmul_op = + tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}}); + ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed); SearchState state(ir_schedule); VLOG(6) << "Original state:\n" << state->DebugString(); // Apply MultiLevelTiling - MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch)); + MultiLevelTiling multi_level_tiling( + target_, MultiLevelTiling::kConfigs.at(target_.arch)); EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]), RuleApplyType::kApplyAndPruneOtherRules); - auto new_states = multi_level_tiling.ApplyOnBlock(state, default_output_names[0]); + auto new_states = + multi_level_tiling.ApplyOnBlock(state, default_output_names[0]); VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString(); - std::string ir = GetIR(new_states[0]->ir_schedule); + std::string ir = GetIR(new_states[0]->ir_schedule); std::string expected_ir = R"ROC(Expr 0 { { ScheduleBlock(root) @@ -325,14 +346,15 @@ TEST_F(TestMultiLevelTiling, Matmul) { ASSERT_EQ(ir, expected_ir); // build ir::Module and debug source code - auto ir_module = BuildIRModule(new_states[0]->ir_schedule); + auto ir_module = BuildIRModule(new_states[0]->ir_schedule); auto source_code = GenSourceCode(ir_module); VLOG(6) << "scheduled source code:\n" << source_code; // execute and check precision CheckResult( GenExecutableKernel(ir_module), - GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))), + GenExecutableKernel(BuildIRModule(MakeIRSchedule( + matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))), default_input_names, default_output_names, {X_shape, Y_shape}, @@ -341,26 +363,29 @@ TEST_F(TestMultiLevelTiling, Matmul) { } TEST_F(TestMultiLevelTiling, ReduceSum) { - default_input_names = {"X"}; - default_output_names = {"var_0_tmp"}; - std::vector X_shape = {1, 16, 32}; - std::vector out_shape = {1, 16, 1}; + default_input_names = {"X"}; + default_output_names = {"var_0_tmp"}; + std::vector X_shape = {1, 16, 32}; + std::vector out_shape = {1, 16, 1}; std::vector reduce_dim = {2}; Initialize(common::DefaultNVGPUTarget()); frontend::Program reduce_sum_op = - tests::OpBuilder("reduce_sum").Build({{"X", X_shape}}, {{"dim", reduce_dim}, {"keep_dim", false}}); + tests::OpBuilder("reduce_sum") + .Build({{"X", X_shape}}, {{"dim", reduce_dim}, {"keep_dim", false}}); ir::IRSchedule ir_schedule = MakeIRSchedule(reduce_sum_op); SearchState state(ir_schedule); VLOG(6) << "Original state:\n" << state->DebugString(); // Apply MultiLevelTiling - MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch)); - // EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]), RuleApplyType::kCannotApply); + MultiLevelTiling multi_level_tiling( + target_, MultiLevelTiling::kConfigs.at(target_.arch)); + // EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, + // default_output_names[0]), RuleApplyType::kCannotApply); } TEST_F(TestMultiLevelTiling, Pool2d) { - default_input_names = {"input"}; + default_input_names = {"input"}; default_output_names = {"var_0"}; std::vector input_shape{2, 8, 16, 16}; std::vector output_shape{2, 8, 8, 8}; @@ -368,23 +393,24 @@ TEST_F(TestMultiLevelTiling, Pool2d) { std::vector ksize{3, 3}; std::vector strides{2, 2}; std::vector paddings{1, 1, 1, 1}; - bool ceil_mode = false; - bool exclusive = true; - bool global_pooling = false; - std::string data_format = "NCHW"; - bool adaptive = false; - std::string padding_algorithm = "EXPLICIT"; - frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build({{"input", input_shape}}, - {{"pool_type", pooling_type}, - {"kernel_size", ksize}, - {"stride_size", strides}, - {"padding_size", paddings}, - {"ceil_mode", ceil_mode}, - {"exclusive", exclusive}, - {"global_pooling", global_pooling}, - {"data_format", data_format}, - {"adaptive", adaptive}, - {"padding_algorithm", padding_algorithm}}); + bool ceil_mode = false; + bool exclusive = true; + bool global_pooling = false; + std::string data_format = "NCHW"; + bool adaptive = false; + std::string padding_algorithm = "EXPLICIT"; + frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build( + {{"input", input_shape}}, + {{"pool_type", pooling_type}, + {"kernel_size", ksize}, + {"stride_size", strides}, + {"padding_size", paddings}, + {"ceil_mode", ceil_mode}, + {"exclusive", exclusive}, + {"global_pooling", global_pooling}, + {"data_format", data_format}, + {"adaptive", adaptive}, + {"padding_algorithm", padding_algorithm}}); Initialize(common::DefaultNVGPUTarget()); ir::IRSchedule ir_schedule = MakeIRSchedule(pool2d_program, fixed_rand_seed); @@ -403,10 +429,11 @@ TEST_F(TestMultiLevelTiling, Pool2d) { MultiLevelTiling multi_level_tiling(target_, mlt_config); EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]), RuleApplyType::kApplyAndPruneOtherRules); - auto new_states = multi_level_tiling.ApplyOnBlock(state, default_output_names[0]); + auto new_states = + multi_level_tiling.ApplyOnBlock(state, default_output_names[0]); VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString(); - std::string ir = GetIR(new_states[0]->ir_schedule); + std::string ir = GetIR(new_states[0]->ir_schedule); std::string expected_ir = R"ROC(Expr 0 { { ScheduleBlock(root) @@ -529,19 +556,20 @@ Expr 1 { ASSERT_EQ(ir, expected_ir); // build ir::Module and debug source code - auto ir_module = BuildIRModule(new_states[0]->ir_schedule); + auto ir_module = BuildIRModule(new_states[0]->ir_schedule); auto source_code = GenSourceCode(ir_module); VLOG(6) << "scheduled source code:\n" << source_code; // execute and check precision - CheckResult(GenExecutableKernel(ir_module), - GenExecutableKernel( - BuildIRModule(MakeIRSchedule(pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))), - default_input_names, - default_output_names, - {input_shape}, - {output_shape}, - target_); + CheckResult( + GenExecutableKernel(ir_module), + GenExecutableKernel(BuildIRModule(MakeIRSchedule( + pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))), + default_input_names, + default_output_names, + {input_shape}, + {output_shape}, + target_); } } // namespace auto_schedule diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.cc index ca46f4b54940a..4ada86aec17e1 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.cc @@ -27,7 +27,7 @@ namespace auto_schedule { SkipRule::SkipRule(const common::Target& target) : AutoGenRule(target) {} RuleApplyType SkipRule::Init(ir::IRSchedule* ir_schedule) { - ir_schedule_ = ir_schedule; + ir_schedule_ = ir_schedule; num_applicable_ = 1; return RuleApplyType::kApply; } diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h index 837eadd6aafe2..41564a5202b70 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h @@ -34,11 +34,15 @@ class SkipRule : public AutoGenRule { std::string GetRuleName() const override; - RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override { + RuleApplyType AnalyseApplyType(SearchState state, + const std::string& block_name) const override { return RuleApplyType::kApply; } - std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override { return {state}; } + std::vector ApplyOnBlock( + SearchState state, const std::string& block_name) override { + return {state}; + } }; } // namespace auto_schedule diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc index 37e8f3eaa7a81..81c955916f4e7 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc @@ -52,8 +52,9 @@ TEST(SkipRule, Basic) { ir::Tensor C = Compute( {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); - poly::StageMap stages = CreateStages({C}); - std::vector funcs = lang::LowerVec("TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); + poly::StageMap stages = CreateStages({C}); + std::vector funcs = lang::LowerVec( + "TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr before SkipRule: "; @@ -69,7 +70,8 @@ TEST(SkipRule, Basic) { // ApplyOnBlock EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply); - std::vector states = skip_rule.ApplyOnBlock(state, "C"); + std::vector states = + skip_rule.ApplyOnBlock(state, "C"); auto test_func = [&ast_expr](ir::IRSchedule* ir_sch) { std::vector exprs = ir_sch->GetModule().GetExprs(); @@ -99,8 +101,9 @@ TEST(SkipRule, ApplyOnSpecificBlock) { ir::Tensor C = Compute( {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); - poly::StageMap stages = CreateStages({C}); - std::vector funcs = lang::LowerVec("TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); + poly::StageMap stages = CreateStages({C}); + std::vector funcs = lang::LowerVec( + "TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Expr before SkipRule: "; @@ -111,7 +114,8 @@ TEST(SkipRule, ApplyOnSpecificBlock) { SearchState state(ir_schedule, 0, {}); EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply); - std::vector states = skip_rule.ApplyOnBlock(state, "C"); + std::vector states = + skip_rule.ApplyOnBlock(state, "C"); std::vector exprs = states[0]->ir_schedule.GetModule().GetExprs(); EXPECT_EQ(exprs.size(), 1UL); diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc index 17601fc695340..c1ba5bb259bdc 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc @@ -42,26 +42,32 @@ using ::cinn::hlir::framework::Shape; using ::cinn::hlir::framework::Tensor; void TestAutoGenRuleBase::Initialize(const common::Target& target) { - target_ = target; + target_ = target; backend_compier_ = backends::Compiler::Create(target); } -ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(const frontend::Program& test_program, - utils::LinearRandomEngine::StateType rand_seed, - bool apply_manual_schedule) { +ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule( + const frontend::Program& test_program, + utils::LinearRandomEngine::StateType rand_seed, + bool apply_manual_schedule) { Context::Global().ResetNameId(); auto graph = std::make_shared(test_program, target_); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - LOG_IF(WARNING, graph->fusion_groups.size() > 1) << "Test Graph has more than 1 group"; - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); + LOG_IF(WARNING, graph->fusion_groups.size() > 1) + << "Test Graph has more than 1 group"; + auto& dtype_dict = + graph->GetMutableAttrs>( + "inferdtype"); + auto& shape_dict = graph->GetMutableAttrs< + absl::flat_hash_map>("infershape"); hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_); if (apply_manual_schedule) { lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front()); } else { - lowered_funcs_ = op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front()); + lowered_funcs_ = + op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front()); } CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty"; @@ -76,20 +82,22 @@ std::string TestAutoGenRuleBase::GetIR(const ir::IRSchedule& schedule) { const auto& exprs = schedule.GetModule().GetExprs(); std::stringstream module_stream; for (auto i = 0; i < exprs.size(); ++i) { - module_stream << "Expr " << i << " {\n" << exprs.at(i) << "\n} // end Expr " << i << "\n"; + module_stream << "Expr " << i << " {\n" + << exprs.at(i) << "\n} // end Expr " << i << "\n"; } return module_stream.str(); } ir::Module TestAutoGenRuleBase::BuildIRModule(const ir::IRSchedule& schedule) { auto&& updated_bodys = schedule.GetModule().GetExprs(); - CHECK_EQ(lowered_funcs_.size(), updated_bodys.size()) << "associated exprs size not equal"; + CHECK_EQ(lowered_funcs_.size(), updated_bodys.size()) + << "associated exprs size not equal"; ir::Module::Builder builder("test_bulder", this->target_); for (int i = 0; i < lowered_funcs_.size(); ++i) { - ir::Expr func_body = updated_bodys.at(i); + ir::Expr func_body = updated_bodys.at(i); const ir::LoweredFunc& ori_func = lowered_funcs_.at(i); - auto&& new_func = UpdateFuncWithNewBody(target_, ori_func, func_body); + auto&& new_func = UpdateFuncWithNewBody(target_, ori_func, func_body); builder.AddFunction(new_func); } @@ -102,20 +110,24 @@ std::string TestAutoGenRuleBase::GenSourceCode(const ir::Module& ir_module) { if (target_ == common::DefaultNVGPUTarget()) { codegen = std::make_unique(this->target_); } else { - codegen = std::make_unique(this->target_, CodeGenCX86::Feature::AVX512); + codegen = std::make_unique( + this->target_, CodeGenCX86::Feature::AVX512); } #else - codegen = std::make_unique(this->target_, CodeGenCX86::Feature::AVX512); + codegen = std::make_unique( + this->target_, CodeGenCX86::Feature::AVX512); #endif codegen->SetInlineBuiltinCodes(false); return codegen->Compile(ir_module, CodeGenC::OutputKind::CImpl); } -raw_func_type TestAutoGenRuleBase::GenExecutableKernel(const ir::Module& ir_module) { +raw_func_type TestAutoGenRuleBase::GenExecutableKernel( + const ir::Module& ir_module) { auto&& func_name = lowered_funcs_.front()->name; // Compile to machine code backend_compier_->Build(ir_module); - auto test_func_ptr = reinterpret_cast(backend_compier_->Lookup(func_name)); + auto test_func_ptr = reinterpret_cast( + backend_compier_->Lookup(func_name)); return test_func_ptr; } @@ -138,15 +150,19 @@ void MemoryCopy(const float* src, float* dst, int numel, std::string type) { } } -void AddDataToScope( - Scope* scope, const common::Target& target, float* data_ptr, std::string name, const std::vector& shape) { - auto* var = scope->Var(name); +void AddDataToScope(Scope* scope, + const common::Target& target, + float* data_ptr, + std::string name, + const std::vector& shape) { + auto* var = scope->Var(name); auto& tensor = absl::get(*var); CHECK(shape.size()) << "The size of shape can not be 0."; Shape cinn_shape(shape); tensor->Resize(cinn_shape); - auto* tgt_data_ptr = tensor->mutable_data(target); - std::string mem_cpy_type = target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost"; + auto* tgt_data_ptr = tensor->mutable_data(target); + std::string mem_cpy_type = + target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost"; MemoryCopy(data_ptr, tgt_data_ptr, cinn_shape.numel(), mem_cpy_type); } @@ -159,16 +175,20 @@ void CheckResult(raw_func_type test_func, const common::Target& target) { CHECK(input_names.size()) << "The number of inputs must be greater than 0."; CHECK(output_names.size()) << "The number of outputs must be greater than 0."; - CHECK_EQ(input_names.size(), input_shapes.size()) << "The quantity of input_names and input_shapes must be equal."; + CHECK_EQ(input_names.size(), input_shapes.size()) + << "The quantity of input_names and input_shapes must be equal."; CHECK_EQ(output_names.size(), output_shapes.size()) << "The quantity of output_names and output_shapes must be equal."; // Initialize data std::vector input_data_ptrs(input_names.size()); for (int i = 0; i < input_shapes.size(); ++i) { - int input_data_numel = - std::accumulate(input_shapes[i].begin(), input_shapes[i].end(), 1, [](int a, int b) { return a * b; }); - input_data_ptrs[i] = reinterpret_cast(malloc(input_data_numel * sizeof(float))); + int input_data_numel = std::accumulate( + input_shapes[i].begin(), input_shapes[i].end(), 1, [](int a, int b) { + return a * b; + }); + input_data_ptrs[i] = + reinterpret_cast(malloc(input_data_numel * sizeof(float))); for (int j = 0; j < input_data_numel; ++j) { input_data_ptrs[i][j] = (rand() * 1.f) / RAND_MAX; } @@ -177,24 +197,35 @@ void CheckResult(raw_func_type test_func, std::vector expected_output_data_ptrs(output_names.size()); std::vector output_data_numels(output_shapes.size()); for (int i = 0; i < output_shapes.size(); ++i) { - output_data_numels[i] = - std::accumulate(output_shapes[i].begin(), output_shapes[i].end(), 1, [](int a, int b) { return a * b; }); - test_output_data_ptrs[i] = reinterpret_cast(malloc(output_data_numels[i] * sizeof(float))); + output_data_numels[i] = std::accumulate( + output_shapes[i].begin(), output_shapes[i].end(), 1, [](int a, int b) { + return a * b; + }); + test_output_data_ptrs[i] = + reinterpret_cast(malloc(output_data_numels[i] * sizeof(float))); memset(test_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float)); - expected_output_data_ptrs[i] = reinterpret_cast(malloc(output_data_numels[i] * sizeof(float))); - memset(expected_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float)); + expected_output_data_ptrs[i] = + reinterpret_cast(malloc(output_data_numels[i] * sizeof(float))); + memset( + expected_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float)); } - auto launch_kernel_fn = [&](raw_func_type& raw_func, std::vector& output_data_ptrs) { + auto launch_kernel_fn = [&](raw_func_type& raw_func, + std::vector& output_data_ptrs) { // Initialize scope Scope scope; // Initialize input data in scope. for (int i = 0; i < input_names.size(); ++i) { - AddDataToScope(&scope, target, input_data_ptrs[i], input_names[i], input_shapes[i]); + AddDataToScope( + &scope, target, input_data_ptrs[i], input_names[i], input_shapes[i]); } // Initialize output data in scope. for (int i = 0; i < output_names.size(); ++i) { - AddDataToScope(&scope, target, output_data_ptrs[i], output_names[i], output_shapes[i]); + AddDataToScope(&scope, + target, + output_data_ptrs[i], + output_names[i], + output_shapes[i]); } // Create Instruction and run @@ -207,9 +238,12 @@ void CheckResult(raw_func_type test_func, // data for (int i = 0; i < output_names.size(); ++i) { - const float* result_ptr = scope.GetTensor(output_names[i])->data(); - std::string mem_cpy_type = target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost"; - MemoryCopy(result_ptr, output_data_ptrs[i], output_data_numels[i], mem_cpy_type); + const float* result_ptr = scope.GetTensor(output_names[i])->data(); + std::string mem_cpy_type = target == common::DefaultNVGPUTarget() + ? "DeviceToHost" + : "HostToHost"; + MemoryCopy( + result_ptr, output_data_ptrs[i], output_data_numels[i], mem_cpy_type); } }; @@ -220,7 +254,8 @@ void CheckResult(raw_func_type test_func, // Check result for (int i = 0; i < output_shapes.size(); ++i) { for (int j = 0; j < output_data_numels[i]; ++j) { - ASSERT_NEAR(test_output_data_ptrs[i][j], expected_output_data_ptrs[i][j], 1e-4); + ASSERT_NEAR( + test_output_data_ptrs[i][j], expected_output_data_ptrs[i][j], 1e-4); } } diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h index 4b5833aca1e13..9fc7c2ced18ee 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h @@ -47,15 +47,18 @@ class TestAutoGenRuleBase : public ::testing::Test { // Initialize context for specified target void Initialize(const common::Target& target); - // construct an ir::IRSchedule by lowering the specified for following AutoGenRule test - ir::IRSchedule MakeIRSchedule(const frontend::Program& test_program, - utils::LinearRandomEngine::StateType rand_seed = -1, - bool apply_manual_schedule = false); + // construct an ir::IRSchedule by lowering the specified for following + // AutoGenRule test + ir::IRSchedule MakeIRSchedule( + const frontend::Program& test_program, + utils::LinearRandomEngine::StateType rand_seed = -1, + bool apply_manual_schedule = false); // Get the IR of bodies in IRSchedule std::string GetIR(const ir::IRSchedule& schedule); - // build ir::Module from the original lowered funcs with their bodies updated by the schedule + // build ir::Module from the original lowered funcs with their bodies updated + // by the schedule ir::Module BuildIRModule(const ir::IRSchedule& schedule); // generate source code with the built ir module @@ -75,9 +78,12 @@ class TestAutoGenRuleBase : public ::testing::Test { * @params-2: Expected function pointer for comparison. * @params-3: Names of input data. * @params-4: Names of output data. - * @params-5: Shapes of the input data, each input corresponds to a std::vector. - * @params-6: Shapes of the output data, each output corresponds to a std::vector. - * @params-7: The Target expressing computing platform and architecture of the function to be tested. + * @params-5: Shapes of the input data, each input corresponds to a + * std::vector. + * @params-6: Shapes of the output data, each output corresponds to a + * std::vector. + * @params-7: The Target expressing computing platform and architecture of the + * function to be tested. * @return: void */ void CheckResult(raw_func_type test_func, diff --git a/paddle/cinn/auto_schedule/search_space/block_sampler.h b/paddle/cinn/auto_schedule/search_space/block_sampler.h index 148ae7c25691a..6a01098736cf3 100644 --- a/paddle/cinn/auto_schedule/search_space/block_sampler.h +++ b/paddle/cinn/auto_schedule/search_space/block_sampler.h @@ -26,24 +26,30 @@ namespace auto_schedule { class SearchState; -// Select the next block to be operated for SearchState during the search process +// Select the next block to be operated for SearchState during the search +// process class BlockSampler { public: /** - * @brief Create a BlockSampler with the specific strategy name and necessary construct parameters. + * @brief Create a BlockSampler with the specific strategy name and necessary + * construct parameters. * @param all_blocks All possible blocks to be sampled. - * @param default_remove_policy The default option to determine whether to delete the next block after selecting it. + * @param default_remove_policy The default option to determine whether to + * delete the next block after selecting it. * @param strategy The block sampling strategy. - * Currently, the available strategies are "traversal" and "probabilistic", - * where "traversal" means to select blocks one by one until all blocks are traversed, - * and "probabilistic" means randomly picking blocks according to the given distribution. - * @param weights Used for the probabilistic policy, giving each candidate a weight. + * Currently, the available strategies are "traversal" and + * "probabilistic", where "traversal" means to select blocks one by one until + * all blocks are traversed, and "probabilistic" means randomly picking blocks + * according to the given distribution. + * @param weights Used for the probabilistic policy, giving each candidate a + * weight. */ - static std::unique_ptr Make(const std::vector& all_blocks, - bool default_remove_policy = true, - const std::string& strategy = "traversal", - utils::LinearRandomEngine::StateType rand_seed = 0, - const std::vector& weights = {}); + static std::unique_ptr Make( + const std::vector& all_blocks, + bool default_remove_policy = true, + const std::string& strategy = "traversal", + utils::LinearRandomEngine::StateType rand_seed = 0, + const std::vector& weights = {}); // Return the name of sample strategy virtual const char* Name() const = 0; @@ -56,18 +62,22 @@ class BlockSampler { protected: // A BlockSampler object should be created with the static function Make() - BlockSampler(const std::vector& all_blocks, bool default_remove_policy); + BlockSampler(const std::vector& all_blocks, + bool default_remove_policy); // Select a block to apply rule - // The param remove is used to determine whether to delete the next block after selecting it, - // If remove == true, it will not be sampled in the future. + // The param remove is used to determine whether to delete the next block + // after selecting it, If remove == true, it will not be sampled in the + // future. virtual std::string NextBlock(bool remove) = 0; // The names of all blocks - // Because the Block Expr will be changed in the search process, the name is saved for indexing + // Because the Block Expr will be changed in the search process, the name is + // saved for indexing std::vector all_blocks_; - // The default policy to determine whether to delete the next block after selecting it. + // The default policy to determine whether to delete the next block after + // selecting it. bool default_remove_policy_; }; @@ -75,7 +85,8 @@ class BlockSampler { // witch means to select blocks one by one until all blocks are traversed. class TraversalBlockSampler : public BlockSampler { public: - TraversalBlockSampler(const std::vector& all_blocks, bool default_remove_policy) + TraversalBlockSampler(const std::vector& all_blocks, + bool default_remove_policy) : BlockSampler(all_blocks, default_remove_policy), cur_idx_(0) {} const char* Name() const override { return "traversal"; } @@ -96,7 +107,7 @@ class ProbabilisticBlockSampler : public BlockSampler { ProbabilisticBlockSampler(const std::vector& all_blocks, bool default_remove_policy, utils::LinearRandomEngine::StateType rand_seed = 0, - const std::vector& weights = {}); + const std::vector& weights = {}); const char* Name() const override { return "probabilistic"; } diff --git a/paddle/cinn/auto_schedule/search_space/block_sampler_test.cc b/paddle/cinn/auto_schedule/search_space/block_sampler_test.cc index f9430c66ac64f..98f5d4a67b0a9 100644 --- a/paddle/cinn/auto_schedule/search_space/block_sampler_test.cc +++ b/paddle/cinn/auto_schedule/search_space/block_sampler_test.cc @@ -24,7 +24,8 @@ namespace auto_schedule { std::vector CreateTestBlocks() { std::vector blocks; for (int i = 0; i < 3; ++i) { - ir::Expr block = ir::ScheduleBlock::Make({}, {}, {}, "block_" + std::to_string(i), ir::Expr()); + ir::Expr block = ir::ScheduleBlock::Make( + {}, {}, {}, "block_" + std::to_string(i), ir::Expr()); blocks.push_back(ir::ScheduleBlockRealize::Make({}, block)); } return blocks; @@ -32,9 +33,11 @@ std::vector CreateTestBlocks() { TEST(BlockSampler, Make) { std::vector mock_blocks = CreateTestBlocks(); - auto traversal_block_sampler = BlockSampler::Make(mock_blocks, true, "traversal"); + auto traversal_block_sampler = + BlockSampler::Make(mock_blocks, true, "traversal"); ASSERT_STREQ(traversal_block_sampler->Name(), "traversal"); - auto probabilistic_block_sampler = BlockSampler::Make(mock_blocks, true, "probabilistic"); + auto probabilistic_block_sampler = + BlockSampler::Make(mock_blocks, true, "probabilistic"); ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic"); } @@ -54,15 +57,17 @@ TEST(TraversalBlockSampler, NextBlock) { } TEST(ProbabilisticBlockSampler, NextBlock) { - std::vector blocks = CreateTestBlocks(); - auto probabilistic_block_sampler = BlockSampler::Make(blocks, false, "probabilistic", 0, {4, 2, 1}); + std::vector blocks = CreateTestBlocks(); + auto probabilistic_block_sampler = + BlockSampler::Make(blocks, false, "probabilistic", 0, {4, 2, 1}); std::string block_name; for (int i = 0; i < 20; ++i) { block_name = probabilistic_block_sampler->NextBlock(); VLOG(6) << "next block name: " << block_name; } - probabilistic_block_sampler = BlockSampler::Make(blocks, true, "probabilistic", 0, {4, 2, 1}); + probabilistic_block_sampler = + BlockSampler::Make(blocks, true, "probabilistic", 0, {4, 2, 1}); probabilistic_block_sampler->NextBlock(); probabilistic_block_sampler->NextBlock(); probabilistic_block_sampler->NextBlock(); diff --git a/paddle/cinn/auto_schedule/search_space/rule_sampler.h b/paddle/cinn/auto_schedule/search_space/rule_sampler.h index e92adcfb866b5..e46387a15b98f 100644 --- a/paddle/cinn/auto_schedule/search_space/rule_sampler.h +++ b/paddle/cinn/auto_schedule/search_space/rule_sampler.h @@ -30,20 +30,25 @@ class SearchState; class RuleSampler { public: /** - * @brief Create a RuleSampler with the specific strategy name and necessary construct parameters. + * @brief Create a RuleSampler with the specific strategy name and necessary + * construct parameters. * @param potential_rules All possible rules to be sampled. - * @param default_remove_policy The default option to determine whether to delete the next block after selecting it. + * @param default_remove_policy The default option to determine whether to + * delete the next block after selecting it. * @param strategy The rule sampling strategy. - * Currently, the available strategies are "traversal" and "probabilistic", - * where "traversal" means to select rules one by one until all rules are traversed, - * and "probabilistic" means randomly picking rules according to the given distribution. - * @param weights Used for the probabilistic policy, giving each candidate a weight. + * Currently, the available strategies are "traversal" and + * "probabilistic", where "traversal" means to select rules one by one until + * all rules are traversed, and "probabilistic" means randomly picking rules + * according to the given distribution. + * @param weights Used for the probabilistic policy, giving each candidate a + * weight. */ - static std::unique_ptr Make(const std::vector& potential_rules, - bool default_remove_policy = true, - const std::string& strategy = "traversal", - utils::LinearRandomEngine::StateType rand_seed = 0, - const std::vector& weights = {}); + static std::unique_ptr Make( + const std::vector& potential_rules, + bool default_remove_policy = true, + const std::string& strategy = "traversal", + utils::LinearRandomEngine::StateType rand_seed = 0, + const std::vector& weights = {}); // Return the name of sample strategy virtual const char* Name() const = 0; @@ -55,18 +60,21 @@ class RuleSampler { protected: // A RuleSampler object should be created with the static function Make() - RuleSampler(const std::vector& potential_rules, bool default_remove_policy) - : potential_rules_(&potential_rules), default_remove_policy_(default_remove_policy) {} + RuleSampler(const std::vector& potential_rules, + bool default_remove_policy) + : potential_rules_(&potential_rules), + default_remove_policy_(default_remove_policy) {} // Select a rule to apply. - // The param remove is used to determine whether to delete the next rule after selecting it, - // If remove == true, it will not be sampled in the future. + // The param remove is used to determine whether to delete the next rule after + // selecting it, If remove == true, it will not be sampled in the future. virtual AutoGenRule* NextRule(bool remove) = 0; // The pointer refers to all potential rules const std::vector* potential_rules_; - // The default policy to determine whether to delete the next rule after selecting it. + // The default policy to determine whether to delete the next rule after + // selecting it. bool default_remove_policy_; }; @@ -74,7 +82,8 @@ class RuleSampler { // witch means to select rules one by one until all rules are traversed. class TraversalRuleSampler : public RuleSampler { public: - TraversalRuleSampler(const std::vector& potential_rules, bool default_remove_policy) + TraversalRuleSampler(const std::vector& potential_rules, + bool default_remove_policy) : RuleSampler(potential_rules, default_remove_policy), cur_idx_(0) {} const char* Name() const override { return "traversal"; } @@ -95,7 +104,7 @@ class ProbabilisticRuleSampler : public RuleSampler { ProbabilisticRuleSampler(const std::vector& potential_rules, bool default_remove_policy, utils::LinearRandomEngine::StateType rand_seed = 0, - const std::vector& weights = {}); + const std::vector& weights = {}); const char* Name() const override { return "probabilistic"; } diff --git a/paddle/cinn/auto_schedule/search_space/rule_sampler_test.cc b/paddle/cinn/auto_schedule/search_space/rule_sampler_test.cc index 2d9ef7be94add..2c21477a1bc59 100644 --- a/paddle/cinn/auto_schedule/search_space/rule_sampler_test.cc +++ b/paddle/cinn/auto_schedule/search_space/rule_sampler_test.cc @@ -28,20 +28,23 @@ Target target = common::DefaultNVGPUTarget(); Target target = common::DefaultHostTarget(); #endif -std::vector GenerateTestRules() { return {new AutoUnroll(target), new SkipRule(target)}; } +std::vector GenerateTestRules() { + return {new AutoUnroll(target), new SkipRule(target)}; +} TEST(RuleSampler, Make) { std::vector rules = GenerateTestRules(); - auto traversal_block_sampler = RuleSampler::Make(rules, true, "traversal"); + auto traversal_block_sampler = RuleSampler::Make(rules, true, "traversal"); ASSERT_STREQ(traversal_block_sampler->Name(), "traversal"); - auto probabilistic_block_sampler = RuleSampler::Make(rules, true, "probabilistic"); + auto probabilistic_block_sampler = + RuleSampler::Make(rules, true, "probabilistic"); ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic"); } TEST(TraversalRuleSampler, NextRule) { std::vector rules = GenerateTestRules(); - auto traversal_rule_sampler = RuleSampler::Make(rules, true, "traversal"); - AutoGenRule* rule = traversal_rule_sampler->NextRule(); + auto traversal_rule_sampler = RuleSampler::Make(rules, true, "traversal"); + AutoGenRule* rule = traversal_rule_sampler->NextRule(); ASSERT_EQ("AutoUnroll", rule->GetRuleName()); rule = traversal_rule_sampler->NextRule(); ASSERT_EQ("SkipRule", rule->GetRuleName()); @@ -50,7 +53,7 @@ TEST(TraversalRuleSampler, NextRule) { ASSERT_EQ("AutoUnroll", rule->GetRuleName()); traversal_rule_sampler = RuleSampler::Make(rules, false, "traversal"); - rule = traversal_rule_sampler->NextRule(); + rule = traversal_rule_sampler->NextRule(); ASSERT_EQ("AutoUnroll", rule->GetRuleName()); rule = traversal_rule_sampler->NextRule(); ASSERT_EQ("AutoUnroll", rule->GetRuleName()); @@ -58,14 +61,16 @@ TEST(TraversalRuleSampler, NextRule) { TEST(ProbabilisticRuleSampler, NextRule) { std::vector rules = GenerateTestRules(); - auto probabilistic_rule_sampler = RuleSampler::Make(rules, false, "probabilistic", 0, {4, 1}); + auto probabilistic_rule_sampler = + RuleSampler::Make(rules, false, "probabilistic", 0, {4, 1}); AutoGenRule* rule; for (int i = 0; i < 20; ++i) { rule = probabilistic_rule_sampler->NextRule(); VLOG(6) << "next rule name: " << rule->GetRuleName(); } - probabilistic_rule_sampler = RuleSampler::Make(rules, true, "probabilistic", 0, {4, 1}); + probabilistic_rule_sampler = + RuleSampler::Make(rules, true, "probabilistic", 0, {4, 1}); probabilistic_rule_sampler->NextRule(); probabilistic_rule_sampler->NextRule(); ASSERT_EQ(nullptr, probabilistic_rule_sampler->NextRule()); diff --git a/paddle/cinn/auto_schedule/search_space/search_space.cc b/paddle/cinn/auto_schedule/search_space/search_space.cc index f2fd7b8618f9d..a0bc719269376 100644 --- a/paddle/cinn/auto_schedule/search_space/search_space.cc +++ b/paddle/cinn/auto_schedule/search_space/search_space.cc @@ -39,18 +39,23 @@ DECLARE_bool(auto_schedule_use_cost_model); namespace cinn { namespace auto_schedule { -SearchSpace::SearchSpace(const TuneTask& tune_task, utils::LinearRandomEngine::StateType rand_seed) - : tune_task_(tune_task), rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)) { +SearchSpace::SearchSpace(const TuneTask& tune_task, + utils::LinearRandomEngine::StateType rand_seed) + : tune_task_(tune_task), + rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)) { const auto& target = tune_task_.target; // initialize a set of rules and they are commonly used by all states // TODO(zhhsplendid): pass correct output names to AutoInline - // sketch_rules_.emplace_back(new AutoInline(target, tune_task_.output_names)); - sketch_rules_.emplace_back(new MultiLevelTiling(target, MultiLevelTiling::kConfigs.at(target.arch))); + // sketch_rules_.emplace_back(new AutoInline(target, + // tune_task_.output_names)); + sketch_rules_.emplace_back( + new MultiLevelTiling(target, MultiLevelTiling::kConfigs.at(target.arch))); sketch_rules_.emplace_back(new AutoUnroll(target)); sketch_rules_.emplace_back(new SkipRule(target)); } -SearchState SearchSpace::GetScheduleMutate(const SearchState& state, const ExprCostModel& cost_model) { +SearchState SearchSpace::GetScheduleMutate(const SearchState& state, + const ExprCostModel& cost_model) { bool has_manual_schedule = false; if (has_manual_schedule) { SearchState ret = ManualScheduleMutate(state); @@ -58,9 +63,11 @@ SearchState SearchSpace::GetScheduleMutate(const SearchState& state, const ExprC } SearchState ret = RandomScheduleMutate(state); if (FLAGS_auto_schedule_use_cost_model) { - ret->predicted_cost = cost_model.Predict(ret->ir_schedule.GetModule(), tune_task_.target); + ret->predicted_cost = + cost_model.Predict(ret->ir_schedule.GetModule(), tune_task_.target); } - VLOG(4) << JoinStatesDebugString("SearchSpace::GetScheduleMutate", {state}, /*verbose=*/VLOG_IS_ON(5)); + VLOG(4) << JoinStatesDebugString( + "SearchSpace::GetScheduleMutate", {state}, /*verbose=*/VLOG_IS_ON(5)); return ret; } @@ -77,9 +84,10 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) { SearchState ret(state); std::vector apply_types(ret->applicable_rules.size()); for (int idx = 0; idx != ret->applicable_rules.size(); ++idx) { - AutoGenRule* rule = ret->applicable_rules.at(idx); + AutoGenRule* rule = ret->applicable_rules.at(idx); RuleApplyType apply_type = rule->Init(&ret->ir_schedule); - VLOG(6) << "Evaluate rule:" << rule->GetRuleName() << "=" << static_cast(apply_type); + VLOG(6) << "Evaluate rule:" << rule->GetRuleName() << "=" + << static_cast(apply_type); apply_types[idx] = apply_type; if (apply_type != RuleApplyType::kCannotApply) { weight_to_rule_index[cur_weight] = idx; @@ -94,7 +102,8 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) { } // 3. Sample a schedule on the distribution - int sample_weighted_index = utils::SampleUniformInt(0, cur_weight, &rand_seed_); + int sample_weighted_index = + utils::SampleUniformInt(0, cur_weight, &rand_seed_); auto iter = weight_to_rule_index.upper_bound(sample_weighted_index); --iter; @@ -102,13 +111,15 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) { int sample_rule_index = iter->second; CHECK_LT(sample_rule_index, ret->applicable_rules.size()); AutoGenRule* sample_rule = ret->applicable_rules.at(sample_rule_index); - VLOG(7) << "Apply rule: " << sample_rule->GetRuleName() << " with index=" << sample_weighted_index - iter->first; + VLOG(7) << "Apply rule: " << sample_rule->GetRuleName() + << " with index=" << sample_weighted_index - iter->first; // 4. Apply the schedule change sample_rule->Apply(sample_weighted_index - iter->first); // 5. Remove the rule after applying it if (apply_types.at(sample_rule_index) != RuleApplyType::kCannotApply) { - ret->applicable_rules.erase(ret->applicable_rules.begin() + sample_rule_index); + ret->applicable_rules.erase(ret->applicable_rules.begin() + + sample_rule_index); } return ret; @@ -116,17 +127,20 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) { std::vector SearchSpace::InitSketchWithRandomStrategy(int num) { VLOG(5) << "SearchSpace::GetRandomInitialSketch with num=" << num; - ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), - utils::ForkRandomState(&rand_seed_)); + ir::IRSchedule init_schedule( + ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), + utils::ForkRandomState(&rand_seed_)); std::vector init_rules; - std::transform(sketch_rules_.begin(), sketch_rules_.end(), std::back_inserter(init_rules), [](const auto& rule) { - return rule.get(); - }); + std::transform(sketch_rules_.begin(), + sketch_rules_.end(), + std::back_inserter(init_rules), + [](const auto& rule) { return rule.get(); }); std::vector result; while (result.size() < num) { SearchState state(init_schedule, SearchState::NOT_INIT_COST, init_rules); for (int i = 0; i < init_sketch_random_depth_; ++i) { - VLOG(6) << "Generating random sketch with RandomScheduleMutate at depth: " << i; + VLOG(6) << "Generating random sketch with RandomScheduleMutate at depth: " + << i; state = RandomScheduleMutate(state); if (state->applicable_rules.empty()) { break; @@ -134,7 +148,9 @@ std::vector SearchSpace::InitSketchWithRandomStrategy(int num) { } VLOG(5) << JoinStatesDebugString( - "SearchSpace::GetRandomInitialSketch-New_Sketch", {state}, /*verbose=*/VLOG_IS_ON(6)); + "SearchSpace::GetRandomInitialSketch-New_Sketch", + {state}, + /*verbose=*/VLOG_IS_ON(6)); result.emplace_back(std::move(state)); } return result; @@ -142,24 +158,28 @@ std::vector SearchSpace::InitSketchWithRandomStrategy(int num) { std::vector SearchSpace::InitSketchWithRandomPrunedStrategy() { VLOG(5) << "SearchSpace::InitSketchWithRandomPrunedStrategy"; - ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), - utils::ForkRandomState(&rand_seed_)); - auto all_blocks = init_schedule.GetAllBlocks(); - auto block_sampler = BlockSampler::Make(all_blocks, true, "probabilistic", utils::ForkRandomState(&rand_seed_)); + ir::IRSchedule init_schedule( + ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), + utils::ForkRandomState(&rand_seed_)); + auto all_blocks = init_schedule.GetAllBlocks(); + auto block_sampler = BlockSampler::Make( + all_blocks, true, "probabilistic", utils::ForkRandomState(&rand_seed_)); std::vector init_rules; - std::transform(sketch_rules_.begin(), sketch_rules_.end() - 1, std::back_inserter(init_rules), [](const auto& rule) { - return rule.get(); - }); + std::transform(sketch_rules_.begin(), + sketch_rules_.end() - 1, + std::back_inserter(init_rules), + [](const auto& rule) { return rule.get(); }); CHECK(init_rules.size() > 0) << "number of init rules cannot be 0"; SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {}); std::vector states_buf1{init_state}, states_buf2; - std::vector* p_states_cur = &states_buf1; + std::vector* p_states_cur = &states_buf1; std::vector* p_states_next = &states_buf2; - int total_steps = 0, steps; + int total_steps = 0, steps; std::string block_name; - while ("" != (block_name = block_sampler->NextBlock()) && total_steps < init_sketch_random_depth_) { + while ("" != (block_name = block_sampler->NextBlock()) && + total_steps < init_sketch_random_depth_) { steps = utils::SampleUniformInt(1, init_rules.size() + 1, &rand_seed_); if (total_steps + steps > init_sketch_random_depth_) { steps = init_sketch_random_depth_ - total_steps; @@ -167,51 +187,66 @@ std::vector SearchSpace::InitSketchWithRandomPrunedStrategy() { total_steps += steps; p_states_next->clear(); for (const auto& state : *p_states_cur) { - auto rule_sampler = RuleSampler::Make(init_rules, true, "probabilistic", utils::ForkRandomState(&rand_seed_)); - auto new_states = ApplySketchRule(state, block_name, rule_sampler.get(), steps, false, 1); - p_states_next->insert(p_states_next->end(), new_states.begin(), new_states.end()); + auto rule_sampler = + RuleSampler::Make(init_rules, + true, + "probabilistic", + utils::ForkRandomState(&rand_seed_)); + auto new_states = ApplySketchRule( + state, block_name, rule_sampler.get(), steps, false, 1); + p_states_next->insert( + p_states_next->end(), new_states.begin(), new_states.end()); } std::swap(p_states_cur, p_states_next); } VLOG(5) << JoinStatesDebugString( - "SearchSpace::InitSketchWithRandomPrunedStrategy", *p_states_cur, /*verbose=*/VLOG_IS_ON(6)); + "SearchSpace::InitSketchWithRandomPrunedStrategy", + *p_states_cur, + /*verbose=*/VLOG_IS_ON(6)); return *p_states_cur; } std::vector SearchSpace::InitSketchWithRulePrunedStrategy() { VLOG(5) << "SearchSpace::InitSketchWithRulePrunedStrategy"; - ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), - utils::ForkRandomState(&rand_seed_)); + ir::IRSchedule init_schedule( + ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), + utils::ForkRandomState(&rand_seed_)); auto all_blocks = init_schedule.GetAllBlocks(); std::reverse(all_blocks.begin(), all_blocks.end()); auto block_sampler = BlockSampler::Make(all_blocks, true, "traversal"); std::vector init_rules; - std::transform(sketch_rules_.begin(), sketch_rules_.end() - 1, std::back_inserter(init_rules), [](const auto& rule) { - return rule.get(); - }); + std::transform(sketch_rules_.begin(), + sketch_rules_.end() - 1, + std::back_inserter(init_rules), + [](const auto& rule) { return rule.get(); }); CHECK(init_rules.size() > 0) << "number of init rules cannot be 0"; SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {}); std::vector states_buf1{init_state}, states_buf2; - std::vector* p_states_cur = &states_buf1; + std::vector* p_states_cur = &states_buf1; std::vector* p_states_next = &states_buf2; std::string block_name; while ("" != (block_name = block_sampler->NextBlock())) { p_states_next->clear(); for (const auto& state : *p_states_cur) { auto rule_sampler = RuleSampler::Make(init_rules, true, "traversal"); - auto new_states = ApplySketchRule(state, block_name, rule_sampler.get(), 0, true); - p_states_next->insert(p_states_next->end(), new_states.begin(), new_states.end()); + auto new_states = + ApplySketchRule(state, block_name, rule_sampler.get(), 0, true); + p_states_next->insert( + p_states_next->end(), new_states.begin(), new_states.end()); } std::swap(p_states_cur, p_states_next); } VLOG(5) << JoinStatesDebugString( - "SearchSpace::InitSketchWithRulePrunedStrategy", *p_states_cur, /*verbose=*/VLOG_IS_ON(6)); + "SearchSpace::InitSketchWithRulePrunedStrategy", + *p_states_cur, + /*verbose=*/VLOG_IS_ON(6)); return *p_states_cur; } -std::vector SearchSpace::GenerateSketches(int num, const std::string& strategy) { +std::vector SearchSpace::GenerateSketches( + int num, const std::string& strategy) { VLOG(4) << "SearchSpace::GenerateSketches with num = " << num; if (strategy == "random") { @@ -239,28 +274,33 @@ std::vector SearchSpace::GenerateSketches(int num, const std::strin } } } - VLOG(4) << JoinStatesDebugString("SearchSpace::GenerateSketches", result, /*verbose=*/VLOG_IS_ON(5)); + VLOG(4) << JoinStatesDebugString( + "SearchSpace::GenerateSketches", result, /*verbose=*/VLOG_IS_ON(5)); return result; } -std::vector SearchSpace::ApplySketchRule(const SearchState& state, - const std::string& block_name, - RuleSampler* rule_sampler, - int steps, - bool prune_by_rule, - double prune_probability) { +std::vector SearchSpace::ApplySketchRule( + const SearchState& state, + const std::string& block_name, + RuleSampler* rule_sampler, + int steps, + bool prune_by_rule, + double prune_probability) { std::list layer{state}; int step = 0; AutoGenRule* rule; - // After determining a SearchState and a block, each rule has two possibilities: apply and not apply. - // In all transfer spaces, select a rule at each step, and collect all possible new states arrived by apply and not - // apply. This forms a tree, and we can use rule pruning or random pruning to reduce the number of sketches. + // After determining a SearchState and a block, each rule has two + // possibilities: apply and not apply. In all transfer spaces, select a rule + // at each step, and collect all possible new states arrived by apply and not + // apply. This forms a tree, and we can use rule pruning or random pruning to + // reduce the number of sketches. VLOG(6) << "Collect the states of all transfers within steps: " << steps; while ((step++ < steps || steps == 0) && (rule = rule_sampler->NextRule())) { VLOG(7) << "step = " << step << ", rule: " << rule->GetRuleName(); std::list new_states; int id = 0; - for (std::list::iterator iter = layer.begin(); iter != layer.end();) { + for (std::list::iterator iter = layer.begin(); + iter != layer.end();) { // Some rules will reduce the number of blocks, such as AutoInline, // so we need to check whether the SearchState still has the block. if (!(*iter)->ir_schedule.HasBlock(block_name)) { @@ -268,21 +308,26 @@ std::vector SearchSpace::ApplySketchRule(const SearchState& state, continue; } auto type = rule->AnalyseApplyType(*iter, block_name); - VLOG(7) << "At SearchState " << ++id - << ", apply type = " << static_cast::type>(type); + VLOG(7) + << "At SearchState " << ++id << ", apply type = " + << static_cast::type>( + type); // if cannot apply the rule, skip it if (type == RuleApplyType::kCannotApply) { ++iter; continue; } - // if can apply the rule, apply it and determine whether to prune the branch that do not apply - std::vector tmp_states = rule->ApplyOnBlock(*iter, block_name); + // if can apply the rule, apply it and determine whether to prune the + // branch that do not apply + std::vector tmp_states = + rule->ApplyOnBlock(*iter, block_name); new_states.insert(new_states.end(), tmp_states.begin(), tmp_states.end()); bool need_prune = false; if (prune_by_rule) { need_prune = (type == RuleApplyType::kApplyAndPruneOtherRules); } else { - need_prune = (utils::SampleUniformDouble(0, 1, &rand_seed_) < prune_probability); + need_prune = + (utils::SampleUniformDouble(0, 1, &rand_seed_) < prune_probability); } if (need_prune) { iter = layer.erase(iter); @@ -290,10 +335,12 @@ std::vector SearchSpace::ApplySketchRule(const SearchState& state, ++iter; } } - VLOG(7) << "apply on block: " << block_name << ", generate " << new_states.size() << " new states at step " << step; + VLOG(7) << "apply on block: " << block_name << ", generate " + << new_states.size() << " new states at step " << step; layer.splice(layer.end(), std::move(new_states)); } - VLOG(6) << "apply on block: " << block_name << ", generate " << layer.size() - 1 << " more states at all"; + VLOG(6) << "apply on block: " << block_name << ", generate " + << layer.size() - 1 << " more states at all"; return std::vector(layer.begin(), layer.end()); } diff --git a/paddle/cinn/auto_schedule/search_space/search_space.h b/paddle/cinn/auto_schedule/search_space/search_space.h index 4463fa82cfc0a..4a7e0632729ee 100644 --- a/paddle/cinn/auto_schedule/search_space/search_space.h +++ b/paddle/cinn/auto_schedule/search_space/search_space.h @@ -40,24 +40,31 @@ namespace auto_schedule { */ class SearchSpace { public: - SearchSpace(const TuneTask& tune_task, utils::LinearRandomEngine::StateType rand_seed = -1); + SearchSpace(const TuneTask& tune_task, + utils::LinearRandomEngine::StateType rand_seed = -1); // Sketch mutate, returns the mutated ModuleExpr and estimited cost - virtual SearchState GetScheduleMutate(const SearchState& state, const ExprCostModel& cost_model); + virtual SearchState GetScheduleMutate(const SearchState& state, + const ExprCostModel& cost_model); /** * \brief Generate sketch as initial population of evolutionary search. * @param num The number of sketches to generate. * @param strategy The strategy to generate sketchs, - * Current optional strategies are "rule_prune" or "random_prune" or "random". - * - "rule_prune": will use rules to prune and generate sketches as efficiently as possible. - * - "random_prune": will use the new interface ApplySketchRules() to simulate the random generation of sketches, - * and supports the function of a rule returning multiple SearchStates and random pruning by probability. - * - "random": will randomly select a block and a rule to apply and repeat this step several times, - * however, each rule can only be used on one SearchState at most once. + * Current optional strategies are "rule_prune" or "random_prune" or + * "random". + * - "rule_prune": will use rules to prune and generate sketches as + * efficiently as possible. + * - "random_prune": will use the new interface ApplySketchRules() to simulate + * the random generation of sketches, and supports the function of a rule + * returning multiple SearchStates and random pruning by probability. + * - "random": will randomly select a block and a rule to apply and repeat + * this step several times, however, each rule can only be used on one + * SearchState at most once. * @return Generated sketchs. */ - virtual std::vector GenerateSketches(int num, const std::string& strategy); + virtual std::vector GenerateSketches( + int num, const std::string& strategy); private: // TODO(zhhsplendid): mutate by manual schedule. @@ -69,20 +76,24 @@ class SearchSpace { // Generate num sketchs, each with several rounds of SketchMutate std::vector InitSketchWithRandomStrategy(int num); - // Generate sketch pruned randomly as initial population of evolutionary search + // Generate sketch pruned randomly as initial population of evolutionary + // search std::vector InitSketchWithRandomPrunedStrategy(); - // Generate sketch pruned by rules as initial population of evolutionary search + // Generate sketch pruned by rules as initial population of evolutionary + // search std::vector InitSketchWithRulePrunedStrategy(); /** - * @brief Collect the new states that may be transferred to after applying several rules on a block from a certain - * state. + * @brief Collect the new states that may be transferred to after applying + * several rules on a block from a certain state. * @param state Starting point of state transition. * @param block_name Name of the block to apply the rules to. - * @param rule_sampler Sampler that samples the new rule to apply on the block. + * @param rule_sampler Sampler that samples the new rule to apply on the + * block. * @param steps Number of steps to apply the rule. - * @param prune_by_rule If true, prune the state transition tree by rule, otherwise prune randomly. + * @param prune_by_rule If true, prune the state transition tree by rule, + * otherwise prune randomly. * @param prune_probability Pruning probability of random pruning. */ std::vector ApplySketchRule(const SearchState& state, diff --git a/paddle/cinn/auto_schedule/search_space/search_state.cc b/paddle/cinn/auto_schedule/search_space/search_state.cc index a50d006f61a2c..5812a6e936a8c 100644 --- a/paddle/cinn/auto_schedule/search_space/search_state.cc +++ b/paddle/cinn/auto_schedule/search_space/search_state.cc @@ -29,21 +29,26 @@ namespace cinn { namespace auto_schedule { -SearchState::SearchState(ir::IRSchedule ir_sch, float cost, const std::vector& rules) +SearchState::SearchState(ir::IRSchedule ir_sch, + float cost, + const std::vector& rules) : common::Shared<_SearchState_>(common::make_shared<_SearchState_>()) { - auto* state = get(); - state->ir_schedule = std::move(ir_sch); + auto* state = get(); + state->ir_schedule = std::move(ir_sch); state->applicable_rules = rules; - state->predicted_cost = cost; + state->predicted_cost = cost; } -SearchState SearchState::Copy() const { return SearchState((*this)->ir_schedule, (*this)->predicted_cost, {}); } +SearchState SearchState::Copy() const { + return SearchState((*this)->ir_schedule, (*this)->predicted_cost, {}); +} std::string _SearchState_::DebugString() const { const auto& exprs = ir_schedule.GetModule().GetExprs(); std::stringstream module_stream; for (auto i = 0; i < exprs.size(); ++i) { - module_stream << "Expr " << i << " {\n" << exprs.at(i) << "\n} // end Expr"; + module_stream << "Expr " << i << " {\n" + << exprs.at(i) << "\n} // end Expr"; } const char* fmt_str = R"ROC( @@ -55,8 +60,10 @@ ScheduleDesc { } // end ScheduleDesc predicted_cost: %f)ROC"; - return utils::StringFormat( - fmt_str, module_stream.str().c_str(), ir_schedule.GetTraceDesc().DebugString().c_str(), predicted_cost); + return utils::StringFormat(fmt_str, + module_stream.str().c_str(), + ir_schedule.GetTraceDesc().DebugString().c_str(), + predicted_cost); } bool operator<(const SearchState& left, const SearchState& right) { @@ -94,7 +101,7 @@ class IrNodesStructuralHash : public DfsWithExprsFields { static decltype(ir::kIrNodeTyReprs) Node2Name = ir::kIrNodeTyReprs; if (!expr->defined()) return; auto type_code = static_cast(expr->node_type()); - hash_key_ = utils::HashCombine(hash_key_, type_code); + hash_key_ = utils::HashCombine(hash_key_, type_code); DfsWithExprsFields::Visit(expr); } @@ -111,7 +118,7 @@ class IrNodesStructuralHash : public DfsWithExprsFields { }; size_t SearchStateHash::operator()(const SearchState& s) const { - size_t hash_key = 0; + size_t hash_key = 0; const auto& exprs = s->ir_schedule.GetModule().GetExprs(); for (auto&& expr : exprs) { hash_key = IrNodesStructuralHash(hash_key)(&expr); @@ -119,7 +126,8 @@ size_t SearchStateHash::operator()(const SearchState& s) const { return hash_key; } -bool SearchStateEqual::operator()(const SearchState& lhs, const SearchState& rhs) const { +bool SearchStateEqual::operator()(const SearchState& lhs, + const SearchState& rhs) const { const auto& lhs_exprs = lhs->ir_schedule.GetModule().GetExprs(); const auto& rhs_exprs = rhs->ir_schedule.GetModule().GetExprs(); // compare exprs size firstly @@ -127,20 +135,24 @@ bool SearchStateEqual::operator()(const SearchState& lhs, const SearchState& rhs // compare every expr one by one with ir::IrEqualVisitor for (int i = 0; i < lhs_exprs.size(); ++i) { - ir::IrEqualVisitor compartor(/*allow_name_suffix_diff=*/true); // ignore suffix difference in name + ir::IrEqualVisitor compartor( + /*allow_name_suffix_diff=*/true); // ignore suffix difference in name if (!compartor.Compare(lhs_exprs[i], rhs_exprs[i])) return false; } return true; } -std::string JoinStatesDebugString(const std::string& title, const std::vector& states, bool verbose) { +std::string JoinStatesDebugString(const std::string& title, + const std::vector& states, + bool verbose) { std::stringstream ss; ss << title << " states size:" << states.size() << "\n"; SearchStateHash state_hasher; for (size_t i = 0; i < states.size(); ++i) { uint64_t hash_key = state_hasher(states[i]); if (verbose) { - ss << "\tState-" << i << " hash:" << hash_key << "\t content:------>" << states[i]->DebugString() << "\n<------"; + ss << "\tState-" << i << " hash:" << hash_key << "\t content:------>" + << states[i]->DebugString() << "\n<------"; } else { ss << "\tState-" << i << " hash:" << hash_key << "\n"; } diff --git a/paddle/cinn/auto_schedule/search_space/search_state.h b/paddle/cinn/auto_schedule/search_space/search_state.h index f180b2d508452..505c4967b6bb5 100644 --- a/paddle/cinn/auto_schedule/search_space/search_state.h +++ b/paddle/cinn/auto_schedule/search_space/search_state.h @@ -35,7 +35,9 @@ class SearchState : public common::Shared<_SearchState_> { public: SearchState() = default; // create a new SearchState - explicit SearchState(ir::IRSchedule ir_sch, float cost = NOT_INIT_COST, const std::vector& rules = {}); + explicit SearchState(ir::IRSchedule ir_sch, + float cost = NOT_INIT_COST, + const std::vector& rules = {}); // Constant standing for a cost not being initialized static constexpr float NOT_INIT_COST = std::numeric_limits::max(); @@ -62,12 +64,14 @@ struct _SearchState_ : public common::Object { static constexpr char* __type_info__ = "auto_schedule_state"; }; -// SearchStateHash hash functor that visits every AST node and combine their hash of node_type in dfs order +// SearchStateHash hash functor that visits every AST node and combine their +// hash of node_type in dfs order struct SearchStateHash { size_t operator()(const SearchState& s) const; }; -// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST struct and fields +// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST +// struct and fields struct SearchStateEqual { bool operator()(const SearchState& lhs, const SearchState& rhs) const; }; diff --git a/paddle/cinn/auto_schedule/search_space/search_state_test.cc b/paddle/cinn/auto_schedule/search_space/search_state_test.cc index f0e09ebb8de32..61547d228302f 100644 --- a/paddle/cinn/auto_schedule/search_space/search_state_test.cc +++ b/paddle/cinn/auto_schedule/search_space/search_state_test.cc @@ -36,15 +36,34 @@ TEST(TestSearchState, SearchStateHash_Equal) { {M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); cinn::common::Context::Global().ResetNameId(); - auto a_plus_const_funcs_1 = - lang::LowerVec("A_plus_const", poly::CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); + auto a_plus_const_funcs_1 = lang::LowerVec("A_plus_const", + poly::CreateStages({A, B}), + {A, B}, + {}, + {}, + nullptr, + target, + true); cinn::common::Context::Global().ResetNameId(); - auto a_plus_const_funcs_2 = - lang::LowerVec("A_plus_const", poly::CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); + auto a_plus_const_funcs_2 = lang::LowerVec("A_plus_const", + poly::CreateStages({A, B}), + {A, B}, + {}, + {}, + nullptr, + target, + true); cinn::common::Context::Global().ResetNameId(); - auto a_plus_b_funcs = lang::LowerVec("A_plus_B", poly::CreateStages({A, C}), {A, C}, {}, {}, nullptr, target, true); + auto a_plus_b_funcs = lang::LowerVec("A_plus_B", + poly::CreateStages({A, C}), + {A, C}, + {}, + {}, + nullptr, + target, + true); std::string a_plus_const_funcs_1_str = R"ROC(function A_plus_const (_A, _B) { @@ -114,19 +133,25 @@ TEST(TestSearchState, SearchStateHash_Equal) { })ROC"; ASSERT_EQ(a_plus_const_funcs_1.size(), 1); - EXPECT_EQ(a_plus_const_funcs_1_str, utils::GetStreamCnt(a_plus_const_funcs_1.front())); + EXPECT_EQ(a_plus_const_funcs_1_str, + utils::GetStreamCnt(a_plus_const_funcs_1.front())); ASSERT_EQ(a_plus_const_funcs_2.size(), 1); - EXPECT_EQ(a_plus_const_funcs_2_str, utils::GetStreamCnt(a_plus_const_funcs_2.front())); + EXPECT_EQ(a_plus_const_funcs_2_str, + utils::GetStreamCnt(a_plus_const_funcs_2.front())); ASSERT_EQ(a_plus_b_funcs.size(), 1); EXPECT_EQ(a_plus_b_funcs_str, utils::GetStreamCnt(a_plus_b_funcs.front())); - SearchState a_plus_const_state1(ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_1.front()->body}))); - SearchState a_plus_const_state2(ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_2.front()->body}))); - SearchState a_plus_b_state(ir::IRSchedule(ir::ModuleExpr({a_plus_b_funcs.front()->body}))); + SearchState a_plus_const_state1( + ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_1.front()->body}))); + SearchState a_plus_const_state2( + ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_2.front()->body}))); + SearchState a_plus_b_state( + ir::IRSchedule(ir::ModuleExpr({a_plus_b_funcs.front()->body}))); SearchStateHash hash_functor; SearchStateEqual equal_functor; - ASSERT_EQ(hash_functor(a_plus_const_state1), hash_functor(a_plus_const_state2)); + ASSERT_EQ(hash_functor(a_plus_const_state1), + hash_functor(a_plus_const_state2)); ASSERT_TRUE(equal_functor(a_plus_const_state1, a_plus_const_state2)); ASSERT_NE(hash_functor(a_plus_const_state1), hash_functor(a_plus_b_state)); ASSERT_FALSE(equal_functor(a_plus_const_state1, a_plus_b_state)); diff --git a/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc index fb75161ff136a..d139cc4c1d309 100644 --- a/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc +++ b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc @@ -41,17 +41,19 @@ DECLARE_bool(auto_schedule_use_cost_model); namespace cinn { namespace auto_schedule { -EvolutionarySearch::EvolutionarySearch(const TuneTask& tune_task, - const ExprCostModel& cost_model, - Database* database, - utils::LinearRandomEngine::StateType rand_seed, - const std::vector>& mutate_rules) +EvolutionarySearch::EvolutionarySearch( + const TuneTask& tune_task, + const ExprCostModel& cost_model, + Database* database, + utils::LinearRandomEngine::StateType rand_seed, + const std::vector>& mutate_rules) : tune_task_(tune_task), cost_model_(cost_model), database_(database), rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)), mutators_(mutate_rules) { - search_space_ = std::make_unique(tune_task, utils::ForkRandomState(&rand_seed_)); + search_space_ = std::make_unique( + tune_task, utils::ForkRandomState(&rand_seed_)); if (mutators_.empty()) { mutators_.push_back(std::make_tuple("mutate_tile_size", 1.0)); } @@ -59,7 +61,8 @@ EvolutionarySearch::EvolutionarySearch(const TuneTask& tune_task, for (const auto& mutator : mutators_) { if (std::get<1>(mutator) > 0) { accum_weight += std::get<1>(mutator); - weighted_mutators_.insert(std::make_pair(accum_weight, MutateRule::Make(std::get<0>(mutator)))); + weighted_mutators_.insert( + std::make_pair(accum_weight, MutateRule::Make(std::get<0>(mutator)))); } } @@ -72,80 +75,109 @@ SearchState EvolutionarySearch::SearchModuleExpr(const TuningOptions& options) { return SearchModuleExprBests(options)[0]; } -std::vector EvolutionarySearch::SearchModuleExprBests(const TuningOptions& options) { - VLOG(4) << "start SearchModuleExprBests with initial statistics: visited_candidates size=" +std::vector EvolutionarySearch::SearchModuleExprBests( + const TuningOptions& options) { + VLOG(4) << "start SearchModuleExprBests with initial statistics: " + "visited_candidates size=" << visited_candidates_.size(); std::vector init_population; - std::vector topk_from_database = GetTopKCandidatesFromDatabase(options.evolution_pick_database_topk); + std::vector topk_from_database = + GetTopKCandidatesFromDatabase(options.evolution_pick_database_topk); VLOG(4) << JoinStatesDebugString( - "EvolutionarySearch::GetTopKCandidatesFromDatabase", topk_from_database, /*verbose=*/VLOG_IS_ON(5)); - int init_num = options.evolution_init_population_num - topk_from_database.size(); + "EvolutionarySearch::GetTopKCandidatesFromDatabase", + topk_from_database, + /*verbose=*/VLOG_IS_ON(5)); + int init_num = + options.evolution_init_population_num - topk_from_database.size(); std::vector init_sketch = InitSketch(init_num, "rule_prune"); - VLOG(4) << JoinStatesDebugString("EvolutionarySearch::InitSketch", init_sketch, /*verbose=*/VLOG_IS_ON(5)); + VLOG(4) << JoinStatesDebugString( + "EvolutionarySearch::InitSketch", init_sketch, /*verbose=*/VLOG_IS_ON(5)); - init_population.insert(init_population.end(), topk_from_database.begin(), topk_from_database.end()); - init_population.insert(init_population.end(), init_sketch.begin(), init_sketch.end()); + init_population.insert(init_population.end(), + topk_from_database.begin(), + topk_from_database.end()); + init_population.insert( + init_population.end(), init_sketch.begin(), init_sketch.end()); std::vector picked_bests = - Evolve(init_population, options.evolution_cross_over_num, options.num_samples_per_iteration); - VLOG(4) << JoinStatesDebugString("EvolutionarySearch::Evolve", picked_bests, /*verbose=*/VLOG_IS_ON(5)); + Evolve(init_population, + options.evolution_cross_over_num, + options.num_samples_per_iteration); + VLOG(4) << JoinStatesDebugString( + "EvolutionarySearch::Evolve", picked_bests, /*verbose=*/VLOG_IS_ON(5)); return picked_bests; } -std::vector EvolutionarySearch::SearchModuleExprEpsGreedy(const TuningOptions& options) { +std::vector EvolutionarySearch::SearchModuleExprEpsGreedy( + const TuningOptions& options) { std::vector picked_bests = SearchModuleExprBests(options); - int random_num = options.evolution_init_population_num - options.evolution_pick_database_topk; - auto results = PickNextGenerationEpsGreedy(picked_bests, - InitSketch(random_num, "random_prune"), - options.num_samples_per_iteration, - options.evolution_eps_greedy); + int random_num = options.evolution_init_population_num - + options.evolution_pick_database_topk; + auto results = + PickNextGenerationEpsGreedy(picked_bests, + InitSketch(random_num, "random_prune"), + options.num_samples_per_iteration, + options.evolution_eps_greedy); VLOG(4) << JoinStatesDebugString( - "EvolutionarySearch::PickNextGenerationEpsGreedy", results, /*verbose=*/VLOG_IS_ON(5)); + "EvolutionarySearch::PickNextGenerationEpsGreedy", + results, + /*verbose=*/VLOG_IS_ON(5)); return results; } -std::vector EvolutionarySearch::GetTopKCandidatesFromDatabase(int topk) { +std::vector EvolutionarySearch::GetTopKCandidatesFromDatabase( + int topk) { std::vector results; - const auto& task_key = tune_task_.serialized_key; - auto records = database_->GetTopK(task_key, topk); + const auto& task_key = tune_task_.serialized_key; + auto records = database_->GetTopK(task_key, topk); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); for (auto&& record : records) { - ir::IRSchedule ir_sch(optim::IRCopy(task_registry->Get(task_key)->module_expr), - utils::ForkRandomState(&rand_seed_)); + ir::IRSchedule ir_sch( + optim::IRCopy(task_registry->Get(task_key)->module_expr), + utils::ForkRandomState(&rand_seed_)); ir::ScheduleDesc::ReplayWithProto(record.trace, &ir_sch); results.emplace_back(SearchState(std::move(ir_sch), record.predicted_cost)); } return results; } -void ApplyPostScheduleRules(ir::IRSchedule* schedule, - const std::vector>& post_schedule_rules) { +void ApplyPostScheduleRules( + ir::IRSchedule* schedule, + const std::vector>& post_schedule_rules) { schedule->TagPostSchedule(); for (const auto& post_rule : post_schedule_rules) { post_rule->Apply(schedule); } } -std::vector EvolutionarySearch::InitSketch(int num, const std::string& strategy) { +std::vector EvolutionarySearch::InitSketch( + int num, const std::string& strategy) { VLOG(4) << "InitSketch with num:" << num << ", strategy: " << strategy; - std::vector states = search_space_->GenerateSketches(num, strategy); - auto post_schedule_fn = [this, &states](int index) { + std::vector states = + search_space_->GenerateSketches(num, strategy); + auto post_schedule_fn = [this, &states](int index) { ApplyPostScheduleRules(&states[index]->ir_schedule, post_schedule_rules_); }; - utils::parallel_run(post_schedule_fn, utils::SequenceDispatcher(0, states.size()), states.size()); + utils::parallel_run(post_schedule_fn, + utils::SequenceDispatcher(0, states.size()), + states.size()); return states; } -SearchState EvolutionarySearch::CrossOver(const SearchState& state1, const SearchState& state2) { +SearchState EvolutionarySearch::CrossOver(const SearchState& state1, + const SearchState& state2) { // TODO(CtfGo): tracing CrossOver with IRSchedule std::vector cross_over_exprs; - std::vector father_exprs = state1->ir_schedule.GetModule().GetExprs(); - std::vector mother_exprs = state2->ir_schedule.GetModule().GetExprs(); + std::vector father_exprs = + state1->ir_schedule.GetModule().GetExprs(); + std::vector mother_exprs = + state2->ir_schedule.GetModule().GetExprs(); CHECK_EQ(father_exprs.size(), mother_exprs.size()) - << "CrossOver ModuleExpr in EvolutionarySearch must have same number of AST"; + << "CrossOver ModuleExpr in EvolutionarySearch must have same number of " + "AST"; for (size_t i = 0; i < father_exprs.size(); ++i) { if (utils::SampleUniformInt(0, 2, &rand_seed_) == 0) { @@ -154,44 +186,57 @@ SearchState EvolutionarySearch::CrossOver(const SearchState& state1, const Searc cross_over_exprs.push_back(optim::IRCopy(mother_exprs[i])); } } - auto res = SearchState(ir::IRSchedule(ir::ModuleExpr(cross_over_exprs), utils::ForkRandomState(&rand_seed_))); + auto res = SearchState(ir::IRSchedule(ir::ModuleExpr(cross_over_exprs), + utils::ForkRandomState(&rand_seed_))); if (FLAGS_auto_schedule_use_cost_model) { - res->predicted_cost = cost_model_.Predict(res->ir_schedule.GetModule(), tune_task_.target); + res->predicted_cost = + cost_model_.Predict(res->ir_schedule.GetModule(), tune_task_.target); } - VLOG(5) << JoinStatesDebugString("EvolutionarySearch::CrossOver", {state1, state2, res}, /*verbose=*/VLOG_IS_ON(6)); + VLOG(5) << JoinStatesDebugString("EvolutionarySearch::CrossOver", + {state1, state2, res}, + /*verbose=*/VLOG_IS_ON(6)); return res; } -SearchState EvolutionarySearch::Mutate(const SearchState& state, utils::LinearRandomEngine::StateType* rand_seed) { - CHECK_GT(weighted_mutators_.size(), 0) << "There is no mutate rule can be applied."; +SearchState EvolutionarySearch::Mutate( + const SearchState& state, utils::LinearRandomEngine::StateType* rand_seed) { + CHECK_GT(weighted_mutators_.size(), 0) + << "There is no mutate rule can be applied."; double accu_weight = (weighted_mutators_.rbegin())->first; CHECK_GT(accu_weight, 0) << "The accumulate weight must be greater than 0."; // sample a mutate rule double sample_weight = utils::SampleUniformDouble(0, accu_weight, rand_seed); - auto sampled_iter = weighted_mutators_.upper_bound(sample_weight); - MutateRule* mutator = sampled_iter->second.get(); + auto sampled_iter = weighted_mutators_.upper_bound(sample_weight); + MutateRule* mutator = sampled_iter->second.get(); CHECK(mutator) << "mutator not defined"; // apply mutation on the trace of SearchState - auto trace = state->ir_schedule.GetTraceDesc(); + auto trace = state->ir_schedule.GetTraceDesc(); auto new_trace = mutator->Apply(trace, rand_seed); - // replay the mutated trace on original ModuleExpr to generate a new ir_schedule - const auto& task_key = tune_task_.serialized_key; + // replay the mutated trace on original ModuleExpr to generate a new + // ir_schedule + const auto& task_key = tune_task_.serialized_key; InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); - ir::IRSchedule new_ir_sch(optim::IRCopy(task_registry->Get(task_key)->module_expr), - utils::ForkRandomState(rand_seed)); + ir::IRSchedule new_ir_sch( + optim::IRCopy(task_registry->Get(task_key)->module_expr), + utils::ForkRandomState(rand_seed)); new_trace.Replay(&new_ir_sch, true); ApplyPostScheduleRules(&new_ir_sch, post_schedule_rules_); auto res = SearchState(std::move(new_ir_sch)); - VLOG(5) << JoinStatesDebugString("EvolutionarySearch::Mutate", {state, res}, /*verbose=*/VLOG_IS_ON(6)); + VLOG(5) << JoinStatesDebugString( + "EvolutionarySearch::Mutate", {state, res}, /*verbose=*/VLOG_IS_ON(6)); return res; } -std::vector EvolutionarySearch::Evolve(const std::vector& population, - int cross_over_num, - int ret_num) { +std::vector EvolutionarySearch::Evolve( + const std::vector& population, + int cross_over_num, + int ret_num) { VLOG(4) << utils::StringFormat( - "Evolve with population size=%lu,cross_over_num:%lu,ret_num:%lu", population.size(), cross_over_num, ret_num); + "Evolve with population size=%lu,cross_over_num:%lu,ret_num:%lu", + population.size(), + cross_over_num, + ret_num); int generation_num = population.size(); if (generation_num == 0) { return std::vector(); @@ -199,40 +244,56 @@ std::vector EvolutionarySearch::Evolve(const std::vector evolution(population); for (SearchState& search_state : evolution) { - if (search_state->predicted_cost == SearchState::NOT_INIT_COST && FLAGS_auto_schedule_use_cost_model) { - search_state->predicted_cost = cost_model_.Predict(search_state->ir_schedule.GetModule(), tune_task_.target); + if (search_state->predicted_cost == SearchState::NOT_INIT_COST && + FLAGS_auto_schedule_use_cost_model) { + search_state->predicted_cost = cost_model_.Predict( + search_state->ir_schedule.GetModule(), tune_task_.target); } } - VLOG(4) << JoinStatesDebugString("EvolutionarySearch::Evolve: Init evolution:", evolution, /*verbose=*/VLOG_IS_ON(5)); + VLOG(4) << JoinStatesDebugString( + "EvolutionarySearch::Evolve: Init evolution:", + evolution, + /*verbose=*/VLOG_IS_ON(5)); // cross over for (int i = 0; i < cross_over_num; ++i) { - int first_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_); - int second_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_); + int first_rand_idx = + utils::SampleUniformInt(0, generation_num, &rand_seed_); + int second_rand_idx = + utils::SampleUniformInt(0, generation_num, &rand_seed_); while (first_rand_idx == second_rand_idx) { second_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_); } - evolution.push_back(CrossOver(population[first_rand_idx], population[second_rand_idx])); + evolution.push_back( + CrossOver(population[first_rand_idx], population[second_rand_idx])); } VLOG(4) << JoinStatesDebugString( - "EvolutionarySearch::Evolve: after CrossOver evolution:", evolution, /*verbose=*/VLOG_IS_ON(5)); + "EvolutionarySearch::Evolve: after CrossOver evolution:", + evolution, + /*verbose=*/VLOG_IS_ON(5)); // mutate std::vector mutated_individuals(evolution.size()); - std::vector rand_seeds(evolution.size()); + std::vector rand_seeds( + evolution.size()); for (int i = 0; i < rand_seeds.size(); ++i) { rand_seeds[i] = utils::ForkRandomState(&rand_seed_); } - auto mutate_fn = [this, &evolution, &mutated_individuals, &rand_seeds](int index) { + auto mutate_fn = [this, &evolution, &mutated_individuals, &rand_seeds]( + int index) { mutated_individuals[index] = Mutate(evolution[index], &rand_seeds[index]); }; - utils::parallel_run(mutate_fn, utils::SequenceDispatcher(0, evolution.size()), evolution.size()); + utils::parallel_run(mutate_fn, + utils::SequenceDispatcher(0, evolution.size()), + evolution.size()); if (FLAGS_auto_schedule_use_cost_model) { for (size_t i = 0; i < mutated_individuals.size(); ++i) { - mutated_individuals[i]->predicted_cost = - cost_model_.Predict(mutated_individuals[i]->ir_schedule.GetModule(), tune_task_.target); + mutated_individuals[i]->predicted_cost = cost_model_.Predict( + mutated_individuals[i]->ir_schedule.GetModule(), tune_task_.target); } } VLOG(4) << JoinStatesDebugString( - "EvolutionarySearch::Evolve: mutated individuals:", mutated_individuals, /*verbose=*/VLOG_IS_ON(5)); + "EvolutionarySearch::Evolve: mutated individuals:", + mutated_individuals, + /*verbose=*/VLOG_IS_ON(5)); // select top ret_num with predicted cost utils::SizedMultiSet evolution_with_cost(ret_num); for (size_t i = 0; i < evolution.size(); ++i) { @@ -241,25 +302,29 @@ std::vector EvolutionarySearch::Evolve(const std::vector>(); + auto selected_individuals = + evolution_with_cost.ReturnAsContainer>(); VLOG(4) << JoinStatesDebugString( - "EvolutionarySearch::Evolve: selected individuals:", selected_individuals, /*verbose=*/VLOG_IS_ON(5)); + "EvolutionarySearch::Evolve: selected individuals:", + selected_individuals, + /*verbose=*/VLOG_IS_ON(5)); return selected_individuals; } -std::vector EvolutionarySearch::PickNextGenerationEpsGreedy(const std::vector& picked_bests, - const std::vector& random_init, - int num, - float eps_greedy) { +std::vector EvolutionarySearch::PickNextGenerationEpsGreedy( + const std::vector& picked_bests, + const std::vector& random_init, + int num, + float eps_greedy) { int num_rands = num * eps_greedy; int num_bests = num - num_rands; std::vector result; SearchState selected; int deduplicated_cnt = 0; - int best_idx = 0; - int rand_idx = 0; + int best_idx = 0; + int rand_idx = 0; while (result.size() < num) { if (result.size() < num_bests && best_idx < picked_bests.size()) { selected = picked_bests[best_idx]; @@ -276,18 +341,23 @@ std::vector EvolutionarySearch::PickNextGenerationEpsGreedy(const s if (!visited_candidates_.count(selected)) { // deduplicate VLOG(4) << JoinStatesDebugString( - "EvolutionarySearch::PickNextGenerationEpsGreedy-Selected", {selected}, /*verbose=*/VLOG_IS_ON(5)); + "EvolutionarySearch::PickNextGenerationEpsGreedy-Selected", + {selected}, + /*verbose=*/VLOG_IS_ON(5)); visited_candidates_.insert(selected); result.push_back(selected); } else { ++deduplicated_cnt; VLOG(4) << JoinStatesDebugString( - "EvolutionarySearch::PickNextGenerationEpsGreedy-Deduplicated", {selected}, /*verbose=*/VLOG_IS_ON(5)); + "EvolutionarySearch::PickNextGenerationEpsGreedy-Deduplicated", + {selected}, + /*verbose=*/VLOG_IS_ON(5)); } } VLOG(4) << utils::StringFormat( - "PickNextGenerationEpsGreedy: picked_bests size=%lu,random_init size=%lu,num=%d," + "PickNextGenerationEpsGreedy: picked_bests size=%lu,random_init " + "size=%lu,num=%d," "eps_greedy=%f,deduplicated_cnt=%d,result size=%lu", picked_bests.size(), random_init.size(), diff --git a/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h index 21005f53988a1..a9215ae5c29c6 100644 --- a/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h +++ b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h @@ -41,11 +41,12 @@ class EvolutionarySearch { * @param tune_task: the TuneTask this class works on. This class doesn't * take ownership of the pointer. */ - EvolutionarySearch(const TuneTask& tune_task, - const ExprCostModel& cost_model, - Database* database, - utils::LinearRandomEngine::StateType rand_seed = -1, - const std::vector>& mutate_rules = {}); + EvolutionarySearch( + const TuneTask& tune_task, + const ExprCostModel& cost_model, + Database* database, + utils::LinearRandomEngine::StateType rand_seed = -1, + const std::vector>& mutate_rules = {}); /** * Destructor @@ -55,14 +56,16 @@ class EvolutionarySearch { /** * Run the evolutionary search for one iteration. * - * @return SearchState containing the best ir::ModuleExpr searched in this iteration + * @return SearchState containing the best ir::ModuleExpr searched in this + * iteration */ SearchState SearchModuleExpr(const TuningOptions& options); /** * Run the evolutionary search for one iteration. * - * @return SearchState(s) containing best ir::ModuleExpr(s) searched in this iteration + * @return SearchState(s) containing best ir::ModuleExpr(s) searched in this + * iteration */ std::vector SearchModuleExprBests(const TuningOptions& options); @@ -77,7 +80,8 @@ class EvolutionarySearch { * "eps * total_return_size" random samples and * "(1 - eps) * total_return_size" best searched samples. */ - std::vector SearchModuleExprEpsGreedy(const TuningOptions& options); + std::vector SearchModuleExprEpsGreedy( + const TuningOptions& options); #ifdef CINN_WITH_TEST /** @@ -87,13 +91,23 @@ class EvolutionarySearch { * @param search_space: the mock search space, note that EvolutionarySearch * takes the ownership. */ - void SetSearchSpace(SearchSpace* search_space) { search_space_.reset(search_space); } + void SetSearchSpace(SearchSpace* search_space) { + search_space_.reset(search_space); + } - // Method only be called during testing, it is a wrapper of private method InitSketch(). - std::vector TestInitSketch(int num, const std::string& strategy) { return InitSketch(num, strategy); } + // Method only be called during testing, it is a wrapper of private method + // InitSketch(). + std::vector TestInitSketch(int num, + const std::string& strategy) { + return InitSketch(num, strategy); + } - // Method only be called during testing, it is a wrapper of private method Evolve(). - std::vector TestEvolve(const std::vector& population, int cross_over_num, int ret_num) { + // Method only be called during testing, it is a wrapper of private method + // Evolve(). + std::vector TestEvolve( + const std::vector& population, + int cross_over_num, + int ret_num) { return Evolve(population, cross_over_num, ret_num); } #endif @@ -105,26 +119,34 @@ class EvolutionarySearch { * \brief Generate sketch as initial population of evolutionary search. * @param num The number of sketches to generate. * @param strategy The strategy to generate sketches, - * Current optional strategies are "rule_prune" or "random_prune" or "random". - * - "rule_prune": will use rules to prune and generate sketches as efficiently as possible. - * - "random_prune": will use the new interface ApplySketchRules() to simulate the random generation of sketches, - * and supports the function of a rule returning multiple SearchStates and random pruning by probability. - * - "random": will randomly select a block and a rule to apply and repeat this step several times, - * however, each rule can only be used on one SearchState at most once. + * Current optional strategies are "rule_prune" or "random_prune" or + * "random". + * - "rule_prune": will use rules to prune and generate sketches as + * efficiently as possible. + * - "random_prune": will use the new interface ApplySketchRules() to simulate + * the random generation of sketches, and supports the function of a rule + * returning multiple SearchStates and random pruning by probability. + * - "random": will randomly select a block and a rule to apply and repeat + * this step several times, however, each rule can only be used on one + * SearchState at most once. * @return Generated sketches. */ std::vector InitSketch(int num, const std::string& strategy); - SearchState Mutate(const SearchState& state, utils::LinearRandomEngine::StateType* rand_seed); + SearchState Mutate(const SearchState& state, + utils::LinearRandomEngine::StateType* rand_seed); SearchState CrossOver(const SearchState& state1, const SearchState& state2); - std::vector Evolve(const std::vector& population, int cross_over_num, int ret_num); + std::vector Evolve(const std::vector& population, + int cross_over_num, + int ret_num); - std::vector PickNextGenerationEpsGreedy(const std::vector& population, - const std::vector& random_init, - int num, - float eps_greedy); + std::vector PickNextGenerationEpsGreedy( + const std::vector& population, + const std::vector& random_init, + int num, + float eps_greedy); private: std::unique_ptr search_space_; @@ -132,7 +154,8 @@ class EvolutionarySearch { const ExprCostModel& cost_model_; // not owned Database* database_; // not owned // used to duplicate states with the same structural IR - std::unordered_set visited_candidates_; + std::unordered_set + visited_candidates_; // mutate rule names and their weights std::vector> mutators_; // mutate rules, the key is the accumulate weight of each mutate rule diff --git a/paddle/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc index cab2249cc5f2f..23743384c71f3 100644 --- a/paddle/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc +++ b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc @@ -34,17 +34,23 @@ namespace cinn { namespace auto_schedule { -std::vector CreateTasks(const frontend::Program& program, const Target& target) { +std::vector CreateTasks(const frontend::Program& program, + const Target& target) { auto graph = std::make_shared(program, target); TaskCreator task_creator; - auto tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); - const auto& dtype_dict = graph->GetAttrs>("inferdtype"); - const auto& shape_dict = graph->GetAttrs>("infershape"); - auto op_lowerer = std::make_unique(dtype_dict, shape_dict, target); + auto tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); + const auto& dtype_dict = + graph->GetAttrs>( + "inferdtype"); + const auto& shape_dict = graph->GetAttrs< + absl::flat_hash_map>("infershape"); + auto op_lowerer = std::make_unique( + dtype_dict, shape_dict, target); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); for (auto i = 0; i < tasks.size(); ++i) { tasks[i].Initialize(shape_dict, dtype_dict, op_lowerer.get()); - task_registry->Regist(tasks[i].serialized_key, ir::ModuleExpr(tasks[i].GetLoweredFuncBodyExprs())); + task_registry->Regist(tasks[i].serialized_key, + ir::ModuleExpr(tasks[i].GetLoweredFuncBodyExprs())); } return tasks; } @@ -64,7 +70,8 @@ class MockSearchSpace : public SearchSpace { int GetModuleExprSize() const { return module_expr_size_; } - std::vector GenerateSketches(int num, const std::string& strategy) override { + std::vector GenerateSketches( + int num, const std::string& strategy) override { std::vector ret; for (int i = 0; i < num; ++i) { std::vector exprs; @@ -79,12 +86,13 @@ class MockSearchSpace : public SearchSpace { private: int module_expr_size_ = 10; - int min_expr_value_ = 0; + int min_expr_value_ = 0; }; class MockCostModel : public ExprCostModel { - float Predict(const ir::ModuleExpr& sample, const common::Target& target) const override { - float cost = 0.0f; + float Predict(const ir::ModuleExpr& sample, + const common::Target& target) const override { + float cost = 0.0f; std::vector exprs = sample.GetExprs(); for (const ir::Expr& expr : exprs) { if (expr.as_int32()) { @@ -97,10 +105,11 @@ class MockCostModel : public ExprCostModel { TEST(EvolutionarySearch, GetOneBest) { TuneTask mock_tune_task; - mock_tune_task.serialized_key = "mock_task"; - mock_tune_task.target = common::DefaultTarget(); + mock_tune_task.serialized_key = "mock_task"; + mock_tune_task.target = common::DefaultTarget(); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); - task_registry->Regist(mock_tune_task.serialized_key, ir::ModuleExpr({ir::Expr(0)})); + task_registry->Regist(mock_tune_task.serialized_key, + ir::ModuleExpr({ir::Expr(0)})); MockCostModel cost_model; TuningOptions options; Database db(2); @@ -109,7 +118,7 @@ TEST(EvolutionarySearch, GetOneBest) { MockSearchSpace* mock_search_space = new MockSearchSpace(mock_tune_task); // Ownership is transferred so don't delete mock_search_space evolutionary_search.SetSearchSpace(mock_search_space); - SearchState best_state = evolutionary_search.SearchModuleExpr(options); + SearchState best_state = evolutionary_search.SearchModuleExpr(options); std::vector exprs = best_state->ir_schedule.GetModule().GetExprs(); EXPECT_GE(exprs.size(), 1UL); for (const ir::Expr& e : exprs) { @@ -119,10 +128,11 @@ TEST(EvolutionarySearch, GetOneBest) { TEST(EvolutionarySearch, GetEpsGreedy) { TuneTask mock_tune_task; - mock_tune_task.serialized_key = "mock_task"; - mock_tune_task.target = common::DefaultTarget(); + mock_tune_task.serialized_key = "mock_task"; + mock_tune_task.target = common::DefaultTarget(); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); - task_registry->Regist(mock_tune_task.serialized_key, ir::ModuleExpr({ir::Expr(0)})); + task_registry->Regist(mock_tune_task.serialized_key, + ir::ModuleExpr({ir::Expr(0)})); ExprCostModel cost_model; TuningOptions options; Database db(2); @@ -131,10 +141,12 @@ TEST(EvolutionarySearch, GetEpsGreedy) { MockSearchSpace* mock_search_space = new MockSearchSpace(mock_tune_task); // Ownership is transferred so don't delete mock_search_space evolutionary_search.SetSearchSpace(mock_search_space); - std::vector search_states = evolutionary_search.SearchModuleExprEpsGreedy(options); + std::vector search_states = + evolutionary_search.SearchModuleExprEpsGreedy(options); EXPECT_GE(search_states.size(), 1UL); - size_t expr_size = static_cast(mock_search_space->GetModuleExprSize()); + size_t expr_size = + static_cast(mock_search_space->GetModuleExprSize()); for (const SearchState& state : search_states) { EXPECT_EQ(state->ir_schedule.GetModule().GetExprs().size(), expr_size); } @@ -142,7 +154,9 @@ TEST(EvolutionarySearch, GetEpsGreedy) { TEST(EvolutionarySearch, Evolve) { auto target = common::DefaultNVGPUTarget(); - auto tasks = CreateTasks(tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}), target); + auto tasks = CreateTasks( + tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}), + target); CHECK_EQ(tasks.size(), 1); ExprCostModel cost_model; std::vector cost_model_samples(1); @@ -150,7 +164,7 @@ TEST(EvolutionarySearch, Evolve) { for (size_t i = 0; i < 2; ++i) { ir::ModuleExpr me({ir::Expr(tasks[0].lowered_funcs[0])}); cost_model_samples[0] = &me; - cost_model_labels[0] = i + 10; + cost_model_labels[0] = i + 10; cost_model.Update(cost_model_samples, cost_model_labels, target); } @@ -160,22 +174,25 @@ TEST(EvolutionarySearch, Evolve) { EvolutionarySearch evolutionary_search(tasks[0], cost_model, &db); - int num_population = 10; - std::vector init_sketch = evolutionary_search.TestInitSketch(num_population, "rule_prune"); + int num_population = 10; + std::vector init_sketch = + evolutionary_search.TestInitSketch(num_population, "rule_prune"); for (int i = 0; i < num_population; ++i) { ir::ModuleExpr me(init_sketch[i]->ir_schedule.GetModule()); cost_model_samples[0] = &me; - cost_model_labels[0] = i; + cost_model_labels[0] = i; cost_model.Update(cost_model_samples, cost_model_labels, target); } VLOG(6) << "init sketch costs:"; for (auto s : init_sketch) { VLOG(6) << "cost = " << s->predicted_cost; } - std::vector*population_pre_ptr = &init_sketch, *population_next_ptr; + std::vector*population_pre_ptr = &init_sketch, + *population_next_ptr; std::vector population; for (int i = 0; i < 10; ++i) { - population = evolutionary_search.TestEvolve(*population_pre_ptr, /*cross_over_num*/ 0, /*ret_num*/ 10); + population = evolutionary_search.TestEvolve( + *population_pre_ptr, /*cross_over_num*/ 0, /*ret_num*/ 10); population_next_ptr = &population; VLOG(6) << "population[" << i + 1 << "] costs:"; double total_cost_pre = 0.0, total_cost_next = 0.0; diff --git a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h index 02fcfa088e6fd..5a0b097cf451c 100644 --- a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h +++ b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h @@ -34,11 +34,14 @@ class MutateRule { * @param rand_seed The random seed for mutation. * @return The mutated trace. */ - virtual ir::ScheduleDesc Apply(const ir::ScheduleDesc& trace, utils::LinearRandomEngine::StateType* rand_seed) = 0; + virtual ir::ScheduleDesc Apply( + const ir::ScheduleDesc& trace, + utils::LinearRandomEngine::StateType* rand_seed) = 0; /** * @brief Create a MutateRule with name. - * @param name The name of mutate rule, consisting of lowercase letters and underscores + * @param name The name of mutate rule, consisting of lowercase letters and + * underscores * @return The created MutateRule. */ static std::unique_ptr Make(const std::string& name); diff --git a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.cc b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.cc index 4026d295dbea8..d1eebb86d2a8b 100644 --- a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.cc +++ b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.cc @@ -44,8 +44,11 @@ std::vector FindSampledTiles(const ScheduleDesc& trace) { break; } if (step.type == "SamplePerfectTile") { - std::vector tile_factors = absl::get>(step.attrs.at("decision")); - CHECK(tile_factors.size() >= 2) << "factors size must be greater equal than 2, which is " << tile_factors.size(); + std::vector tile_factors = + absl::get>(step.attrs.at("decision")); + CHECK(tile_factors.size() >= 2) + << "factors size must be greater equal than 2, which is " + << tile_factors.size(); tiles.push_back(std::make_tuple(step, tile_factors, step_idx)); } ++step_idx; @@ -57,9 +60,9 @@ std::vector FindSampledTiles(const ScheduleDesc& trace) { ScheduleDesc DoMutateTileSize(const ScheduleDesc& trace, const SampledTile& tile, LinearRandomEngine::StateType* rand_seed) { - ScheduleDesc::Step step = std::get<0>(tile); + ScheduleDesc::Step step = std::get<0>(tile); std::vector tile_factors = std::get<1>(tile); - int split_size = tile_factors.size(); + int split_size = tile_factors.size(); // Step 1. Choose 2 loops with index: 'loop_x' and 'loop_y' int loop_x, loop_y; @@ -89,10 +92,13 @@ ScheduleDesc DoMutateTileSize(const ScheduleDesc& trace, // Step 2. Choose the divisor for mutate. int divisor; if (loop_y == split_size - 1) { - int max_innermost_factor = absl::get(step.attrs.at("max_innermost_factor")); + int max_innermost_factor = + absl::get(step.attrs.at("max_innermost_factor")); int max_optional_factor_idx = optional_factors.size() - 1; for (; max_optional_factor_idx > 0; --max_optional_factor_idx) { - if (optional_factors.at(max_optional_factor_idx) * tile_factors.at(loop_y) <= max_innermost_factor) { + if (optional_factors.at(max_optional_factor_idx) * + tile_factors.at(loop_y) <= + max_innermost_factor) { break; } } @@ -103,27 +109,32 @@ ScheduleDesc DoMutateTileSize(const ScheduleDesc& trace, } continue; } - divisor = optional_factors.at(utils::SampleUniformInt(1, max_optional_factor_idx + 1, rand_seed)); + divisor = optional_factors.at( + utils::SampleUniformInt(1, max_optional_factor_idx + 1, rand_seed)); } else { - divisor = optional_factors.at(utils::SampleUniformInt(1, optional_factors.size(), rand_seed)); + divisor = optional_factors.at( + utils::SampleUniformInt(1, optional_factors.size(), rand_seed)); } // Step 3. Determine the new tile value - VLOG(6) << "DoMutateTileSize: divisor = " << divisor << ", before mutate: \n" - << "factors[" << loop_x << "] = " << tile_factors[loop_x] << ", factors[" << loop_y - << "] = " << tile_factors[loop_y]; + VLOG(6) << "DoMutateTileSize: divisor = " << divisor + << ", before mutate: \n" + << "factors[" << loop_x << "] = " << tile_factors[loop_x] + << ", factors[" << loop_y << "] = " << tile_factors[loop_y]; tile_factors[loop_x] /= divisor; tile_factors[loop_y] *= divisor; VLOG(6) << "after mutate: \n" - << "factors[" << loop_x << "] = " << tile_factors[loop_x] << ", factors[" << loop_y - << "] = " << tile_factors[loop_y]; + << "factors[" << loop_x << "] = " << tile_factors[loop_x] + << ", factors[" << loop_y << "] = " << tile_factors[loop_y]; // Step 4. Create a new step with new tile values and return the new trace int step_idx = std::get<2>(tile); return trace.ForkAndUpdate(step_idx, tile_factors, true); } } -ScheduleDesc MutateTileSize::Apply(const ScheduleDesc& trace, LinearRandomEngine::StateType* rand_seed) { - VLOG(6) << "Start applying MutateTileSize, old trace: \n" << trace.DebugString(); +ScheduleDesc MutateTileSize::Apply(const ScheduleDesc& trace, + LinearRandomEngine::StateType* rand_seed) { + VLOG(6) << "Start applying MutateTileSize, old trace: \n" + << trace.DebugString(); std::vector sample_tile_steps; std::vector> sample_tile_data; @@ -132,9 +143,12 @@ ScheduleDesc MutateTileSize::Apply(const ScheduleDesc& trace, LinearRandomEngine VLOG(6) << "MutateTileSize failed, try other mutate rules."; return trace; } - int sample_step_idx = utils::SampleUniformInt(0, sampled_tiles.size(), rand_seed); - auto new_trace = DoMutateTileSize(trace, sampled_tiles.at(sample_step_idx), rand_seed); - VLOG(6) << "End applying MutateTileSize, new trace: \n" << new_trace.DebugString(); + int sample_step_idx = + utils::SampleUniformInt(0, sampled_tiles.size(), rand_seed); + auto new_trace = + DoMutateTileSize(trace, sampled_tiles.at(sample_step_idx), rand_seed); + VLOG(6) << "End applying MutateTileSize, new trace: \n" + << new_trace.DebugString(); return new_trace; } diff --git a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h index 7f860204f1a39..0d4c557618dea 100644 --- a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h +++ b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h @@ -20,13 +20,16 @@ namespace cinn { namespace auto_schedule { /** - * The rule to mutate tile size, witch will modify the factors of the Split primitive. + * The rule to mutate tile size, witch will modify the factors of the Split + * primitive. */ class MutateTileSize : public MutateRule { public: MutateTileSize() = default; - ir::ScheduleDesc Apply(const ir::ScheduleDesc& trace, utils::LinearRandomEngine::StateType* rand_seed) override; + ir::ScheduleDesc Apply( + const ir::ScheduleDesc& trace, + utils::LinearRandomEngine::StateType* rand_seed) override; }; } // namespace auto_schedule diff --git a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc index c334761174fff..dfd895b72ed9f 100644 --- a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc +++ b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc @@ -42,32 +42,45 @@ TEST(MutateTileSize, Basic) { Var k(K.as_int32(), "reduce_axis_k"); ir::Tensor C = Compute( - {M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + {M, N}, + [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, + "C"); poly::StageMap stages = CreateStages({A, B, C}); std::vector funcs = - lang::LowerVec("TestMutateTileSize_Basic", stages, {A, B, C}, {}, {}, nullptr, target, true); + lang::LowerVec("TestMutateTileSize_Basic", + stages, + {A, B, C}, + {}, + {}, + nullptr, + target, + true); ir::Expr ast_expr = funcs[0]->body; VLOG(6) << "Original Expr: "; VLOG(6) << ast_expr; ir::ModuleExpr module_expr({ast_expr}); - // We need to fix the seed as a constant to ensure that the result can be repeated. + // We need to fix the seed as a constant to ensure that the result can be + // repeated. utils::LinearRandomEngine::StateType rand_seed = 123; ir::IRSchedule ir_schedule(module_expr, rand_seed); ir::IRSchedule new_ir_schedule(ir_schedule); // apply schedule - auto loops = ir_schedule.GetLoops("C"); + auto loops = ir_schedule.GetLoops("C"); auto factors = ir_schedule.SamplePerfectTile(loops[0], 2, kSize); auto splited = ir_schedule.Split(loops[0], factors); // apply mutate MutateTileSize mutator; - ir::ScheduleDesc sch_desc = mutator.Apply(ir_schedule.GetTraceDesc(), &rand_seed); + ir::ScheduleDesc sch_desc = + mutator.Apply(ir_schedule.GetTraceDesc(), &rand_seed); sch_desc.Replay(&new_ir_schedule, true); - VLOG(6) << "Expr before mutate tile size: \n" << ir_schedule.GetModule().GetExprs()[0]; - VLOG(6) << "Expr after mutate tile size: \n" << new_ir_schedule.GetModule().GetExprs()[0]; + VLOG(6) << "Expr before mutate tile size: \n" + << ir_schedule.GetModule().GetExprs()[0]; + VLOG(6) << "Expr after mutate tile size: \n" + << new_ir_schedule.GetModule().GetExprs()[0]; std::string target_new_ir = R"ROC({ ScheduleBlock(root) @@ -111,7 +124,8 @@ TEST(MutateTileSize, Basic) { sch_desc = mutator.Apply(sch_desc, &rand_seed); for (auto&& step : sch_desc.Steps()) { if (step.type == "SamplePerfectTile") { - std::vector tile_factors = absl::get>(step.attrs.at("decision")); + std::vector tile_factors = + absl::get>(step.attrs.at("decision")); ASSERT_EQ(tile_factors.size(), last_tile_factors.size()); ASSERT_NE(tile_factors[0], last_tile_factors[0]); ASSERT_NE(tile_factors[1], last_tile_factors[1]); diff --git a/paddle/cinn/auto_schedule/task/task_creator.cc b/paddle/cinn/auto_schedule/task/task_creator.cc index c22ce44587633..0d8e48e8d0a79 100644 --- a/paddle/cinn/auto_schedule/task/task_creator.cc +++ b/paddle/cinn/auto_schedule/task/task_creator.cc @@ -36,7 +36,8 @@ using ::cinn::hlir::framework::NodeData; std::vector TaskCreator::CreateTuneTaskOpLevel(Graph* graph) { std::vector ret_tasks; - const std::vector>* groups = &graph->fusion_groups; + const std::vector>* groups = + &graph->fusion_groups; std::vector> non_fused_groups; // The input graph doesn't run Op Fusion if (graph->fusion_groups.empty()) { @@ -48,7 +49,7 @@ std::vector TaskCreator::CreateTuneTaskOpLevel(Graph* graph) { for (const auto& sub_graph : *groups) { ret_tasks.emplace_back(TuneTask()); ret_tasks.back().subgraph = sub_graph; - ret_tasks.back().target = graph->target_; + ret_tasks.back().target = graph->target_; } return ret_tasks; } diff --git a/paddle/cinn/auto_schedule/task/task_creator_test.cc b/paddle/cinn/auto_schedule/task/task_creator_test.cc index cc7d7e0b3dd82..60b5ebec0e808 100644 --- a/paddle/cinn/auto_schedule/task/task_creator_test.cc +++ b/paddle/cinn/auto_schedule/task/task_creator_test.cc @@ -39,10 +39,10 @@ Program CreateAddProgram() { constexpr int N = 24; NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {M, N}, "A"); - auto b = builder.CreateInput(Float(32), {M, N}, "B"); - auto c = builder.Add(a, b); - auto d = builder.Add(a, c); + auto a = builder.CreateInput(Float(32), {M, N}, "A"); + auto b = builder.CreateInput(Float(32), {M, N}, "B"); + auto c = builder.Add(a, b); + auto d = builder.Add(a, c); auto program = builder.Build(); return program; @@ -55,7 +55,7 @@ TEST(TaskCreator, Basic) { Target target = common::DefaultHostTarget(); #endif Program prog = CreateAddProgram(); - auto graph = std::make_shared(prog, target); + auto graph = std::make_shared(prog, target); TaskCreator task_creator; std::vector tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); diff --git a/paddle/cinn/auto_schedule/task/task_optimizer.cc b/paddle/cinn/auto_schedule/task/task_optimizer.cc index a3a50c98dee6f..f988a03eb301b 100644 --- a/paddle/cinn/auto_schedule/task/task_optimizer.cc +++ b/paddle/cinn/auto_schedule/task/task_optimizer.cc @@ -47,14 +47,18 @@ namespace auto_schedule { using cinn::hlir::op::ExternalApiRegistry; -// *** forward declarations of auxiliary functions to be used in this file only *** -// update a scheduled function with several post-processors -ir::LoweredFunc FuncWithUpdatedBody(const common::Target& target, const ir::LoweredFunc& old_func, ir::Expr& body); +// *** forward declarations of auxiliary functions to be used in this file only +// *** update a scheduled function with several post-processors +ir::LoweredFunc FuncWithUpdatedBody(const common::Target& target, + const ir::LoweredFunc& old_func, + ir::Expr& body); // check whether a scheduled lowered function is valid -bool PruneInvalid(const ir::LoweredFunc& lowered_func, const common::Target& target); +bool PruneInvalid(const ir::LoweredFunc& lowered_func, + const common::Target& target); // exclude some special tasks bool IsForbiddenToTune(const TuneTask* task); -// tell whether the task has been wrapped by custom_call in TransToCustomCallPass +// tell whether the task has been wrapped by custom_call in +// TransToCustomCallPass bool IsWrappedByCustomCall(const TuneTask* task); // tell whether the task has registered external api bool HasExternalApi(const TuneTask* task); @@ -75,10 +79,11 @@ FunctionGroup TaskOptimizer::Optimize(const TuningOptions& options) { if (IsForbiddenToTune(task_) || IsWrappedByCustomCall(task_)) { return task_->op_lowerer->Lower(task_->subgraph); } - // TODO(CtfGo): the input/output names of a Graph::Group will be changed in Lowering by OpLowerer currently, - // so we should revert them after following different lower methods, remove this hard code by fixing the - // decoupling between lowering and BuildInstructions - auto initial_input_names = task_->subgraph->input_names; + // TODO(CtfGo): the input/output names of a Graph::Group will be changed in + // Lowering by OpLowerer currently, so we should revert them after following + // different lower methods, remove this hard code by fixing the decoupling + // between lowering and BuildInstructions + auto initial_input_names = task_->subgraph->input_names; auto initial_output_names = task_->subgraph->output_names; std::vector candidates; @@ -87,12 +92,15 @@ FunctionGroup TaskOptimizer::Optimize(const TuningOptions& options) { if (HasExternalApi(task_)) { candidates.emplace_back(OptimizeByExternal(options.num_measure_trials > 0)); } - sort(candidates.begin(), candidates.end(), [](const auto& lhs, const auto& rhs) { return lhs.cost < rhs.cost; }); + sort(candidates.begin(), + candidates.end(), + [](const auto& lhs, const auto& rhs) { return lhs.cost < rhs.cost; }); auto&& best = candidates.front(); - VLOG(4) << "Total candidates=" << candidates.size() << ", the best from=" << best.from << ", cost=" << best.cost; + VLOG(4) << "Total candidates=" << candidates.size() + << ", the best from=" << best.from << ", cost=" << best.cost; // revert input/output names - task_->subgraph->input_names = initial_input_names; + task_->subgraph->input_names = initial_input_names; task_->subgraph->output_names = initial_output_names; return best.functions; } @@ -109,18 +117,22 @@ TaskOptimizer::Result TaskOptimizer::OptimizeByManual(bool need_measured) { } SearchState state(ir::IRSchedule(ir::ModuleExpr(std::move(func_bodys)))); - // the manual is regarded as the second best in default, so we set its cost 0.0 + // the manual is regarded as the second best in default, so we set its cost + // 0.0 result.cost = 0.0; - // add the specific prefix in front of serialized_key to be store/load measured record for manual schedule + // add the specific prefix in front of serialized_key to be store/load + // measured record for manual schedule std::string measured_key = kManualMeasuredKeyPrefix + task_->serialized_key; if (need_measured && database_->Count(measured_key) == 0) { std::vector inputs(1); - inputs.back().task = task_; + inputs.back().task = task_; inputs.back().lowered_funcs = result.functions; VLOG(4) << "Measure manual schedule"; - std::vector measure_outputs = schedule_measurer_->Measure(inputs); - database_->AddRecord(TuningRecord(measured_key, state, measure_outputs[0].execution_cost)); + std::vector measure_outputs = + schedule_measurer_->Measure(inputs); + database_->AddRecord( + TuningRecord(measured_key, state, measure_outputs[0].execution_cost)); } auto measured_records = database_->LookUp(measured_key); @@ -133,26 +145,32 @@ TaskOptimizer::Result TaskOptimizer::OptimizeByManual(bool need_measured) { TaskOptimizer::Result TaskOptimizer::OptimizeByExternal(bool need_measured) { static constexpr char* kExternalMeasuredKeyPrefix = "@ExternalMeasured:\n"; TaskOptimizer::Result result("External"); - auto nodes = task_->subgraph->CollectNodes(); + auto nodes = task_->subgraph->CollectNodes(); auto* first_node = nodes.front(); // set the necessary field for lowering with external api - std::string original_op = first_node->op()->name; + std::string original_op = first_node->op()->name; first_node->attrs.attr_store["original_op"] = original_op; - first_node->attrs.op = hlir::framework::Operator::Get("custom_call"); - result.functions = task_->op_lowerer->Lower(task_->subgraph); + first_node->attrs.op = hlir::framework::Operator::Get("custom_call"); + result.functions = task_->op_lowerer->Lower(task_->subgraph); - // add the specific prefix in front of serialized_key to be store/load measured record for external api - result.cost = -1.0; // the external is regarded as the best in default, so we set its cost -1.0 + // add the specific prefix in front of serialized_key to be store/load + // measured record for external api + result.cost = -1.0; // the external is regarded as the best in default, so we + // set its cost -1.0 std::string measured_key = kExternalMeasuredKeyPrefix + task_->serialized_key; if (need_measured && database_->Count(measured_key) == 0) { std::vector inputs(1); - inputs.back().task = task_; + inputs.back().task = task_; inputs.back().lowered_funcs = result.functions; VLOG(4) << "Measure external api"; - std::vector measure_outputs = schedule_measurer_->Measure(inputs); - // the SearchState of external is invalid and will not be used, so we just put a temporary one - database_->AddRecord(TuningRecord(measured_key, SearchState(ir::IRSchedule()), measure_outputs[0].execution_cost)); + std::vector measure_outputs = + schedule_measurer_->Measure(inputs); + // the SearchState of external is invalid and will not be used, so we just + // put a temporary one + database_->AddRecord(TuningRecord(measured_key, + SearchState(ir::IRSchedule()), + measure_outputs[0].execution_cost)); } auto measured_records = database_->LookUp(measured_key); @@ -165,10 +183,11 @@ TaskOptimizer::Result TaskOptimizer::OptimizeByExternal(bool need_measured) { bool IsForbiddenToTune(const TuneTask* task) { // TODO(CtfGo): some operators may change its linked edges in // TransToCustomCallPass, like conv2d, we will skip these ops in auto-schedule - // because they can't revert original links for no schedule and manual schedule lowering. + // because they can't revert original links for no schedule and manual + // schedule lowering. static std::unordered_set links_changed_ops = {"conv2d"}; - auto nodes = task->subgraph->CollectNodes(); - auto&& op_name = nodes.front()->op()->name; + auto nodes = task->subgraph->CollectNodes(); + auto&& op_name = nodes.front()->op()->name; if (nodes.size() == 1 && links_changed_ops.count(op_name)) { VLOG(5) << "Op:" << op_name << " is forbidden to call external_api"; return true; @@ -178,20 +197,23 @@ bool IsForbiddenToTune(const TuneTask* task) { } bool HasExternalApi(const TuneTask* task) { - auto nodes = task->subgraph->CollectNodes(); + auto nodes = task->subgraph->CollectNodes(); auto* first_node = nodes.front(); - if (nodes.size() == 1 && ExternalApiRegistry::Global()->Has(first_node->op()->name, task->target)) { + if (nodes.size() == 1 && ExternalApiRegistry::Global()->Has( + first_node->op()->name, task->target)) { return true; } return false; } bool IsWrappedByCustomCall(const TuneTask* task) { - auto nodes = task->subgraph->CollectNodes(); + auto nodes = task->subgraph->CollectNodes(); auto* first_node = nodes.front(); if (nodes.size() == 1 && first_node->op()->name == "custom_call") { - CHECK(first_node->attrs.attr_store.count("original_op")) << "a custom_call op must store its original op name"; - std::string op_name = absl::get(first_node->attrs.attr_store.at("original_op")); + CHECK(first_node->attrs.attr_store.count("original_op")) + << "a custom_call op must store its original op name"; + std::string op_name = + absl::get(first_node->attrs.attr_store.at("original_op")); VLOG(5) << "Op:" << op_name << " was wrapped as custom_call"; return true; } @@ -199,35 +221,42 @@ bool IsWrappedByCustomCall(const TuneTask* task) { return false; } -TaskOptimizer::Result TaskOptimizer::OptimizeByEvolution(const TuningOptions& options) { +TaskOptimizer::Result TaskOptimizer::OptimizeByEvolution( + const TuningOptions& options) { CHECK_EQ(options.num_measure_trials % options.num_samples_per_iteration, 0) - << "TuningOptions.num_measure_trials % TuningOptions.num_samples_per_iteration must be 0."; + << "TuningOptions.num_measure_trials % " + "TuningOptions.num_samples_per_iteration must be 0."; - VLOG(4) << "Optimizing TuneTask with num_measure_trials:" << options.num_measure_trials + VLOG(4) << "Optimizing TuneTask with num_measure_trials:" + << options.num_measure_trials << ", LoweredFunc before optimization is:"; VLOG(4) << "lowered function size = " << task_->lowered_funcs.size(); for (size_t i = 0; i < task_->lowered_funcs.size(); ++i) { - VLOG(4) << "lowered_funcs[" << i << "] detail:\n" << task_->lowered_funcs[i]; + VLOG(4) << "lowered_funcs[" << i << "] detail:\n" + << task_->lowered_funcs[i]; } if (evolutionary_search_ == nullptr) { // TODO(zhhsplendid): check whether the options is same as previous, // if not, we should create new EvolutionarySearch - evolutionary_search_ = - std::make_unique(*task_, cost_model_, database_, utils::ForkRandomState(&rand_seed_)); + evolutionary_search_ = std::make_unique( + *task_, cost_model_, database_, utils::ForkRandomState(&rand_seed_)); } TaskOptimizer::Result result("Evolution"); auto& optimized_funcs = result.functions; - auto& best_cost = result.cost; + auto& best_cost = result.cost; // use initial lowered function as default result optimized_funcs = optim::IRCopy(task_->lowered_funcs); - if (options.num_measure_trials == 0) { // no need to measure and simply return the best searched + if (options.num_measure_trials == + 0) { // no need to measure and simply return the best searched std::vector measure_candidates; - std::vector states = SearchOneRound(options, &measure_candidates); + std::vector states = + SearchOneRound(options, &measure_candidates); if (!states.empty()) { if (FLAGS_auto_schedule_use_cost_model) { - best_cost = cost_model_.Predict(states.front()->ir_schedule.GetModule(), task_->target); + best_cost = cost_model_.Predict(states.front()->ir_schedule.GetModule(), + task_->target); } optimized_funcs = measure_candidates[0].lowered_funcs; } else { @@ -236,7 +265,7 @@ TaskOptimizer::Result TaskOptimizer::OptimizeByEvolution(const TuningOptions& op return result; } - int measured_count = 0; + int measured_count = 0; uint32_t continuous_empty_cnt = 0; while (measured_count < options.num_measure_trials) { VLOG(4) << "Launch a new search, current measured_count:" << measured_count; @@ -245,25 +274,30 @@ TaskOptimizer::Result TaskOptimizer::OptimizeByEvolution(const TuningOptions& op if (states.empty()) { // no new valid candidate achieved ++continuous_empty_cnt; if (continuous_empty_cnt <= kMaxRetryContinuousEmpty_) { - VLOG(4) << "No valid state searched, continuous_empty_cnt=" << continuous_empty_cnt; + VLOG(4) << "No valid state searched, continuous_empty_cnt=" + << continuous_empty_cnt; continue; } else { - LOG(WARNING) - << "OptimizeByEvolution will be exited in advance due to continuous invalid search, final measured_count=" - << measured_count; + LOG(WARNING) << "OptimizeByEvolution will be exited in advance due to " + "continuous invalid search, final measured_count=" + << measured_count; break; } } continuous_empty_cnt = 0; // reset if get valid candidates - VLOG(4) << "ScheduleMeasurer start with input size=" << measure_inputs.size(); - std::vector measure_outputs = schedule_measurer_->Measure(measure_inputs); + VLOG(4) << "ScheduleMeasurer start with input size=" + << measure_inputs.size(); + std::vector measure_outputs = + schedule_measurer_->Measure(measure_inputs); CHECK_EQ(measure_outputs.size(), states.size()) - << "ScheduleMeasurer didn't output same number of MeasureOutput of states in TaskOptimizer"; + << "ScheduleMeasurer didn't output same number of MeasureOutput of " + "states in TaskOptimizer"; // record to database for (size_t i = 0; i < states.size(); ++i) { - database_->AddRecord( - TuningRecord(measure_inputs[i].task->serialized_key, states[i], measure_outputs[i].execution_cost)); + database_->AddRecord(TuningRecord(measure_inputs[i].task->serialized_key, + states[i], + measure_outputs[i].execution_cost)); } // update cost model @@ -272,19 +306,21 @@ TaskOptimizer::Result TaskOptimizer::OptimizeByEvolution(const TuningOptions& op std::vector cost_model_labels(states.size()); for (size_t i = 0; i < states.size(); ++i) { cost_model_samples[i] = &(states[i]->ir_schedule.GetModule()); - cost_model_labels[i] = measure_outputs[i].execution_cost; + cost_model_labels[i] = measure_outputs[i].execution_cost; } - VLOG(4) << utils::StringFormat("Update CostModel with samples size=%lu,labels size=%lu", - cost_model_samples.size(), - cost_model_labels.size()); + VLOG(4) << utils::StringFormat( + "Update CostModel with samples size=%lu,labels size=%lu", + cost_model_samples.size(), + cost_model_labels.size()); cost_model_.Update(cost_model_samples, cost_model_labels, task_->target); } // update the best for (size_t i = 0; i < measure_outputs.size(); ++i) { if (measure_outputs[i].execution_cost < best_cost) { - VLOG(4) << "Update best candidate with execution_cost:" << measure_outputs[i].execution_cost << "us"; - best_cost = measure_outputs[i].execution_cost; + VLOG(4) << "Update best candidate with execution_cost:" + << measure_outputs[i].execution_cost << "us"; + best_cost = measure_outputs[i].execution_cost; optimized_funcs = measure_inputs[i].lowered_funcs; } } @@ -295,20 +331,27 @@ TaskOptimizer::Result TaskOptimizer::OptimizeByEvolution(const TuningOptions& op return result; } -std::vector TaskOptimizer::SearchOneRound(const TuningOptions& options, - std::vector* measure_candidates) { - std::vector states = evolutionary_search_->SearchModuleExprEpsGreedy(options); - VLOG(4) << JoinStatesDebugString("TaskOptimizer::EvolutionarySearch-Result", states, /*verbose=*/VLOG_IS_ON(5)); +std::vector TaskOptimizer::SearchOneRound( + const TuningOptions& options, + std::vector* measure_candidates) { + std::vector states = + evolutionary_search_->SearchModuleExprEpsGreedy(options); + VLOG(4) << JoinStatesDebugString("TaskOptimizer::EvolutionarySearch-Result", + states, + /*verbose=*/VLOG_IS_ON(5)); size_t valid_cnt = 0; for (size_t i = 0; i < states.size(); ++i) { - std::vector best_exprs = states[i]->ir_schedule.GetModule().GetExprs(); + std::vector best_exprs = + states[i]->ir_schedule.GetModule().GetExprs(); CHECK_EQ(best_exprs.size(), task_->lowered_funcs.size()) - << "RuntimeError: Expr size is not equal to LoweredFunc size in TaskOptimizer"; + << "RuntimeError: Expr size is not equal to LoweredFunc size in " + "TaskOptimizer"; auto init_funcs = optim::IRCopy(task_->lowered_funcs); std::vector valid_funcs; for (size_t j = 0; j < best_exprs.size(); ++j) { - auto updated_f = UpdateFuncWithNewBody(task_->target, init_funcs[j], best_exprs[j]); + auto updated_f = + UpdateFuncWithNewBody(task_->target, init_funcs[j], best_exprs[j]); if (PruneInvalid(updated_f, task_->target)) { VLOG(4) << "PruneInvalid states-" << i; break; @@ -320,42 +363,52 @@ std::vector TaskOptimizer::SearchOneRound(const TuningOptions& opti if (valid_funcs.size() == init_funcs.size()) { states[valid_cnt++] = states[i]; measure_candidates->emplace_back(MeasureInput()); - measure_candidates->back().task = task_; + measure_candidates->back().task = task_; measure_candidates->back().lowered_funcs = std::move(valid_funcs); } } states.erase(states.begin() + valid_cnt, states.end()); - CHECK_EQ(states.size(), measure_candidates->size()) << "result size of states not equal to measure_candidates"; - VLOG(4) << "EvolutionarySearch return size=" << states.size() << ", valid count=" << valid_cnt; - VLOG(4) << JoinStatesDebugString("TaskOptimizer::SearchOneRound-Result", states, /*verbose=*/VLOG_IS_ON(5)); + CHECK_EQ(states.size(), measure_candidates->size()) + << "result size of states not equal to measure_candidates"; + VLOG(4) << "EvolutionarySearch return size=" << states.size() + << ", valid count=" << valid_cnt; + VLOG(4) << JoinStatesDebugString("TaskOptimizer::SearchOneRound-Result", + states, + /*verbose=*/VLOG_IS_ON(5)); return states; } -// detect the limit of available shared memory on the current NVGPU with CUDA runtime +// detect the limit of available shared memory on the current NVGPU with CUDA +// runtime size_t GetGPUSharedMemoryLimit() { #ifdef CINN_WITH_CUDA int device_id; CUDA_CALL(cudaGetDevice(&device_id)); cudaDeviceProp prop; CUDA_CALL(cudaGetDeviceProperties(&prop, device_id)); - VLOG(4) << utils::StringFormat("GPU-%d GPUSharedMemoryLimit=%d", device_id, prop.sharedMemPerBlock); + VLOG(4) << utils::StringFormat( + "GPU-%d GPUSharedMemoryLimit=%d", device_id, prop.sharedMemPerBlock); return prop.sharedMemPerBlock; #else return 0; #endif } -// detect the limit of available local/stack memory on the current NVGPU with CUDA runtime +// detect the limit of available local/stack memory on the current NVGPU with +// CUDA runtime size_t GetGPULocalStackLimit() { #ifdef CINN_WITH_CUDA int device_id; CUDA_CALL(cudaGetDevice(&device_id)); cudaDeviceProp prop; CUDA_CALL(cudaGetDeviceProperties(&prop, device_id)); - size_t limit = prop.totalGlobalMem / prop.multiProcessorCount / prop.maxThreadsPerMultiProcessor; + size_t limit = prop.totalGlobalMem / prop.multiProcessorCount / + prop.maxThreadsPerMultiProcessor; VLOG(4) << utils::StringFormat( - "GPU-%d totalGlobalMem=%lu,maxThreadsPerMultiProcessor=%d,multiProcessorCount=%d, calculated " + "GPU-%d " + "totalGlobalMem=%lu,maxThreadsPerMultiProcessor=%d,multiProcessorCount=%" + "d, calculated " "GPULocalStackLimit=%lu", device_id, prop.totalGlobalMem, @@ -368,14 +421,16 @@ size_t GetGPULocalStackLimit() { #endif } -// check whether usage of the specific memory type in the lowered_func exceeds hardware limit +// check whether usage of the specific memory type in the lowered_func exceeds +// hardware limit bool IsGPUMemoryUsageExceedLimit(const ir::LoweredFunc& lowered_func, const ir::MemoryType& used_memory_type, const size_t limit_bytes) { std::unordered_set visited; size_t used_bytes_cnt = 0; for (auto&& buf : lowered_func->temp_bufs) { - VLOG(5) << "temp buf name=" << buf->name << ", numel=" << buf->numel() << ",dtype=" << buf->dtype; + VLOG(5) << "temp buf name=" << buf->name << ", numel=" << buf->numel() + << ",dtype=" << buf->dtype; if (buf->memory_type == used_memory_type && !visited.count(buf->name)) { used_bytes_cnt += buf->numel() * buf->dtype.bytes(); visited.insert(buf->name); @@ -385,18 +440,26 @@ bool IsGPUMemoryUsageExceedLimit(const ir::LoweredFunc& lowered_func, return used_bytes_cnt >= limit_bytes; } -bool PruneInvalid(const ir::LoweredFunc& lowered_func, const common::Target& target) { +bool PruneInvalid(const ir::LoweredFunc& lowered_func, + const common::Target& target) { static const size_t kGPUSharedMemoryLimitBytes = GetGPUSharedMemoryLimit(); - static const size_t kGPULocalStackLimitBytes = GetGPULocalStackLimit(); + static const size_t kGPULocalStackLimitBytes = GetGPULocalStackLimit(); if (target == common::DefaultNVGPUTarget()) { - if (IsGPUMemoryUsageExceedLimit(lowered_func, ir::MemoryType::GPUShared, kGPUSharedMemoryLimitBytes)) { - VLOG(5) << ir::MemoryType::GPUShared << " memory usage exceeds limit, func:\n" << lowered_func; + if (IsGPUMemoryUsageExceedLimit(lowered_func, + ir::MemoryType::GPUShared, + kGPUSharedMemoryLimitBytes)) { + VLOG(5) << ir::MemoryType::GPUShared + << " memory usage exceeds limit, func:\n" + << lowered_func; return true; } - if (IsGPUMemoryUsageExceedLimit(lowered_func, ir::MemoryType::GPULocal, kGPULocalStackLimitBytes)) { - VLOG(5) << ir::MemoryType::GPULocal << " memory usage exceeds limit, func:\n" << lowered_func; + if (IsGPUMemoryUsageExceedLimit( + lowered_func, ir::MemoryType::GPULocal, kGPULocalStackLimitBytes)) { + VLOG(5) << ir::MemoryType::GPULocal + << " memory usage exceeds limit, func:\n" + << lowered_func; return true; } } diff --git a/paddle/cinn/auto_schedule/task/task_optimizer.h b/paddle/cinn/auto_schedule/task/task_optimizer.h index 5d70e540d55ed..849a8f423bcb9 100644 --- a/paddle/cinn/auto_schedule/task/task_optimizer.h +++ b/paddle/cinn/auto_schedule/task/task_optimizer.h @@ -45,7 +45,8 @@ class TaskOptimizer { std::string from; double cost; FunctionGroup functions; - Result(const std::string& from_type) : from(from_type), cost(std::numeric_limits::max()) {} + Result(const std::string& from_type) + : from(from_type), cost(std::numeric_limits::max()) {} }; Result OptimizeByManual(bool need_measure); @@ -53,7 +54,9 @@ class TaskOptimizer { Result OptimizeByEvolution(const TuningOptions& options); // call search candidates once by EvolutionarySearch and prune invalid ones - std::vector SearchOneRound(const TuningOptions& options, std::vector* measure_candidates); + std::vector SearchOneRound( + const TuningOptions& options, + std::vector* measure_candidates); private: // the max retry times if continuously get empty result diff --git a/paddle/cinn/auto_schedule/task/task_registry.h b/paddle/cinn/auto_schedule/task/task_registry.h index 2e57fc7151ebd..749c16c68b49f 100644 --- a/paddle/cinn/auto_schedule/task/task_registry.h +++ b/paddle/cinn/auto_schedule/task/task_registry.h @@ -31,7 +31,8 @@ struct InitialTaskInfo { std::string task_key; ir::ModuleExpr module_expr; - InitialTaskInfo(const std::string& task_key, const ir::ModuleExpr& module_expr) + InitialTaskInfo(const std::string& task_key, + const ir::ModuleExpr& module_expr) : task_key(task_key), module_expr(module_expr) {} }; @@ -45,19 +46,25 @@ class InitialTaskRegistry : public Registry { // Get the initial ModuleExpr of a task. inline const InitialTaskInfo* Get(const std::string& task_key) { - const InitialTaskInfo* task_info = Registry::Find(task_key); - CHECK(task_info) << "InitialTaskInfo [" << task_key << "] is not registered"; + const InitialTaskInfo* task_info = + Registry::Find(task_key); + CHECK(task_info) << "InitialTaskInfo [" << task_key + << "] is not registered"; return task_info; } // Check if the task info with task_key exists; - inline const bool Has(const std::string& task_key) { return nullptr != Registry::Find(task_key); } + inline const bool Has(const std::string& task_key) { + return nullptr != Registry::Find(task_key); + } // Regist the initial ModuleExpr of a task into the map - inline void Regist(const std::string& task_key, const ir::ModuleExpr& module_expr) { + inline void Regist(const std::string& task_key, + const ir::ModuleExpr& module_expr) { std::lock_guard guard(registering_mutex); if (fmap_.count(task_key) == 0) { - InitialTaskInfo* task_info = new InitialTaskInfo(task_key, optim::IRCopy(module_expr)); + InitialTaskInfo* task_info = + new InitialTaskInfo(task_key, optim::IRCopy(module_expr)); __REGISTER__(task_key, task_info); } } @@ -67,7 +74,8 @@ class InitialTaskRegistry : public Registry { CINN_DISALLOW_COPY_AND_ASSIGN(InitialTaskRegistry); // Regist the initial ModuleExpr of a task. - inline InitialTaskInfo* __REGISTER__(const std::string& task_key, InitialTaskInfo* task_info) { + inline InitialTaskInfo* __REGISTER__(const std::string& task_key, + InitialTaskInfo* task_info) { fmap_[task_key] = task_info; const_list_.push_back(task_info); entry_list_.push_back(task_info); diff --git a/paddle/cinn/auto_schedule/task/task_registry_test.cc b/paddle/cinn/auto_schedule/task/task_registry_test.cc index bf68df11481fc..24e823f82fb43 100644 --- a/paddle/cinn/auto_schedule/task/task_registry_test.cc +++ b/paddle/cinn/auto_schedule/task/task_registry_test.cc @@ -34,16 +34,21 @@ DECLARE_bool(cinn_ir_schedule); namespace cinn { namespace auto_schedule { -std::vector CreateTasks(hlir::framework::Graph* graph, const common::Target& target) { +std::vector CreateTasks(hlir::framework::Graph* graph, + const common::Target& target) { // create tasks TaskCreator task_creator; std::vector tasks = task_creator.CreateTuneTaskOpLevel(graph); - const auto& dtype_dict = graph->GetAttrs>("inferdtype"); - const auto& shape_dict = graph->GetAttrs>("infershape"); + const auto& dtype_dict = + graph->GetAttrs>( + "inferdtype"); + const auto& shape_dict = graph->GetAttrs< + absl::flat_hash_map>("infershape"); std::unique_ptr op_lowerer = - std::make_unique(dtype_dict, shape_dict, target); + std::make_unique( + dtype_dict, shape_dict, target); for (TuneTask& task : tasks) { task.Initialize(shape_dict, dtype_dict, op_lowerer.get()); VLOG(3) << "Add a task with serialized_key:\n" << task.serialized_key; @@ -52,7 +57,8 @@ std::vector CreateTasks(hlir::framework::Graph* graph, const common::T return tasks; } -std::shared_ptr CreateAddProgram(const common::Target& target) { +std::shared_ptr CreateAddProgram( + const common::Target& target) { frontend::NetBuilder builder("test"); auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A"); @@ -64,7 +70,7 @@ std::shared_ptr CreateAddProgram(const common::Target& t TEST(TestTaskRegistry, basic) { FLAGS_auto_schedule_use_cost_model = true; - FLAGS_cinn_ir_schedule = true; + FLAGS_cinn_ir_schedule = true; #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -72,7 +78,7 @@ TEST(TestTaskRegistry, basic) { Target target = common::DefaultHostTarget(); #endif std::shared_ptr graph = CreateAddProgram(target); - std::vector tasks = CreateTasks(graph.get(), target); + std::vector tasks = CreateTasks(graph.get(), target); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); @@ -89,8 +95,10 @@ TEST(TestTaskRegistry, basic) { ASSERT_EQ(new_expr.GetExprs().size(), module_exprs[i].GetExprs().size()); for (int j = 0; j < new_expr.GetExprs().size(); ++j) { - VLOG(3) << "expr " << j << " of task " << key << " : " << new_expr.GetExprs().at(j); - ASSERT_EQ(utils::GetStreamCnt(new_expr.GetExprs().at(j)), utils::GetStreamCnt(module_exprs[i].GetExprs().at(j))); + VLOG(3) << "expr " << j << " of task " << key << " : " + << new_expr.GetExprs().at(j); + ASSERT_EQ(utils::GetStreamCnt(new_expr.GetExprs().at(j)), + utils::GetStreamCnt(module_exprs[i].GetExprs().at(j))); } } diff --git a/paddle/cinn/auto_schedule/task/tune_task.cc b/paddle/cinn/auto_schedule/task/tune_task.cc index 9e79ec6ac7b4b..a6c11a4e4d58b 100644 --- a/paddle/cinn/auto_schedule/task/tune_task.cc +++ b/paddle/cinn/auto_schedule/task/tune_task.cc @@ -30,15 +30,17 @@ namespace cinn { namespace auto_schedule { -void TuneTask::Initialize(const absl::flat_hash_map& shape_dict, - const absl::flat_hash_map& dtype_dict, - hlir::framework::OpLowerer* lower_handler) { +void TuneTask::Initialize( + const absl::flat_hash_map& + shape_dict, + const absl::flat_hash_map& dtype_dict, + hlir::framework::OpLowerer* lower_handler) { CHECK(lower_handler != nullptr) << "op_lowerer can't be nullptr"; op_lowerer = lower_handler; // Set lowered_funcs and analyze output names. - this->lowered_funcs = op_lowerer->LowerWithoutSchedule(subgraph); - this->output_names = GetOutputNamesFromLoweredFunc(this->lowered_funcs); + this->lowered_funcs = op_lowerer->LowerWithoutSchedule(subgraph); + this->output_names = GetOutputNamesFromLoweredFunc(this->lowered_funcs); this->serialized_key = SerializeToString(shape_dict, dtype_dict); } @@ -50,34 +52,46 @@ std::vector TuneTask::GetLoweredFuncBodyExprs() const { return result; } -std::string TuneTask::SerializeToString(const absl::flat_hash_map& shape_dict, - const absl::flat_hash_map& dtype_dict) { +std::string TuneTask::SerializeToString( + const absl::flat_hash_map& + shape_dict, + const absl::flat_hash_map& dtype_dict) { std::stringstream ss; ss << target << "\n\n"; // print target - // local function to print dtype,shape of out/in variables of the specified node - auto print_node_links_fn = [&](const std::vector>& links, bool is_input) { - int printed_num = 0; - for (auto&& edge : links) { - const auto* var_node = is_input ? edge->source()->safe_as() - : edge->sink()->safe_as(); - CHECK(var_node) << "var node invalid"; - auto sit = shape_dict.find(var_node->id()); - CHECK(sit != shape_dict.end()) << "can't find shape of variable:" << var_node->id(); - auto dit = dtype_dict.find(var_node->id()); - CHECK(dit != dtype_dict.end()) << "can't find dtype of variable:" << var_node->id(); - if (printed_num > 0) { - ss << ", "; - } - ++printed_num; - // TODO(CtfGo): CINN uses the names of input/output NodeData ids as arguments of the LoweredFunc in the Lower - // process, so it will result in different LoweredFuncs for two Nodes even though they represents the same - // operator. Here we add `var_node->id()` into the serialized_key to distinguish them, otherwise AutoTuner will - // get wrong TuningRecords when querying cached results from database. In the future, we should remove - // name-related limit in Lower process, to avoid duplicate tuning tasks with same operators. - ss << var_node->id() << "->" << cinn::common::Type2Str(dit->second) << "[" + utils::Join(sit->second, ",") << "]"; - } - }; + // local function to print dtype,shape of out/in variables of the specified + // node + auto print_node_links_fn = + [&](const std::vector>& links, + bool is_input) { + int printed_num = 0; + for (auto&& edge : links) { + const auto* var_node = + is_input ? edge->source()->safe_as() + : edge->sink()->safe_as(); + CHECK(var_node) << "var node invalid"; + auto sit = shape_dict.find(var_node->id()); + CHECK(sit != shape_dict.end()) + << "can't find shape of variable:" << var_node->id(); + auto dit = dtype_dict.find(var_node->id()); + CHECK(dit != dtype_dict.end()) + << "can't find dtype of variable:" << var_node->id(); + if (printed_num > 0) { + ss << ", "; + } + ++printed_num; + // TODO(CtfGo): CINN uses the names of input/output NodeData ids as + // arguments of the LoweredFunc in the Lower process, so it will + // result in different LoweredFuncs for two Nodes even though they + // represents the same operator. Here we add `var_node->id()` into the + // serialized_key to distinguish them, otherwise AutoTuner will get + // wrong TuningRecords when querying cached results from database. In + // the future, we should remove name-related limit in Lower process, + // to avoid duplicate tuning tasks with same operators. + ss << var_node->id() << "->" << cinn::common::Type2Str(dit->second) + << "[" + utils::Join(sit->second, ",") << "]"; + } + }; // print each node of the subgraph ss << "Group {\n"; diff --git a/paddle/cinn/auto_schedule/task/tune_task.h b/paddle/cinn/auto_schedule/task/tune_task.h index c253878b94fa6..2921f41a0f5fd 100644 --- a/paddle/cinn/auto_schedule/task/tune_task.h +++ b/paddle/cinn/auto_schedule/task/tune_task.h @@ -36,11 +36,14 @@ namespace auto_schedule { class TuneTask { public: TuneTask() = default; - TuneTask(std::shared_ptr group) : subgraph(group) {} + TuneTask(std::shared_ptr group) + : subgraph(group) {} // Initialize a task - void Initialize(const absl::flat_hash_map& shape_dict, - const absl::flat_hash_map& dtype_dict, - hlir::framework::OpLowerer* lower_handler); + void Initialize( + const absl::flat_hash_map& + shape_dict, + const absl::flat_hash_map& dtype_dict, + hlir::framework::OpLowerer* lower_handler); // Extract bodies in lowered_funcs() and return std::vector GetLoweredFuncBodyExprs() const; @@ -55,14 +58,16 @@ class TuneTask { std::vector lowered_funcs; // names of the output arguments of lowered_funcs_ std::unordered_set output_names; - // serialized string of this task, it contains struct,shape,dtype,input/output variable name - // of the subgraph and can be further used to hash + // serialized string of this task, it contains struct,shape,dtype,input/output + // variable name of the subgraph and can be further used to hash std::string serialized_key; private: // Serialize this task as a string contains specific fields of it - std::string SerializeToString(const absl::flat_hash_map& shape_dict, - const absl::flat_hash_map& dtype_dict); + std::string SerializeToString( + const absl::flat_hash_map& + shape_dict, + const absl::flat_hash_map& dtype_dict); }; } // namespace auto_schedule diff --git a/paddle/cinn/auto_schedule/task/tune_task_test.cc b/paddle/cinn/auto_schedule/task/tune_task_test.cc index f434af1187368..b0b4a27e9b6d2 100755 --- a/paddle/cinn/auto_schedule/task/tune_task_test.cc +++ b/paddle/cinn/auto_schedule/task/tune_task_test.cc @@ -49,10 +49,10 @@ Program CreateAddProgram() { constexpr int N = 24; NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {M, N}, "A"); - auto b = builder.CreateInput(Float(32), {M, N}, "B"); - auto c = builder.Add(a, b); - auto d = builder.Add(a, c); + auto a = builder.CreateInput(Float(32), {M, N}, "A"); + auto b = builder.CreateInput(Float(32), {M, N}, "B"); + auto c = builder.Add(a, b); + auto d = builder.Add(a, c); auto program = builder.Build(); return program; @@ -65,17 +65,20 @@ TEST(TuneTask, GraphToUnoptLoweredFunc_NoPass) { #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); #else - Target target = common::DefaultHostTarget(); + Target target = common::DefaultHostTarget(); #endif Program prog = CreateAddProgram(); - auto graph = std::make_shared(prog, target); + auto graph = std::make_shared(prog, target); TaskCreator task_creator; std::vector tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); ASSERT_EQ(tasks.size(), 2UL); - const auto& shape_dict = graph->GetAttrs>("infershape"); - const auto& dtype_dict = graph->GetAttrs>("inferdtype"); + const auto& shape_dict = graph->GetAttrs< + absl::flat_hash_map>("infershape"); + const auto& dtype_dict = + graph->GetAttrs>( + "inferdtype"); OpLowerer op_lowerer(dtype_dict, shape_dict, target); std::stringstream ss; @@ -127,7 +130,7 @@ TEST(TuneTask, GraphToUnoptLoweredFunc_NoPass) { } )ROC"; #else - std::string target_str = R"ROC( + std::string target_str = R"ROC( { ScheduleBlock(root) { @@ -173,10 +176,10 @@ TEST(TuneTask, GraphToUnoptLoweredFunc_ApplyPass) { #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); #else - Target target = common::DefaultHostTarget(); + Target target = common::DefaultHostTarget(); #endif Program prog = CreateAddProgram(); - auto graph = std::make_shared(prog, target); + auto graph = std::make_shared(prog, target); ApplyPass(graph.get(), "OpFusionPass"); TaskCreator task_creator; @@ -184,8 +187,11 @@ TEST(TuneTask, GraphToUnoptLoweredFunc_ApplyPass) { ASSERT_EQ(tasks.size(), 1UL); - const auto& shape_dict = graph->GetAttrs>("infershape"); - const auto& dtype_dict = graph->GetAttrs>("inferdtype"); + const auto& shape_dict = graph->GetAttrs< + absl::flat_hash_map>("infershape"); + const auto& dtype_dict = + graph->GetAttrs>( + "inferdtype"); OpLowerer op_lowerer(dtype_dict, shape_dict, target); @@ -236,7 +242,7 @@ TEST(TuneTask, GraphToUnoptLoweredFunc_ApplyPass) { )ROC"; #else - std::string target_str = R"ROC( + std::string target_str = R"ROC( { ScheduleBlock(root) { @@ -277,16 +283,20 @@ TEST(TuneTask, SerializeToString) { #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); #else - Target target = common::DefaultHostTarget(); + Target target = common::DefaultHostTarget(); #endif Program prog = CreateAddProgram(); - auto graph = std::make_shared(prog, target); + auto graph = std::make_shared(prog, target); TaskCreator task_creator; - std::vector single_tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); - - const auto& shape_dict = graph->GetAttrs>("infershape"); - const auto& dtype_dict = graph->GetAttrs>("inferdtype"); + std::vector single_tasks = + task_creator.CreateTuneTaskOpLevel(graph.get()); + + const auto& shape_dict = graph->GetAttrs< + absl::flat_hash_map>("infershape"); + const auto& dtype_dict = + graph->GetAttrs>( + "inferdtype"); OpLowerer op_lowerer(dtype_dict, shape_dict, target); ASSERT_EQ(single_tasks.size(), 2UL); for (auto&& task : single_tasks) { @@ -301,7 +311,7 @@ Group { } )ROC"; #else - std::string single_add_str = R"ROC(Target + std::string single_add_str = R"ROC(Target Group { (var_1->float32[32,24]) = elementwise_add(A->float32[32,24], B->float32[32,24]) @@ -311,7 +321,8 @@ Group { EXPECT_EQ(single_tasks[0].serialized_key, single_add_str); ApplyPass(graph.get(), "OpFusionPass"); - std::vector fused_tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); + std::vector fused_tasks = + task_creator.CreateTuneTaskOpLevel(graph.get()); ASSERT_EQ(fused_tasks.size(), 1UL); fused_tasks[0].Initialize(shape_dict, dtype_dict, &op_lowerer); diff --git a/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.cc b/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.cc index 8cfa6067fe95b..38a5de33bb16a 100644 --- a/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.cc +++ b/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.cc @@ -27,7 +27,9 @@ int EfficiencyPriority::NextTaskId() { return -1; } -bool EfficiencyPriority::IsTaskToTune(const TuneTask* task) { return config_.minimum_gain_threshold > 0.0; } +bool EfficiencyPriority::IsTaskToTune(const TuneTask* task) { + return config_.minimum_gain_threshold > 0.0; +} } // namespace auto_schedule } // namespace cinn diff --git a/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.h b/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.h index a42ebc290a0f0..9b37492a32de8 100644 --- a/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.h +++ b/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.h @@ -25,7 +25,8 @@ namespace auto_schedule { // is picking a task with the maximum earnings ratio. class EfficiencyPriority : public TaskScheduler { public: - EfficiencyPriority(const std::vector& tasks, const Config& config) : TaskScheduler(tasks, config) {} + EfficiencyPriority(const std::vector& tasks, const Config& config) + : TaskScheduler(tasks, config) {} const char* Name() const override { return "efficiency_priority"; }; diff --git a/paddle/cinn/auto_schedule/task_scheduler/round_robin.h b/paddle/cinn/auto_schedule/task_scheduler/round_robin.h index bbd862b70e721..a2dc0555aae43 100644 --- a/paddle/cinn/auto_schedule/task_scheduler/round_robin.h +++ b/paddle/cinn/auto_schedule/task_scheduler/round_robin.h @@ -25,7 +25,8 @@ namespace auto_schedule { // is picking a task to tune once a time iteratively. class RoundRobin : public TaskScheduler { public: - RoundRobin(const std::vector& tasks, const Config& config) : TaskScheduler(tasks, config) {} + RoundRobin(const std::vector& tasks, const Config& config) + : TaskScheduler(tasks, config) {} const char* Name() const override { return "round_robin"; }; diff --git a/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.cc b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.cc index c9682ca9adc4a..eed2ad3d66970 100644 --- a/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.cc +++ b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.cc @@ -23,9 +23,10 @@ namespace cinn { namespace auto_schedule { -std::unique_ptr TaskScheduler::Make(const std::vector& tasks, - const Config& config, - const std::string& strategy) { +std::unique_ptr TaskScheduler::Make( + const std::vector& tasks, + const Config& config, + const std::string& strategy) { CHECK_GT(tasks.size(), 0) << "Empty task list"; if (strategy == "round_robin") { return std::make_unique(tasks, config); @@ -37,7 +38,8 @@ std::unique_ptr TaskScheduler::Make(const std::vector& return nullptr; } -TaskScheduler::TaskScheduler(const std::vector& tasks, const Config& config) +TaskScheduler::TaskScheduler(const std::vector& tasks, + const Config& config) : tasks_(&tasks), config_(config), cur_task_id_(0) {} void TaskScheduler::Reset() { cur_task_id_ = 0; } diff --git a/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.h b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.h index 862ab7f2c3314..193f715d0a85f 100644 --- a/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.h +++ b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.h @@ -38,9 +38,10 @@ class TaskScheduler { // Create a TaskScheduler with the specific strategy name // and necessary construct parameters. - static std::unique_ptr Make(const std::vector& tasks, - const Config& config, - const std::string& strategy = "round_robin"); + static std::unique_ptr Make( + const std::vector& tasks, + const Config& config, + const std::string& strategy = "round_robin"); // Reset associated states to schedule at the beginning void Reset(); diff --git a/paddle/cinn/auto_schedule/task_scheduler/task_scheduler_test.cc b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler_test.cc index e75778d27d5de..3f955b4a820b0 100644 --- a/paddle/cinn/auto_schedule/task_scheduler/task_scheduler_test.cc +++ b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler_test.cc @@ -30,7 +30,8 @@ TEST(TaskScheduler, Make) { auto round_robin = TaskScheduler::Make(tasks, config); ASSERT_STREQ(round_robin->Name(), "round_robin"); - auto efficiency_priority = TaskScheduler::Make(tasks, config, "efficiency_priority"); + auto efficiency_priority = + TaskScheduler::Make(tasks, config, "efficiency_priority"); ASSERT_STREQ(efficiency_priority->Name(), "efficiency_priority"); } @@ -48,7 +49,8 @@ TEST(EfficiencyPriorityScheduler, NextTaskId) { std::vector tasks(3); TaskScheduler::Config config; config.minimum_gain_threshold = -1.0; - auto efficiency_priority = TaskScheduler::Make(tasks, config, "efficiency_priority"); + auto efficiency_priority = + TaskScheduler::Make(tasks, config, "efficiency_priority"); ASSERT_EQ(-1, efficiency_priority->NextTaskId()); } diff --git a/paddle/cinn/auto_schedule/tests/performance_comparison_test.cc b/paddle/cinn/auto_schedule/tests/performance_comparison_test.cc index 694e75b5e2b38..79b4dc95d180c 100644 --- a/paddle/cinn/auto_schedule/tests/performance_comparison_test.cc +++ b/paddle/cinn/auto_schedule/tests/performance_comparison_test.cc @@ -32,20 +32,29 @@ #include "paddle/cinn/utils/data_util.h" #include "test/cpp/cinn/program_builder.h" -/* This test is used as a tool to evaluate or compare performance of 3 schedules(no schedule, manual schedule, - * auto-schedule). One can specify which schedules to be evaluated through `FLAGS_evaluate_knobs` and specify which - * operator or model through `--gtest_filter=PerformanceTester.xx`, for example, `FLAGS_evaluate_knobs=4 - * --gtest_filter=PerformanceTester.Matmul` means it will evaluate auto-schedule on Matmul operator. You can refer to - * explanation of following flags or parameters for more detail. +/* This test is used as a tool to evaluate or compare performance of 3 + * schedules(no schedule, manual schedule, auto-schedule). One can specify which + * schedules to be evaluated through `FLAGS_evaluate_knobs` and specify which + * operator or model through `--gtest_filter=PerformanceTester.xx`, for example, + * `FLAGS_evaluate_knobs=4 + * --gtest_filter=PerformanceTester.Matmul` means it will evaluate auto-schedule + * on Matmul operator. You can refer to explanation of following flags or + * parameters for more detail. */ -DEFINE_string(resnet50_model_dir, "./ResNet50", "the path to paddle model resnet50."); +DEFINE_string(resnet50_model_dir, + "./ResNet50", + "the path to paddle model resnet50."); // Flags that control which schedule tests will be run. -// Bit with index 0 controls no schedule test, means options = 1 = "001" will run no schedule test. -// Bit with index 1 controls manual schedule test, means options = 2 = "010" will run manual schedule test. -// Bit with index 2 controls auto schedule test, means options = 4 = "100" will run auto schedule test. -// The default value is -1, which means that this flag is disabled to set the options -DEFINE_int32(evaluate_knobs, -1, "the options to control which schedule tests will be run."); +// Bit with index 0 controls no schedule test, means options = 1 = "001" will +// run no schedule test. Bit with index 1 controls manual schedule test, means +// options = 2 = "010" will run manual schedule test. Bit with index 2 controls +// auto schedule test, means options = 4 = "100" will run auto schedule test. +// The default value is -1, which means that this flag is disabled to set the +// options +DEFINE_int32(evaluate_knobs, + -1, + "the options to control which schedule tests will be run."); DECLARE_int32(cinn_parallel_compile_size); namespace cinn { @@ -64,7 +73,8 @@ class PerformanceTester : public ::testing::Test { int repeat_times = 2; // the num_tuning_rounds for auto tuning int num_tuning_rounds = 2; - // knobs to control which schedules will be measured, refer to FLAGS_evaluate_knobs explanation + // knobs to control which schedules will be measured, refer to + // FLAGS_evaluate_knobs explanation std::bitset<3> evaluate_knobs = 0UL; }; @@ -76,50 +86,66 @@ class PerformanceTester : public ::testing::Test { } VLOG(3) << "evaluate_knobs = " << options_.evaluate_knobs; - auto worker_fn = [this, &program]( - const std::string& schedule_name, BuildRuntimeProgramFn build_fn, bool execute = true) { + auto worker_fn = [this, &program](const std::string& schedule_name, + BuildRuntimeProgramFn build_fn, + bool execute = true) { Context::Global().ResetNameId(); VLOG(3) << "Initialize graph."; auto graph = std::make_shared(program, target_); VLOG(3) << "Apply graph pass."; hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); VLOG(3) << "Build " << schedule_name << " program."; - auto scope = BuildScope(target_, graph); - auto graph_compiler = std::make_unique(target_, scope, graph); - auto runtime_program = (this->*build_fn)(graph.get(), graph_compiler.get()); + auto scope = BuildScope(target_, graph); + auto graph_compiler = + std::make_unique(target_, scope, graph); + auto runtime_program = + (this->*build_fn)(graph.get(), graph_compiler.get()); if (execute) { VLOG(3) << "Execute " << schedule_name << " program."; runtime_program->ExecuteTest(options_.repeat_times); } }; - // if no one is set, build no/manual schedule cases to ensure their build functions are valid + // if no one is set, build no/manual schedule cases to ensure their build + // functions are valid if (options_.evaluate_knobs.none()) { - worker_fn("no schedule", &PerformanceTester::BuildNoScheduleProgram, /* execute */ false); - worker_fn("manual schedule", &PerformanceTester::BuildManualScheduleProgram, /* execute */ false); + worker_fn("no schedule", + &PerformanceTester::BuildNoScheduleProgram, + /* execute */ false); + worker_fn("manual schedule", + &PerformanceTester::BuildManualScheduleProgram, + /* execute */ false); } else { if (options_.evaluate_knobs.test(0)) { worker_fn("no schedule", &PerformanceTester::BuildNoScheduleProgram); } if (options_.evaluate_knobs.test(1)) { - worker_fn("manual schedule", &PerformanceTester::BuildManualScheduleProgram); + worker_fn("manual schedule", + &PerformanceTester::BuildManualScheduleProgram); } if (options_.evaluate_knobs.test(2)) { - worker_fn("auto schedule", &PerformanceTester::BuildAutoScheduleProgram); + worker_fn("auto schedule", + &PerformanceTester::BuildAutoScheduleProgram); } } } protected: - using BuildRuntimeProgramFn = std::unique_ptr (PerformanceTester::*)(Graph*, - GraphCompiler*); - - std::unique_ptr BuildNoScheduleProgram(Graph* graph, GraphCompiler* graph_compiler) { - const auto& dtype_dict = graph->GetAttrs>("inferdtype"); - const auto& shape_dict = graph->GetAttrs>("infershape"); + using BuildRuntimeProgramFn = std::unique_ptr ( + PerformanceTester::*)(Graph*, GraphCompiler*); + + std::unique_ptr BuildNoScheduleProgram( + Graph* graph, GraphCompiler* graph_compiler) { + const auto& dtype_dict = + graph->GetAttrs>( + "inferdtype"); + const auto& shape_dict = graph->GetAttrs< + absl::flat_hash_map>( + "infershape"); std::shared_ptr op_lowerer = - std::make_unique(dtype_dict, shape_dict, target_); + std::make_unique( + dtype_dict, shape_dict, target_); GraphCompiler::CompileOptions compile_options; compile_options.with_instantiate_variables = true; @@ -130,31 +156,36 @@ class PerformanceTester : public ::testing::Test { compile_options.groups = graph->fusion_groups; for (auto group : graph->fusion_groups) { - compile_options.lowered_funcs.push_back(op_lowerer->LowerWithoutSchedule(group)); + compile_options.lowered_funcs.push_back( + op_lowerer->LowerWithoutSchedule(group)); } - VLOG(3) << "===========================No Schedule LoweredFunc Begin==========================="; + VLOG(3) << "===========================No Schedule LoweredFunc " + "Begin==========================="; for (const auto& funcvec : compile_options.lowered_funcs) { for (const auto& func : funcvec) { VLOG(3) << func; } } - VLOG(3) << "===========================No Schedule LoweredFunc End============================="; + VLOG(3) << "===========================No Schedule LoweredFunc " + "End============================="; return graph_compiler->Build(compile_options).runtime_program; } - std::unique_ptr BuildManualScheduleProgram(Graph* graph, GraphCompiler* graph_compiler) { + std::unique_ptr BuildManualScheduleProgram( + Graph* graph, GraphCompiler* graph_compiler) { return graph_compiler->Build(); } - std::unique_ptr BuildAutoScheduleProgram(Graph* graph, GraphCompiler* graph_compiler) { + std::unique_ptr BuildAutoScheduleProgram( + Graph* graph, GraphCompiler* graph_compiler) { auto tuner = std::make_unique(target_, graph); AutoTuner::Config tuning_config; TuningOptions tuning_options; - tuning_options.num_tuning_rounds = options_.num_tuning_rounds; - tuning_options.num_measure_trials = 2; + tuning_options.num_tuning_rounds = options_.num_tuning_rounds; + tuning_options.num_measure_trials = 2; tuning_options.num_samples_per_iteration = 2; tuner->Initialize(tuning_config, graph_compiler); @@ -164,13 +195,15 @@ class PerformanceTester : public ::testing::Test { compile_options.with_instantiate_variables = true; compile_options.Apply(tuning_result); - VLOG(3) << "===========================Auto Schedule LoweredFunc Begin==========================="; + VLOG(3) << "===========================Auto Schedule LoweredFunc " + "Begin==========================="; for (const auto& funcvec : compile_options.lowered_funcs) { for (const auto& func : funcvec) { VLOG(3) << func; } } - VLOG(3) << "===========================Auto Schedule LoweredFunc End============================="; + VLOG(3) << "===========================Auto Schedule LoweredFunc " + "End============================="; return graph_compiler->Build(compile_options).runtime_program; } @@ -185,35 +218,42 @@ class PerformanceTester : public ::testing::Test { constexpr int batch_size = 2; -TEST_F(PerformanceTester, Mul) { Evaluate(tests::OpBuilder("mul").Build({{"X", {32, 16}}, {"Y", {16, 32}}})); } +TEST_F(PerformanceTester, Mul) { + Evaluate(tests::OpBuilder("mul").Build({{"X", {32, 16}}, {"Y", {16, 32}}})); +} TEST_F(PerformanceTester, Add) { - Evaluate(tests::OpBuilder("elementwise_add").Build({{"X", {1, 56, 56, 256}}, {"Y", {1, 56, 56, 256}}})); + Evaluate(tests::OpBuilder("elementwise_add") + .Build({{"X", {1, 56, 56, 256}}, {"Y", {1, 56, 56, 256}}})); } TEST_F(PerformanceTester, Matmul) { - Evaluate(tests::OpBuilder("matmul").Build({{"X", {batch_size, 2048}}, {"Y", {2048, 1000}}})); + Evaluate(tests::OpBuilder("matmul").Build( + {{"X", {batch_size, 2048}}, {"Y", {2048, 1000}}})); } -TEST_F(PerformanceTester, Relu) { Evaluate(tests::OpBuilder("relu").Build({{"X", {batch_size, 64, 56, 56}}})); } +TEST_F(PerformanceTester, Relu) { + Evaluate(tests::OpBuilder("relu").Build({{"X", {batch_size, 64, 56, 56}}})); +} TEST_F(PerformanceTester, Conv2d) { std::vector strides{2, 2}; std::vector paddings{3, 3}; std::vector dilations{1, 1}; - int groups = 1; - std::string conv_type = "forward"; - std::string data_format = "NCHW"; + int groups = 1; + std::string conv_type = "forward"; + std::string data_format = "NCHW"; std::string padding_algorithm = "EXPLICIT"; - Evaluate(tests::OpBuilder("conv2d").Build({{"X", {batch_size, 3, 224, 224}}, {"W", {64, 3, 7, 7}}}, - {{"stride", strides}, - {"padding", paddings}, - {"dilation", dilations}, - {"groups", groups}, - {"conv_type", conv_type}, - {"data_format", data_format}, - {"padding_algorithm", padding_algorithm}})); + Evaluate(tests::OpBuilder("conv2d").Build( + {{"X", {batch_size, 3, 224, 224}}, {"W", {64, 3, 7, 7}}}, + {{"stride", strides}, + {"padding", paddings}, + {"dilation", dilations}, + {"groups", groups}, + {"conv_type", conv_type}, + {"data_format", data_format}, + {"padding_algorithm", padding_algorithm}})); } TEST_F(PerformanceTester, Pool2d) { @@ -222,24 +262,25 @@ TEST_F(PerformanceTester, Pool2d) { std::vector ksize{3, 3}; std::vector strides{2, 2}; std::vector paddings{1, 1, 1, 1}; - bool ceil_mode = false; - bool exclusive = true; - bool global_pooling = false; - std::string data_format = "NCHW"; - bool adaptive = false; + bool ceil_mode = false; + bool exclusive = true; + bool global_pooling = false; + std::string data_format = "NCHW"; + bool adaptive = false; std::string padding_algorithm = "EXPLICIT"; - Evaluate(tests::OpBuilder("pool2d").Build({{"X", {batch_size, 64, 112, 112}}}, - {{"pool_type", pooling_type}, - {"kernel_size", ksize}, - {"stride_size", strides}, - {"padding_size", paddings}, - {"ceil_mode", ceil_mode}, - {"exclusive", exclusive}, - {"global_pooling", global_pooling}, - {"data_format", data_format}, - {"adaptive", adaptive}, - {"padding_algorithm", padding_algorithm}})); + Evaluate(tests::OpBuilder("pool2d").Build( + {{"X", {batch_size, 64, 112, 112}}}, + {{"pool_type", pooling_type}, + {"kernel_size", ksize}, + {"stride_size", strides}, + {"padding_size", paddings}, + {"ceil_mode", ceil_mode}, + {"exclusive", exclusive}, + {"global_pooling", global_pooling}, + {"data_format", data_format}, + {"adaptive", adaptive}, + {"padding_algorithm", padding_algorithm}})); } TEST_F(PerformanceTester, BatchNorm) { @@ -248,60 +289,73 @@ TEST_F(PerformanceTester, BatchNorm) { std::vector bias_shape{64}; std::vector mean_shape{64}; std::vector variance_shape{64}; - float epsilon = 1e-5f; - float momentum = 0.9f; + float epsilon = 1e-5f; + float momentum = 0.9f; const std::string& data_layout = "NCHW"; - Evaluate( - tests::OpBuilder("batch_norm") - .Build( - {{"X", {batch_size, 64, 112, 112}}, {"scale", {64}}, {"bias", {64}}, {"mean", {64}}, {"variance", {64}}}, - {{"epsilon", epsilon}, {"momentum", momentum}, {"data_layout", data_layout}})); + Evaluate(tests::OpBuilder("batch_norm") + .Build({{"X", {batch_size, 64, 112, 112}}, + {"scale", {64}}, + {"bias", {64}}, + {"mean", {64}}, + {"variance", {64}}}, + {{"epsilon", epsilon}, + {"momentum", momentum}, + {"data_layout", data_layout}})); } TEST_F(PerformanceTester, Reshape) { std::vector output_shape{batch_size, 2048}; - Evaluate(tests::OpBuilder("reshape").Build({{"X", {batch_size, 2048, 1, 1}}}, {{"shape", output_shape}})); + Evaluate(tests::OpBuilder("reshape").Build({{"X", {batch_size, 2048, 1, 1}}}, + {{"shape", output_shape}})); } TEST_F(PerformanceTester, Softmax) { - std::vector axes = {-1}; - std::string mode = "fast"; + std::vector axes = {-1}; + std::string mode = "fast"; std::string data_format = "AnyLayout"; - Evaluate(tests::OpBuilder("softmax").Build({{"X", {batch_size, 1000}}}, - {{"axes", axes}, {"mode", mode}, {"data_format", data_format}})); + Evaluate(tests::OpBuilder("softmax").Build( + {{"X", {batch_size, 1000}}}, + {{"axes", axes}, {"mode", mode}, {"data_format", data_format}})); } TEST_F(PerformanceTester, Scale) { - float scale = 1.0f; - float bias = 0.0f; + float scale = 1.0f; + float bias = 0.0f; bool bias_after_scale = true; - Evaluate(tests::OpBuilder("scale").Build({{"X", {batch_size, 1000}}}, - {{"scale", scale}, {"bias", bias}, {"bias_after_scale", bias_after_scale}})); + Evaluate(tests::OpBuilder("scale").Build( + {{"X", {batch_size, 1000}}}, + {{"scale", scale}, + {"bias", bias}, + {"bias_after_scale", bias_after_scale}})); } TEST_F(PerformanceTester, LookupTable) { int64_t padding_idx = -1; - Evaluate( - tests::OpBuilder("lookup_table") - .Build({{"table", {50001, 768}}, {"ids", {10, 128, 1}, common::Int(64)}}, {{"padding_idx", padding_idx}})); + Evaluate(tests::OpBuilder("lookup_table") + .Build({{"table", {50001, 768}}, + {"ids", {10, 128, 1}, common::Int(64)}}, + {{"padding_idx", padding_idx}})); } TEST_F(PerformanceTester, Gather) { int axis = 3; Evaluate(tests::OpBuilder("gather").Build( - {{"operand", {10, 12, 128, 512}}, {"index", {1, 1, 1, 128}, common::Int(32)}}, {{"axis", axis}})); + {{"operand", {10, 12, 128, 512}}, + {"index", {1, 1, 1, 128}, common::Int(32)}}, + {{"axis", axis}})); } // paddle model test TEST_F(PerformanceTester, ResNet50) { CHECK_NE(FLAGS_resnet50_model_dir, ""); - std::unordered_map> feeds = {{"inputs", {batch_size, 3, 224, 224}}}; + std::unordered_map> feeds = { + {"inputs", {batch_size, 3, 224, 224}}}; Evaluate(cinn::frontend::PaddleModelConvertor(common::DefaultNVGPUTarget()) .LoadModel(FLAGS_resnet50_model_dir, true, feeds)); } diff --git a/paddle/cinn/backends/_x86_builtin_source.cc b/paddle/cinn/backends/_x86_builtin_source.cc index 1fc10c1d3ce59..bc698c530f5e2 100644 --- a/paddle/cinn/backends/_x86_builtin_source.cc +++ b/paddle/cinn/backends/_x86_builtin_source.cc @@ -62,7 +62,8 @@ struct StackVec { memcpy(&res.data_[0], (const value_type*)base + offset, num_bytes()); } - static self_type Load(const void* base, const StackVec& offset) { + static self_type Load(const void* base, + const StackVec& offset) { self_type res; for (size_t i = 0; i < Num; i++) { res.data_[i] = ((const value_type*)base)[offset[i]]; @@ -150,12 +151,16 @@ struct ExternalVec { // AVX256 load //@{ inline __m256 cinn_avx256_load(const float* dst) { return _mm256_load_ps(dst); } -inline __m256d cinn_avx256_load(const double* dst) { return _mm256_load_pd(dst); } +inline __m256d cinn_avx256_load(const double* dst) { + return _mm256_load_pd(dst); +} //@} // AVX512 load //@{ inline __m512 cinn_avx512_load(const float* dst) { return _mm512_load_ps(dst); } -inline __m512d cinn_avx512_load(const double* dst) { return _mm512_load_pd(dst); } +inline __m512d cinn_avx512_load(const double* dst) { + return _mm512_load_pd(dst); +} //@} // FP32x8 * FP32x8 @@ -320,10 +325,18 @@ inline void cinn_avx512_div(double* dst, double* a, double* b) { inline __m512 cinn_avx512_add(const __m512& a, const __m512& b); -inline __m256 cinn_avx256_add_float(const __m256& a, const __m256& b) { return _mm256_add_ps(a, b); } -inline __m256d cinn_avx256_add_double(const __m256d& a, const __m256d& b) { return _mm256_add_pd(a, b); } -inline __m512 cinn_avx512_add_float(const __m512& a, const __m512& b) { return _mm512_add_ps(a, b); } -inline __m512d cinn_avx512_add_double(const __m512d& a, const __m512d& b) { return _mm512_add_pd(a, b); } +inline __m256 cinn_avx256_add_float(const __m256& a, const __m256& b) { + return _mm256_add_ps(a, b); +} +inline __m256d cinn_avx256_add_double(const __m256d& a, const __m256d& b) { + return _mm256_add_pd(a, b); +} +inline __m512 cinn_avx512_add_float(const __m512& a, const __m512& b) { + return _mm512_add_ps(a, b); +} +inline __m512d cinn_avx512_add_double(const __m512d& a, const __m512d& b) { + return _mm512_add_pd(a, b); +} //! set1 // @{ @@ -335,38 +348,82 @@ inline __m512d cinn_avx512_set1(double value) { return _mm512_set1_pd(value); } //! store // @{ -inline void cinn_avx512_store(float* dst, const __m512& x) { _mm512_store_ps(dst, x); } -inline void cinn_avx512_store(double* dst, const __m512d& x) { _mm512_store_pd(dst, x); } -inline void cinn_avx256_store(float* dst, const __m256& x) { _mm256_store_ps(dst, x); } -inline void cinn_avx256_store(double* dst, const __m256d& x) { _mm256_store_pd(dst, x); } +inline void cinn_avx512_store(float* dst, const __m512& x) { + _mm512_store_ps(dst, x); +} +inline void cinn_avx512_store(double* dst, const __m512d& x) { + _mm512_store_pd(dst, x); +} +inline void cinn_avx256_store(float* dst, const __m256& x) { + _mm256_store_ps(dst, x); +} +inline void cinn_avx256_store(double* dst, const __m256d& x) { + _mm256_store_pd(dst, x); +} // @} //! add // @{ -inline __m256 cinn_avx256_add(const __m256& a, const __m256& b) { return _mm256_add_ps(a, b); } -inline __m256d cinn_avx256_add(const __m256d& a, const __m256d& b) { return _mm256_add_pd(a, b); } -inline __m512 cinn_avx512_add(const __m512& a, const __m512& b) { return _mm512_add_ps(a, b); } -inline __m512d cinn_avx512_add(const __m512d& a, const __m512d& b) { return _mm512_add_pd(a, b); } +inline __m256 cinn_avx256_add(const __m256& a, const __m256& b) { + return _mm256_add_ps(a, b); +} +inline __m256d cinn_avx256_add(const __m256d& a, const __m256d& b) { + return _mm256_add_pd(a, b); +} +inline __m512 cinn_avx512_add(const __m512& a, const __m512& b) { + return _mm512_add_ps(a, b); +} +inline __m512d cinn_avx512_add(const __m512d& a, const __m512d& b) { + return _mm512_add_pd(a, b); +} // @} //! mul // @{ -inline __m256 cinn_avx256_mul(const __m256& a, const __m256& b) { return _mm256_mul_ps(a, b); } -inline __m256d cinn_avx256_mul(const __m256d& a, const __m256d& b) { return _mm256_mul_pd(a, b); } -inline __m512 cinn_avx512_mul(const __m512& a, const __m512& b) { return _mm512_mul_ps(a, b); } -inline __m512d cinn_avx512_mul(const __m512d& a, const __m512d& b) { return _mm512_mul_pd(a, b); } +inline __m256 cinn_avx256_mul(const __m256& a, const __m256& b) { + return _mm256_mul_ps(a, b); +} +inline __m256d cinn_avx256_mul(const __m256d& a, const __m256d& b) { + return _mm256_mul_pd(a, b); +} +inline __m512 cinn_avx512_mul(const __m512& a, const __m512& b) { + return _mm512_mul_ps(a, b); +} +inline __m512d cinn_avx512_mul(const __m512d& a, const __m512d& b) { + return _mm512_mul_pd(a, b); +} // @} //! fma // @{ -inline __m128 cinn_avx128_fma(const __m128& a, const __m128& b, const __m128& c) { return _mm_fmadd_ps(a, b, c); } -inline __m128d cinn_avx128_fma(const __m128d& a, const __m128d& b, const __m128d& c) { return _mm_fmadd_pd(a, b, c); } -inline __m256 cinn_avx256_fma(const __m256& a, const __m256& b, const __m256& c) { return _mm256_fmadd_ps(a, b, c); } -inline __m256d cinn_avx256_fma(const __m256d& a, const __m256d& b, const __m256d& c) { +inline __m128 cinn_avx128_fma(const __m128& a, + const __m128& b, + const __m128& c) { + return _mm_fmadd_ps(a, b, c); +} +inline __m128d cinn_avx128_fma(const __m128d& a, + const __m128d& b, + const __m128d& c) { + return _mm_fmadd_pd(a, b, c); +} +inline __m256 cinn_avx256_fma(const __m256& a, + const __m256& b, + const __m256& c) { + return _mm256_fmadd_ps(a, b, c); +} +inline __m256d cinn_avx256_fma(const __m256d& a, + const __m256d& b, + const __m256d& c) { return _mm256_fmadd_pd(a, b, c); } -inline __m512 cinn_avx512_fma(const __m512& a, const __m512& b, const __m512& c) { return _mm512_fmadd_ps(a, b, c); } -inline __m512d cinn_avx512_fma(const __m512d& a, const __m512d& b, const __m512d& c) { +inline __m512 cinn_avx512_fma(const __m512& a, + const __m512& b, + const __m512& c) { + return _mm512_fmadd_ps(a, b, c); +} +inline __m512d cinn_avx512_fma(const __m512d& a, + const __m512d& b, + const __m512d& c) { return _mm512_fmadd_pd(a, b, c); } // @} diff --git a/paddle/cinn/backends/codegen_c.cc b/paddle/cinn/backends/codegen_c.cc index 453376d774d0e..239ba9cf59226 100644 --- a/paddle/cinn/backends/codegen_c.cc +++ b/paddle/cinn/backends/codegen_c.cc @@ -62,7 +62,8 @@ void CodeGenC::Compile(const ir::Module &module, const Outputs &outputs) { CodeGenC::CodeGenC(Target target) : ir::IrPrinter(ss_) {} -std::string CodeGenC::Compile(const ir::Module &module, OutputKind output_kind) { +std::string CodeGenC::Compile(const ir::Module &module, + OutputKind output_kind) { if (output_kind == OutputKind::CHeader) { GenerateHeaderFile(module); } else if (output_kind == OutputKind::CImpl) { @@ -120,9 +121,12 @@ std::string CodeGenC::GetTypeName(Type type) { if (type.is_customized_type()) { CHECK(!type.customized_type().empty()) << "customized_type can't be empty."; auto customized_name = type.customized_type(); - // get name of a cuda built-in vector type, it is started with a 'CudaVectorType::' prefix - if (utils::Startswith(customized_name, common::customized_type::kcuda_builtin_vector_t)) { - customized_name.erase(0, strlen(common::customized_type::kcuda_builtin_vector_t)); + // get name of a cuda built-in vector type, it is started with a + // 'CudaVectorType::' prefix + if (utils::Startswith(customized_name, + common::customized_type::kcuda_builtin_vector_t)) { + customized_name.erase( + 0, strlen(common::customized_type::kcuda_builtin_vector_t)); } return customized_name; } @@ -188,8 +192,8 @@ void CodeGenC::Visit(const ir::Not *op) { } void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v()); } void CodeGenC::Visit(const ir::For *op) { - Expr extent = op->extent; - Expr min = op->min; + Expr extent = op->extent; + Expr min = op->min; int num_task = 1; if (op->is_parallel()) { os() << "int num_task = max_concurrency();\n"; @@ -204,10 +208,10 @@ void CodeGenC::Visit(const ir::For *op) { Print((op->extent + num_task_var - 1) / num_task_var); os() << ";\n"; CHECK_EQ(min.as_int32(), 0); - auto task_id = Var("task_id"); + auto task_id = Var("task_id"); auto n_per_task = Var("n_per_task"); - min = task_id * n_per_task; - extent = (task_id + 1) * n_per_task; + min = task_id * n_per_task; + extent = (task_id + 1) * n_per_task; DoIndent(); } os() << "for ("; @@ -344,7 +348,8 @@ void CodeGenC::Visit(const ir::Call *op) { PrintCallArgs(op); os() << ")"; } else if (op->is_extern_call()) { - const auto &fn_name = ExternFunctionEmitterRegistry::Global().Lookup(ExternFuncID{backend_C, op->name.c_str()}); + const auto &fn_name = ExternFunctionEmitterRegistry::Global().Lookup( + ExternFuncID{backend_C, op->name.c_str()}); if (!fn_name.empty()) { ExternFunctionLLVMEmitter emitter(fn_name); emitter.BindCodeGen(this); @@ -527,11 +532,13 @@ void CodeGenC::Visit(const ir::Let *op) { } void CodeGenC::Visit(const ir::Reduce *op) { - LOG(FATAL) << "Reduce IR is just for internal representation, should not be used for CodeGen."; + LOG(FATAL) << "Reduce IR is just for internal representation, should not be " + "used for CodeGen."; } void CodeGenC::Visit(const ir::Ramp *op) { - os() << "StackVec<" << op->lanes << "," << GetTypeRepr(op->type().ElementOf()) << ">::Ramp("; + os() << "StackVec<" << op->lanes << "," << GetTypeRepr(op->type().ElementOf()) + << ">::Ramp("; Print(op->base); os() << ", "; Print(op->stride); @@ -541,7 +548,8 @@ void CodeGenC::Visit(const ir::Ramp *op) { } void CodeGenC::Visit(const ir::Broadcast *op) { - os() << "StackVec<" << op->lanes << "," << GetTypeRepr(op->type().ElementOf()) << ">::Broadcast("; + os() << "StackVec<" << op->lanes << "," << GetTypeRepr(op->type().ElementOf()) + << ">::Broadcast("; Print(op->value); os() << ", "; os() << op->lanes << ")"; @@ -564,7 +572,9 @@ void CodeGenC::PrintCastExpr(const std::string &type, Expr e) { os() << ")"; } -void CodeGenC::PrintShape(const std::vector &shape, char leftb, char rightb) { +void CodeGenC::PrintShape(const std::vector &shape, + char leftb, + char rightb) { os() << leftb << " "; for (int i = 0; i < shape.size() - 1; i++) { @@ -582,22 +592,31 @@ void CodeGenC::Visit(const ir::_LoweredFunc_ *op) { DoIndent(); - CHECK_EQ(op->alloc_output_buffer_exprs.size(), op->dealloc_output_buffer_exprs.size()) + CHECK_EQ(op->alloc_output_buffer_exprs.size(), + op->dealloc_output_buffer_exprs.size()) << "the count of allocation and deallocaton expressions is not match"; std::vector new_body; - std::vector create_temp_buffers = op->PrepareCreateTempBufferExprs(); - std::vector alloca_temp_buffers = op->PrepareAllocTempBufferExprs(); + std::vector create_temp_buffers = op->PrepareCreateTempBufferExprs(); + std::vector alloca_temp_buffers = op->PrepareAllocTempBufferExprs(); std::vector dealloca_temp_buffers = op->PrepareDeallocTempBufferExprs(); -#define APPEND_TO_NEW_BODY(field__) new_body.insert(std::end(new_body), std::begin(op->field__), std::end(op->field__)); +#define APPEND_TO_NEW_BODY(field__) \ + new_body.insert( \ + std::end(new_body), std::begin(op->field__), std::end(op->field__)); APPEND_TO_NEW_BODY(argument_prepare_exprs) - new_body.insert(std::end(new_body), std::begin(create_temp_buffers), std::end(create_temp_buffers)); + new_body.insert(std::end(new_body), + std::begin(create_temp_buffers), + std::end(create_temp_buffers)); APPEND_TO_NEW_BODY(alloc_output_buffer_exprs) - new_body.insert(std::end(new_body), std::begin(alloca_temp_buffers), std::end(alloca_temp_buffers)); + new_body.insert(std::end(new_body), + std::begin(alloca_temp_buffers), + std::end(alloca_temp_buffers)); APPEND_TO_NEW_BODY(buffer_data_cast_exprs) new_body.push_back(op->body); - new_body.insert(std::end(new_body), std::begin(dealloca_temp_buffers), std::end(dealloca_temp_buffers)); + new_body.insert(std::end(new_body), + std::begin(dealloca_temp_buffers), + std::end(dealloca_temp_buffers)); APPEND_TO_NEW_BODY(dealloc_output_buffer_exprs) Expr func_body = ir::Block::Make(new_body); @@ -618,7 +637,8 @@ void CodeGenC::PrintFileGuardOpen(const std::string &name) { os() << "\n"; } void CodeGenC::PrintFileGuardClose(const std::string &module_name) { - os() << utils::StringFormat("#endif // _%s_CINN_H_\n", Uppercase(module_name).c_str()); + os() << utils::StringFormat("#endif // _%s_CINN_H_\n", + Uppercase(module_name).c_str()); } void CodeGenC::PrintBufferCreation(const std::vector &buffers) { @@ -626,10 +646,13 @@ void CodeGenC::PrintBufferCreation(const std::vector &buffers) { // Ignore the buffer in other devices. if (!buffer->is_on_host()) continue; DoIndent(); - auto buffer_ptr_type = Type().set_customized_type(common::customized_type::kbuffer_t).set_cpp_handle(); - Var variable = ir::_Var_::Make(buffer->name, buffer_ptr_type); - auto expr = ir::intrinsics::BufferCreate::Make(buffer); - expr = ir::Let::Make(variable, expr); + auto buffer_ptr_type = + Type() + .set_customized_type(common::customized_type::kbuffer_t) + .set_cpp_handle(); + Var variable = ir::_Var_::Make(buffer->name, buffer_ptr_type); + auto expr = ir::intrinsics::BufferCreate::Make(buffer); + expr = ir::Let::Make(variable, expr); Print(expr); os() << ";\n"; } @@ -711,7 +734,9 @@ void CodeGenC::PrintStackVecType(Type type, int lanes) { void CodeGenC::Visit(const ir::PrimitiveNode *op) { CINN_NOT_IMPLEMENTED } void CodeGenC::Visit(const ir::_BufferRange_ *op) { CINN_NOT_IMPLEMENTED } void CodeGenC::Visit(const ir::ScheduleBlock *op) { CINN_NOT_IMPLEMENTED } -void CodeGenC::Visit(const ir::ScheduleBlockRealize *op) { CINN_NOT_IMPLEMENTED } +void CodeGenC::Visit(const ir::ScheduleBlockRealize *op) { + CINN_NOT_IMPLEMENTED +} void CodeGenC::Visit(const ir::IntrinsicOp *op) { switch (op->getKind()) { @@ -841,11 +866,13 @@ std::string ReadWholeFile(const std::string &path) { } void CodeGenC::PrintBuiltinCodes() { - CHECK(!FLAGS_cinn_x86_builtin_code_root.empty()) << "The flag cinn_x86_builtin_code_root should be set first"; + CHECK(!FLAGS_cinn_x86_builtin_code_root.empty()) + << "The flag cinn_x86_builtin_code_root should be set first"; const std::string x86_code_file = "_x86_builtin_source.cc"; - auto source = ReadWholeFile(FLAGS_cinn_x86_builtin_code_root + "/" + x86_code_file); + auto source = + ReadWholeFile(FLAGS_cinn_x86_builtin_code_root + "/" + x86_code_file); os() << source << "\n"; } diff --git a/paddle/cinn/backends/codegen_c.h b/paddle/cinn/backends/codegen_c.h index 479300a1d6f38..a44f99fea606f 100755 --- a/paddle/cinn/backends/codegen_c.h +++ b/paddle/cinn/backends/codegen_c.h @@ -76,7 +76,9 @@ class CodeGenC : public ir::IrPrinter { os() << ")"; } - void PrintShape(const std::vector& shape, char leftb = '{', char rightb = '}'); + void PrintShape(const std::vector& shape, + char leftb = '{', + char rightb = '}'); virtual void PrintIncludes(); void PrintBuiltinCodes(); @@ -101,7 +103,8 @@ class CodeGenC : public ir::IrPrinter { NODETY_FORALL(__DEFINE_VISIT) #undef __DEFINE_VISIT -#define __DEFINE_VISIT(op__) void Visit(const ir::intrinsics::op__* op) override; +#define __DEFINE_VISIT(op__) \ + void Visit(const ir::intrinsics::op__* op) override; INTRINSIC_KIND_FOR_EACH(__DEFINE_VISIT) #undef __DEFINE_VISIT diff --git a/paddle/cinn/backends/codegen_c_test.cc b/paddle/cinn/backends/codegen_c_test.cc index a72be4f5d7468..f0e6e238734f5 100755 --- a/paddle/cinn/backends/codegen_c_test.cc +++ b/paddle/cinn/backends/codegen_c_test.cc @@ -62,11 +62,11 @@ TEST(CodeGenC, module) { Target target; target.arch = Target::Arch ::X86; target.bits = Target::Bit ::k32; - target.os = Target::OS ::Linux; + target.os = Target::OS ::Linux; Module::Builder builder("module1", target); auto stages = CreateStages({A, B, C}); - auto func = Lower("add1", stages, {A, B, C}); + auto func = Lower("add1", stages, {A, B, C}); builder.AddFunction(func); @@ -124,7 +124,8 @@ void add1(void* _args, int32_t num_args); CodeGenC compiler(target); compiler.SetInlineBuiltinCodes(false); Outputs outputs; - outputs = outputs.c_header("./generated_module1.h").c_source("./_generated_module1.cc"); + outputs = outputs.c_header("./generated_module1.h") + .c_source("./_generated_module1.cc"); compiler.Compile(builder.Build(), outputs); } } @@ -144,7 +145,9 @@ TEST(CodeGenC, matmul) { Var k(20, "k0"); Tensor C = Compute( - {Expr(100), Expr(50)}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + {Expr(100), Expr(50)}, + [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, + "C"); auto stages = CreateStages({A, B, C}); @@ -154,7 +157,8 @@ TEST(CodeGenC, matmul) { builder.AddBuffer(C->buffer); { // main - std::vector returns({lang::ReturnType{Float(32), C->shape, C->name}}); + std::vector returns( + {lang::ReturnType{Float(32), C->shape, C->name}}); auto tensors = lang::CallLowered("matmul", {A, B}, returns); @@ -242,33 +246,39 @@ TEST(CodeGenC, matmul_tile) { {M, N}, [&](Var i, Var j) { return Expr(0.f); }, "C_init"); Tensor C = Compute( - {M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + {M, N}, + [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, + "C"); auto stages = CreateStages({C, C_init}); stages[C]->ShareBufferWith(stages[C_init]); { - auto _i_outer_i_inner_j_outer_j_inner_ = stages[C_init]->Tile(0, 1, bn.as_int32(), bn.as_int32()); // NOLINT - auto &i_outer = std::get<0>(_i_outer_i_inner_j_outer_j_inner_); - auto &i_inner = std::get<1>(_i_outer_i_inner_j_outer_j_inner_); - auto &j_outer = std::get<2>(_i_outer_i_inner_j_outer_j_inner_); - auto &j_inner = std::get<3>(_i_outer_i_inner_j_outer_j_inner_); + auto _i_outer_i_inner_j_outer_j_inner_ = + stages[C_init]->Tile(0, 1, bn.as_int32(), bn.as_int32()); // NOLINT + auto &i_outer = std::get<0>(_i_outer_i_inner_j_outer_j_inner_); + auto &i_inner = std::get<1>(_i_outer_i_inner_j_outer_j_inner_); + auto &j_outer = std::get<2>(_i_outer_i_inner_j_outer_j_inner_); + auto &j_inner = std::get<3>(_i_outer_i_inner_j_outer_j_inner_); stages[C_init]->Reorder({i_outer, j_outer, i_inner, j_inner}); } { - auto _i_outer_i_inner_j_outer_j_inner_ = stages[C]->Tile(0, 1, bn.as_int32(), bn.as_int32()); // NOLINT - auto &i_outer = std::get<0>(_i_outer_i_inner_j_outer_j_inner_); - auto &i_inner = std::get<1>(_i_outer_i_inner_j_outer_j_inner_); - auto &j_outer = std::get<2>(_i_outer_i_inner_j_outer_j_inner_); - auto &j_inner = std::get<3>(_i_outer_i_inner_j_outer_j_inner_); - auto _k_outer_k_inner_ = stages[C]->Split(poly::Iterator("k0"), 4); // NOLINT - auto &k_outer = std::get<0>(_k_outer_k_inner_); - auto &k_inner = std::get<1>(_k_outer_k_inner_); + auto _i_outer_i_inner_j_outer_j_inner_ = + stages[C]->Tile(0, 1, bn.as_int32(), bn.as_int32()); // NOLINT + auto &i_outer = std::get<0>(_i_outer_i_inner_j_outer_j_inner_); + auto &i_inner = std::get<1>(_i_outer_i_inner_j_outer_j_inner_); + auto &j_outer = std::get<2>(_i_outer_i_inner_j_outer_j_inner_); + auto &j_inner = std::get<3>(_i_outer_i_inner_j_outer_j_inner_); + auto _k_outer_k_inner_ = + stages[C]->Split(poly::Iterator("k0"), 4); // NOLINT + auto &k_outer = std::get<0>(_k_outer_k_inner_); + auto &k_inner = std::get<1>(_k_outer_k_inner_); stages[C]->Reorder({i_outer, j_outer, i_inner, j_inner, k_outer, k_inner}); } - stages[C_init]->ComputeAtSchedule(stages[C], 3, poly::Stage::kComputeAtBefore); + stages[C_init]->ComputeAtSchedule( + stages[C], 3, poly::Stage::kComputeAtBefore); // Code gen auto func = Lower("matmul", stages, {A, B, C}); @@ -332,21 +342,28 @@ TEST(CodeGenC, matmul_packed) { // TODO(Superjomn) Make sure the domain works. Var k(K.as_int32(), "k0"); auto packedB = Compute( - {N / bn, K, bn}, [&](Expr x, Expr y, Expr z) { return B(y, x * bn + z); }, "PackedB"); + {N / bn, K, bn}, + [&](Expr x, Expr y, Expr z) { return B(y, x * bn + z); }, + "PackedB"); auto C = Compute( - {M, N}, [&](Expr i, Expr j) { return ReduceSum(A(i, k) * packedB(j / bn, k, j % bn), {k}); }, "C"); + {M, N}, + [&](Expr i, Expr j) { + return ReduceSum(A(i, k) * packedB(j / bn, k, j % bn), {k}); + }, + "C"); auto stages = CreateStages({packedB, C}); { - auto _i_outer_i_inner_j_outer_j_inner_ = stages[C]->Tile(0, 1, bn.as_int32(), bn.as_int32()); - auto &i_outer = std::get<0>(_i_outer_i_inner_j_outer_j_inner_); - auto &i_inner = std::get<1>(_i_outer_i_inner_j_outer_j_inner_); - auto &j_outer = std::get<2>(_i_outer_i_inner_j_outer_j_inner_); - auto &j_inner = std::get<3>(_i_outer_i_inner_j_outer_j_inner_); - auto _k_outer_k_inner_ = stages[C]->Split(poly::Iterator("k0"), 4); - auto &k_outer = std::get<0>(_k_outer_k_inner_); - auto &k_inner = std::get<1>(_k_outer_k_inner_); + auto _i_outer_i_inner_j_outer_j_inner_ = + stages[C]->Tile(0, 1, bn.as_int32(), bn.as_int32()); + auto &i_outer = std::get<0>(_i_outer_i_inner_j_outer_j_inner_); + auto &i_inner = std::get<1>(_i_outer_i_inner_j_outer_j_inner_); + auto &j_outer = std::get<2>(_i_outer_i_inner_j_outer_j_inner_); + auto &j_inner = std::get<3>(_i_outer_i_inner_j_outer_j_inner_); + auto _k_outer_k_inner_ = stages[C]->Split(poly::Iterator("k0"), 4); + auto &k_outer = std::get<0>(_k_outer_k_inner_); + auto &k_inner = std::get<1>(_k_outer_k_inner_); stages[C]->Reorder({i_outer, j_outer, i_inner, j_inner, k_outer, k_inner}); } @@ -417,7 +434,9 @@ TEST(CodeGenC, call_extern) { Placeholder x("x", {M}); ir::Tensor y = Compute( - {M}, [=](Var i) -> Expr { return lang::CallExtern("tanh", {x(i)}); }, "y"); + {M}, + [=](Var i) -> Expr { return lang::CallExtern("tanh", {x(i)}); }, + "y"); auto stages = CreateStages({y}); diff --git a/paddle/cinn/backends/codegen_c_x86.cc b/paddle/cinn/backends/codegen_c_x86.cc index 994ef1191e675..6164bfa3c16d0 100644 --- a/paddle/cinn/backends/codegen_c_x86.cc +++ b/paddle/cinn/backends/codegen_c_x86.cc @@ -17,10 +17,18 @@ namespace cinn { namespace backends { -void CodeGenCX86::Visit(const ir::Add *op) { VisitBinaryOp(op, op->a(), op->b(), "add"); } -void CodeGenCX86::Visit(const ir::Sub *op) { VisitBinaryOp(op, op->a(), op->b(), "sub"); } -void CodeGenCX86::Visit(const ir::Mul *op) { VisitBinaryOp(op, op->a(), op->b(), "mul"); } -void CodeGenCX86::Visit(const ir::Div *op) { VisitBinaryOp(op, op->a(), op->b(), "div"); } +void CodeGenCX86::Visit(const ir::Add *op) { + VisitBinaryOp(op, op->a(), op->b(), "add"); +} +void CodeGenCX86::Visit(const ir::Sub *op) { + VisitBinaryOp(op, op->a(), op->b(), "sub"); +} +void CodeGenCX86::Visit(const ir::Mul *op) { + VisitBinaryOp(op, op->a(), op->b(), "mul"); +} +void CodeGenCX86::Visit(const ir::Div *op) { + VisitBinaryOp(op, op->a(), op->b(), "div"); +} void CodeGenCX86::Visit(const ir::Load *op) { Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1); @@ -86,7 +94,7 @@ void CodeGenCX86::Visit(const ir::Store *op) { } void CodeGenCX86::PrintVecInputArgument(const Expr *op) { - int bits = op->type().bits() * op->type().lanes(); + int bits = op->type().bits() * op->type().lanes(); auto *broadcast_n = op->As(); if (op->type().lanes() == 1 || broadcast_n) { diff --git a/paddle/cinn/backends/codegen_c_x86.h b/paddle/cinn/backends/codegen_c_x86.h index 75a13ad22e873..643efee39108b 100644 --- a/paddle/cinn/backends/codegen_c_x86.h +++ b/paddle/cinn/backends/codegen_c_x86.h @@ -27,14 +27,14 @@ namespace backends { */ class CodeGenCX86 : public CodeGenC { public: - //! The X86 CPU supports some following features. We use SSE or AVX to accelerate the basic operations if forloop is - //! vectorized. + //! The X86 CPU supports some following features. We use SSE or AVX to + //! accelerate the basic operations if forloop is vectorized. enum class Feature : int { - None = 0, - SSE = 1, //! support SSE instruction set. + None = 0, + SSE = 1, //! support SSE instruction set. AVX256 = 1 << 1, // ! support AVX256 instruction set. AVX512 = 1 << 2, // ! support AVX512 instruction set. - BLAS = 1 << 3, // ! support BLAS library. + BLAS = 1 << 3, // ! support BLAS library. }; Feature feature{Feature::None}; @@ -44,7 +44,8 @@ class CodeGenCX86 : public CodeGenC { * @param target The device. * @param features Features it supported. */ - CodeGenCX86(Target target, Feature feature) : CodeGenC(target), feature(feature) {} + CodeGenCX86(Target target, Feature feature) + : CodeGenC(target), feature(feature) {} protected: void Visit(const ir::Add *op) override; @@ -67,10 +68,18 @@ class CodeGenCX86 : public CodeGenC { //! Check the features. // @{ - bool SupportsSSE() { return static_cast(feature) & static_cast(Feature::SSE); } - bool SupportsAVX256() { return static_cast(feature) & static_cast(Feature::AVX256); } - bool SupportsAVX512() { return static_cast(feature) & static_cast(Feature::AVX512); } - bool SupportsBLAS() { return static_cast(feature) & static_cast(Feature::BLAS); } + bool SupportsSSE() { + return static_cast(feature) & static_cast(Feature::SSE); + } + bool SupportsAVX256() { + return static_cast(feature) & static_cast(Feature::AVX256); + } + bool SupportsAVX512() { + return static_cast(feature) & static_cast(Feature::AVX512); + } + bool SupportsBLAS() { + return static_cast(feature) & static_cast(Feature::BLAS); + } // @} //! Print (and prepare) a argument in vectorize type, for example: @@ -84,10 +93,11 @@ class CodeGenCX86 : public CodeGenC { void PrintAbsAddr(const Op *op) { os() << op->tensor.template As()->name << " + "; - auto index = op->index(); + auto index = op->index(); auto *ramp_n = index.template As(); if (ramp_n) { - CHECK(!ramp_n->base.template As()) << "base of a Ramp node should not be Ramp type"; + CHECK(!ramp_n->base.template As()) + << "base of a Ramp node should not be Ramp type"; Print(ramp_n->base); } else { Print(op->index()); @@ -99,8 +109,12 @@ class CodeGenCX86 : public CodeGenC { }; template -void CodeGenCX86::VisitBinaryOp(const Op *op, Expr a, Expr b, const std::string &op_repr) { - CHECK_EQ(a.type(), b.type()) << " a is : " << a << ", and b is : " << b << ". op_repr is : " << op_repr; +void CodeGenCX86::VisitBinaryOp(const Op *op, + Expr a, + Expr b, + const std::string &op_repr) { + CHECK_EQ(a.type(), b.type()) << " a is : " << a << ", and b is : " << b + << ". op_repr is : " << op_repr; // scalar. if (a.type().lanes() == 1) { diff --git a/paddle/cinn/backends/codegen_c_x86_test.cc b/paddle/cinn/backends/codegen_c_x86_test.cc index 4f2dddb319d63..9e1821f7b0200 100644 --- a/paddle/cinn/backends/codegen_c_x86_test.cc +++ b/paddle/cinn/backends/codegen_c_x86_test.cc @@ -35,15 +35,15 @@ TEST(CodeGenCX86, basic) { using namespace ir; // NOLINT - const int M = 100; - const int K = 200; - const int N = 500; + const int M = 100; + const int K = 200; + const int N = 500; const int bn = 32; Target target; target.arch = Target::Arch ::X86; target.bits = Target::Bit ::k32; - target.os = Target::OS ::Linux; + target.os = Target::OS ::Linux; Placeholder A("A", {M, N}); Placeholder B("B", {M, N}); diff --git a/paddle/cinn/backends/codegen_cuda_dev.cc b/paddle/cinn/backends/codegen_cuda_dev.cc index 6eb1232a069ed..ee028d5eb9e04 100644 --- a/paddle/cinn/backends/codegen_cuda_dev.cc +++ b/paddle/cinn/backends/codegen_cuda_dev.cc @@ -14,8 +14,8 @@ #include "paddle/cinn/backends/codegen_cuda_dev.h" -#include #include +#include #include #include @@ -49,13 +49,14 @@ const std::string &CodeGenCUDA_Dev::GetSourceHeader() { return source_header_; } CodeGenCUDA_Dev::CodeGenCUDA_Dev(Target target) : CodeGenC(target) {} std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, bool for_nvrtc) { - for_nvrtc_ = for_nvrtc; + for_nvrtc_ = for_nvrtc; auto source = Compile(module, OutputKind::CImpl); return source; } -void CodeGenCUDA_Dev::Compile(const ir::Module &module, const Outputs &outputs) { +void CodeGenCUDA_Dev::Compile(const ir::Module &module, + const Outputs &outputs) { ir::IrVerify(Expr(module)); CodeGenC::inline_builtin_codes_ = false; @@ -83,13 +84,15 @@ std::string CodeGenCUDA_Dev::Compile(const ir::LoweredFunc &func) { return ss_.str(); } -std::vector CodeGenCUDA_Dev::GenerateBufferAliasExprs(const ir::_LoweredFunc_ *op, - const std::vector &temp_buffers) { - std::set temp_buffer_set(temp_buffers.begin(), temp_buffers.end()); +std::vector CodeGenCUDA_Dev::GenerateBufferAliasExprs( + const ir::_LoweredFunc_ *op, const std::vector &temp_buffers) { + std::set temp_buffer_set(temp_buffers.begin(), + temp_buffers.end()); // prepare temp buffer alias std::vector buffer_alias; auto tensors = ir::CollectIRNodes(op->body, [&](const Expr *x) { - return x->as_tensor() && x->as_tensor()->buffer.defined() && temp_buffer_set.count(x->as_tensor()->buffer); + return x->as_tensor() && x->as_tensor()->buffer.defined() && + temp_buffer_set.count(x->as_tensor()->buffer); }); // unique tensors @@ -99,7 +102,7 @@ std::vector CodeGenCUDA_Dev::GenerateBufferAliasExprs(const ir::_LoweredFu } for (auto &t : unique_tensors) { - auto data_type = t->type(); + auto data_type = t->type(); auto data_ptr_type = data_type; data_ptr_type.set_cpp_handle(); @@ -124,10 +127,11 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) { std::vector new_body; auto alloca_temp_buffers = op->PrepareAllocTempBufferExprs(); - auto temp_buffer_alias = GenerateBufferAliasExprs(op, op->temp_bufs); - auto alis_var_exprs = op->CudaAliasVarExprs(); + auto temp_buffer_alias = GenerateBufferAliasExprs(op, op->temp_bufs); + auto alis_var_exprs = op->CudaAliasVarExprs(); -#define APPEND_TO_NEW_BODY(field__) new_body.insert(std::end(new_body), std::begin(field__), std::end(field__)); +#define APPEND_TO_NEW_BODY(field__) \ + new_body.insert(std::end(new_body), std::begin(field__), std::end(field__)); APPEND_TO_NEW_BODY(alloca_temp_buffers) APPEND_TO_NEW_BODY(temp_buffer_alias) APPEND_TO_NEW_BODY(alis_var_exprs) @@ -145,7 +149,8 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) { } void CodeGenCUDA_Dev::Visit(const ir::_Var_ *op) { - if (utils::Startswith(op->name, "threadIdx") || utils::Startswith(op->name, "blockIdx")) { + if (utils::Startswith(op->name, "threadIdx") || + utils::Startswith(op->name, "blockIdx")) { os() << "(int)" + op->name; } else { os() << op->name; @@ -197,7 +202,8 @@ void CodeGenCUDA_Dev::PrintFunctionDeclaration(const ir::_LoweredFunc_ *op) { void CodeGenCUDA_Dev::PrintFuncArg(const ir::Argument &arg) { if (arg.is_buffer()) { - // In CUDA kernel, only primitive type is supported, so we replace the buffer with T*j + // In CUDA kernel, only primitive type is supported, so we replace the + // buffer with T*j if (arg.is_input()) os() << "const "; os() << GetTypeRepr(arg.buffer_arg()->dtype); os() << "* "; @@ -216,7 +222,8 @@ void CodeGenCUDA_Dev::PrintFuncArg(const ir::Argument &arg) { void CodeGenCUDA_Dev::PrintBuiltinCodes() {} -std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, CodeGenC::OutputKind output_kind) { +std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, + CodeGenC::OutputKind output_kind) { if (output_kind == OutputKind::CHeader) { GenerateHeaderFile(module); } else if (output_kind == OutputKind::CImpl) { @@ -268,7 +275,8 @@ void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) { break; default: - LOG(FATAL) << "CUDA device codegen not support memory " << buffer->name << ", type " << buffer->memory_type; + LOG(FATAL) << "CUDA device codegen not support memory " << buffer->name + << ", type " << buffer->memory_type; } } @@ -321,7 +329,8 @@ void CodeGenCUDA_Dev::Visit(const ir::Let *op) { // identify vectorized tensors by checking their dtypes are customized_type // with customized_type::kcuda_builtin_vector_t prefix, and save their names if (op->type().is_customized() && - utils::Startswith(op->type().customized_type(), common::customized_type::kcuda_builtin_vector_t)) { + utils::Startswith(op->type().customized_type(), + common::customized_type::kcuda_builtin_vector_t)) { os() << GetTypeRepr(op->type()); if (op->type().is_cpp_handle()) { os() << " " << kCKeywordRestrict; @@ -329,7 +338,8 @@ void CodeGenCUDA_Dev::Visit(const ir::Let *op) { os() << " "; Print(op->symbol); vectorized_tensor_names_.insert(utils::GetStreamCnt(op->symbol)); - // skip "=0" in "half8 temp = 0;" sincethe operator= of half8 may not overloaded. + // skip "=0" in "half8 temp = 0;" sincethe operator= of half8 may not + // overloaded. if (op->body.As() && op->body.As()->value == 0) { return; } @@ -340,8 +350,11 @@ void CodeGenCUDA_Dev::Visit(const ir::Let *op) { } } -bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op, ir::Expr index_expr, bool is_store) { - static constexpr char index2suffix[8] = {'x', 'y', 'z', 'w', 'v', 'u', 't', 's'}; +bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op, + ir::Expr index_expr, + bool is_store) { + static constexpr char index2suffix[8] = { + 'x', 'y', 'z', 'w', 'v', 'u', 't', 's'}; // addr of op should be a place of tensor and the index is simple int number if (!op->is_addr_tensor() || !index_expr.As()) { @@ -363,22 +376,25 @@ bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op, if (is_store && tensor->type().is_cpp_handle()) { os() << tensor->name << "[" << index << "]"; } else { - os() << tensor->name << (tensor->type().is_cpp_handle() ? "->" : ".") << index2suffix[index]; + os() << tensor->name << (tensor->type().is_cpp_handle() ? "->" : ".") + << index2suffix[index]; } return true; } void CodeGenCUDA_Dev::Visit(const ir::Load *op) { - // overload this visit function to especially deal with the case when it accesses - // element at a cuda built-in vector, others still resolve to CodeGenC + // overload this visit function to especially deal with the case when it + // accesses element at a cuda built-in vector, others still resolve to + // CodeGenC if (!PrintBuiltinVectorAccess(op, op->index(), false)) { CodeGenC::Visit(op); } } void CodeGenCUDA_Dev::Visit(const ir::Store *op) { - // overload this visit function to especially deal with the case when it accesses - // element at a cuda built-in vector, others still resolve to CodeGenC + // overload this visit function to especially deal with the case when it + // accesses element at a cuda built-in vector, others still resolve to + // CodeGenC if (PrintBuiltinVectorAccess(op, op->index(), true)) { os() << " = "; Print(op->value); diff --git a/paddle/cinn/backends/codegen_cuda_dev.h b/paddle/cinn/backends/codegen_cuda_dev.h index 44607d76a283b..1abe891b619f0 100644 --- a/paddle/cinn/backends/codegen_cuda_dev.h +++ b/paddle/cinn/backends/codegen_cuda_dev.h @@ -36,9 +36,11 @@ namespace backends { /** * CUDA device code generator. * - * It generates the device function, e.g, the function called "myadd" will have a __global__ functon called - * "myadd_kernel", different from codegen_c, the declaration of the "myadd_kernel" function has an expanded argument - * list, which finally similar to `__global__ void myadd(float* __restrict__ A, float* __restrict__ B, int n);` + * It generates the device function, e.g, the function called "myadd" will have + * a __global__ functon called "myadd_kernel", different from codegen_c, the + * declaration of the "myadd_kernel" function has an expanded argument list, + * which finally similar to `__global__ void myadd(float* __restrict__ A, float* + * __restrict__ B, int n);` */ class CodeGenCUDA_Dev : public CodeGenC { public: @@ -55,7 +57,8 @@ class CodeGenCUDA_Dev : public CodeGenC { std::string Compile(const ir::LoweredFunc& func); /** - * \brief Print a function argument in CUDA syntax. Currently, just some decoration of __restrict__. + * \brief Print a function argument in CUDA syntax. Currently, just some + * decoration of __restrict__. * @param arg the argument. * @return the representation in CUDA syntax. * @@ -79,7 +82,9 @@ class CodeGenCUDA_Dev : public CodeGenC { void Visit(const ir::Let* op) override; // Print element access at a cuda built-in vector on a load/store node - bool PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger* op, ir::Expr index, bool is_store); + bool PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger* op, + ir::Expr index, + bool is_store); void PrintBuiltinCodes(); @@ -89,19 +94,23 @@ class CodeGenCUDA_Dev : public CodeGenC { void PrintTempBufferAliasDefinition(const ir::Buffer& buffer); - std::vector GenerateBufferAliasExprs(const ir::_LoweredFunc_* op, const std::vector& temp_buffers); + std::vector GenerateBufferAliasExprs( + const ir::_LoweredFunc_* op, const std::vector& temp_buffers); /** - * Print the function declaration, this is different from C, we expand the arguments and get something like - * `__global__ void myadd(float* __restrict__ A, float* __restrict__ B, int n);` + * Print the function declaration, this is different from C, we expand the + * arguments and get something like + * `__global__ void myadd(float* __restrict__ A, float* __restrict__ B, int + * n);` */ void PrintFunctionDeclaration(const ir::_LoweredFunc_* op); private: Target target_; bool for_nvrtc_{false}; - // names of vectorized tensors from `Let` statments where dtypes of the tensors - // are customized_type with customized_type::kcuda_builtin_vector_t prefix + // names of vectorized tensors from `Let` statments where dtypes of the + // tensors are customized_type with customized_type::kcuda_builtin_vector_t + // prefix std::unordered_set vectorized_tensor_names_; static const std::string source_header_; }; diff --git a/paddle/cinn/backends/codegen_cuda_host.cc b/paddle/cinn/backends/codegen_cuda_host.cc index f1ee430b68997..9c44c8302742a 100644 --- a/paddle/cinn/backends/codegen_cuda_host.cc +++ b/paddle/cinn/backends/codegen_cuda_host.cc @@ -32,22 +32,25 @@ using cinn::common::float16; const int kArgsArrayMaxLen = 20; -llvm::Value* CodeGenCUDA_Host::LowerGPUKernelLauncher(const ir::_LoweredFunc_* func) { - auto body = func->body; +llvm::Value* CodeGenCUDA_Host::LowerGPUKernelLauncher( + const ir::_LoweredFunc_* func) { + auto body = func->body; auto* call_ir = body.As(); CHECK(call_ir); // Create the function // @{ - auto* function_type = GenFunctionTypeFromCinnFunction(func, true); - llvm::Function* function = llvm::Function::Create(function_type, llvm::Function::ExternalLinkage, func->name, m_); + auto* function_type = GenFunctionTypeFromCinnFunction(func, true); + llvm::Function* function = llvm::Function::Create( + function_type, llvm::Function::ExternalLinkage, func->name, m_); function->setCallingConv(llvm::CallingConv::C); function->setHasUWTable(); std::vector ll_function_args; - std::transform(function->arg_begin(), function->arg_end(), std::back_inserter(ll_function_args), [](auto& arg) { - return std::addressof(arg); - }); + std::transform(function->arg_begin(), + function->arg_end(), + std::back_inserter(ll_function_args), + [](auto& arg) { return std::addressof(arg); }); // @} llvm::BasicBlock* entry = llvm::BasicBlock::Create( @@ -57,8 +60,8 @@ llvm::Value* CodeGenCUDA_Host::LowerGPUKernelLauncher(const ir::_LoweredFunc_* f /*InsertBefore=*/nullptr); b_->SetInsertPoint(entry); - auto* kernel_args = ll_function_args[0]; - auto* kernel_args_count = ll_function_args[1]; + auto* kernel_args = ll_function_args[0]; + auto* kernel_args_count = ll_function_args[1]; llvm::Value* kernel_stream = nullptr; if (ll_function_args.size() == 3) { kernel_stream = ll_function_args[2]; @@ -68,13 +71,16 @@ llvm::Value* CodeGenCUDA_Host::LowerGPUKernelLauncher(const ir::_LoweredFunc_* f CHECK_EQ(kernel_args_count->getType(), ll_int32_ty()); // int32 std::unordered_map global_args = { - {KERNEL_ARGS, kernel_args}, {KERNEL_ARGS_NUM, kernel_args_count}, {KERNEL_STREAM, kernel_stream}}; + {KERNEL_ARGS, kernel_args}, + {KERNEL_ARGS_NUM, kernel_args_count}, + {KERNEL_STREAM, kernel_stream}}; auto ret_type = CinnTypeToLLVMType(Void(), m_); std::vector args_type; for (auto r_arg : call_ir->read_args) { if (r_arg.is_var()) { - if (r_arg.as_var()->type().is_cpp_handle() || r_arg.as_var()->type().is_string()) { + if (r_arg.as_var()->type().is_cpp_handle() || + r_arg.as_var()->type().is_string()) { args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); } else if (r_arg.as_var()->type().is_int(32)) { args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); @@ -120,9 +126,12 @@ llvm::Value* CodeGenCUDA_Host::LowerGPUKernelLauncher(const ir::_LoweredFunc_* f for (auto& r_arg : call_ir->read_args) { if (r_arg.is_var()) { if (r_arg.as_var()->type().is_string()) { - auto kvalue = m_->getOrInsertGlobal(r_arg.as_var()->name + "_ptr_", b_->getInt8PtrTy()); - call_args.push_back(b_->CreateLoad(b_->getInt8PtrTy(), kvalue, r_arg.as_var()->name + "_ptr_load")); - } else if (r_arg.as_var()->type().is_cpp_handle() || r_arg.as_var()->type().is_int(32)) { + auto kvalue = m_->getOrInsertGlobal(r_arg.as_var()->name + "_ptr_", + b_->getInt8PtrTy()); + call_args.push_back(b_->CreateLoad( + b_->getInt8PtrTy(), kvalue, r_arg.as_var()->name + "_ptr_load")); + } else if (r_arg.as_var()->type().is_cpp_handle() || + r_arg.as_var()->type().is_int(32)) { CHECK(global_args.count(r_arg.as_var()->name)); call_args.push_back(global_args[r_arg.as_var()->name]); } else { @@ -148,15 +157,19 @@ llvm::Value* CodeGenCUDA_Host::LowerGPUKernelLauncher(const ir::_LoweredFunc_* f } else if (r_arg.type().is_uint(64)) { call_args.push_back(b_->getInt64(r_arg.as_uint64())); } else if (r_arg.type().is_float(32)) { - call_args.push_back(llvm::ConstantFP::get(b_->getFloatTy(), llvm::APFloat(r_arg.as_float()))); + call_args.push_back(llvm::ConstantFP::get( + b_->getFloatTy(), llvm::APFloat(r_arg.as_float()))); } else if (r_arg.type().is_float(64)) { - call_args.push_back(llvm::ConstantFP::get(b_->getDoubleTy(), llvm::APFloat(r_arg.as_double()))); + call_args.push_back(llvm::ConstantFP::get( + b_->getDoubleTy(), llvm::APFloat(r_arg.as_double()))); } else if (r_arg.type().is_bfloat16()) { - call_args.push_back( - llvm::ConstantFP::get(b_->getBFloatTy(), llvm::APFloat(static_cast(r_arg.as_bfloat16())))); + call_args.push_back(llvm::ConstantFP::get( + b_->getBFloatTy(), + llvm::APFloat(static_cast(r_arg.as_bfloat16())))); } else if (r_arg.type().is_float16()) { - call_args.push_back( - llvm::ConstantFP::get(b_->getHalfTy(), llvm::APFloat(static_cast(r_arg.as_float16())))); + call_args.push_back(llvm::ConstantFP::get( + b_->getHalfTy(), + llvm::APFloat(static_cast(r_arg.as_float16())))); } else { CINN_NOT_IMPLEMENTED; } diff --git a/paddle/cinn/backends/codegen_cuda_host.h b/paddle/cinn/backends/codegen_cuda_host.h index c91ab76a91a8e..5d311b5808d45 100644 --- a/paddle/cinn/backends/codegen_cuda_host.h +++ b/paddle/cinn/backends/codegen_cuda_host.h @@ -31,11 +31,15 @@ namespace backends { */ class CodeGenCUDA_Host : public CodeGenLLVM { public: - explicit CodeGenCUDA_Host(llvm::Module *m, llvm::IRBuilder<> *b, const std::shared_ptr &vars = nullptr) + explicit CodeGenCUDA_Host(llvm::Module *m, + llvm::IRBuilder<> *b, + const std::shared_ptr &vars = nullptr) : CodeGenLLVM(m, b, vars) {} using CodeGenLLVM::Visit; - llvm::Value *Visit(const ir::_LoweredFunc_ *func) override { return LowerGPUKernelLauncher(func); } + llvm::Value *Visit(const ir::_LoweredFunc_ *func) override { + return LowerGPUKernelLauncher(func); + } private: /** @@ -43,10 +47,12 @@ class CodeGenCUDA_Host : public CodeGenLLVM { * * We launch a CUDA kernel in the following way: * - * 1. a GPU function (called fn) will compiled to PTX and lower by CUDA driver to a function pointer, which we store - * as a `void*` type global variable [fn_kernel_ptr] in LLVM module. - * 2. when lower the host launcher, we replace the Call of the original kernel [fn] to a Call of - * `cinn_call_cuda_kernel` method which is registered as an external function. + * 1. a GPU function (called fn) will compiled to PTX and lower by CUDA driver + * to a function pointer, which we store as a `void*` type global variable + * [fn_kernel_ptr] in LLVM module. + * 2. when lower the host launcher, we replace the Call of the original kernel + * [fn] to a Call of `cinn_call_cuda_kernel` method which is registered as an + * external function. * */ llvm::Value *LowerGPUKernelLauncher(const ir::_LoweredFunc_ *func); diff --git a/paddle/cinn/backends/codegen_cuda_util.h b/paddle/cinn/backends/codegen_cuda_util.h index 51677bf3c0530..09f35d6698565 100755 --- a/paddle/cinn/backends/codegen_cuda_util.h +++ b/paddle/cinn/backends/codegen_cuda_util.h @@ -33,13 +33,13 @@ namespace backends { #define KERNEL_STREAM "kernel_stream" /** - * Split a CINN Module into two separate modules, one cantains the host functions, the other contains the device - * kernels. + * Split a CINN Module into two separate modules, one cantains the host + * functions, the other contains the device kernels. * * This contains some process: * - * - replace the original kernel function with a Call node and add it to the first module, add a device kernel function - * to the second module. + * - replace the original kernel function with a Call node and add it to the + * first module, add a device kernel function to the second module. */ std::tuple SplitCudaAndHostModule(ir::Module module); @@ -48,11 +48,13 @@ namespace detail { struct CollectHostFunctionVisitor : public ir::IRMutator<> { explicit CollectHostFunctionVisitor(const std::string& module_name) : host_module_builder(module_name + "_host", common::DefaultHostTarget()), - device_module_builder(module_name + "_gpu_device", common::DefaultNVGPUTarget()) {} + device_module_builder(module_name + "_gpu_device", + common::DefaultNVGPUTarget()) {} std::tuple operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); - return std::make_tuple(host_module_builder.Build(), device_module_builder.Build()); + return std::make_tuple(host_module_builder.Build(), + device_module_builder.Build()); } private: @@ -65,7 +67,8 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> { } auto host_func = CreateHostFunctionGivenDeviceKernel(op); host_module_builder.AddFunction(host_func.as_lowered_func_ref()); - device_module_builder.AddFunction(CreateDeviceFunctionGivenDeviceKernel(*expr).as_lowered_func_ref()); + device_module_builder.AddFunction( + CreateDeviceFunctionGivenDeviceKernel(*expr).as_lowered_func_ref()); } } @@ -89,45 +92,49 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> { */ Expr CreateHostFunctionGivenDeviceKernel(const ir::_LoweredFunc_* func) { // std::vector args; - // NOTE the suffix `__ptr` makes this argument lower to a pointer in LLVM backend. - // args.push_back(Var("args__ptr", type_of())); + // NOTE the suffix `__ptr` makes this argument lower to a pointer in LLVM + // backend. args.push_back(Var("args__ptr", type_of())); // args.push_back(Var("num_args", type_of())); ir::Var kernel_ptr(GenDeviceKernelName(func->name), type_of()); ir::Var kernel_args(KERNEL_ARGS, type_of()); ir::Var kernel_args_num(KERNEL_ARGS_NUM, type_of()); ir::Var kernel_stream(KERNEL_STREAM, type_of()); - auto call_extern_api = ir::Call::Make(Void(), - runtime::intrinsic::call_cuda_kernel, - {kernel_ptr, - kernel_args, - kernel_args_num, - Expr(func->cuda_axis_info.grid_dim(0)), // grid_x - Expr(func->cuda_axis_info.grid_dim(1)), // grid_y - Expr(func->cuda_axis_info.grid_dim(2)), // grid_z - Expr(func->cuda_axis_info.block_dim(0)), // block_x - Expr(func->cuda_axis_info.block_dim(1)), // block_y - Expr(func->cuda_axis_info.block_dim(2)), // block_z - kernel_stream}, - {}, - ir::CallType::Extern, - ir::FunctionRef(), - 0); - std::vector arguments = {ir::Argument(kernel_args, ir::Argument::IO::kOutput), - ir::Argument(kernel_args_num, ir::Argument::IO::kInput), - ir::Argument(kernel_stream, ir::Argument::IO::kOutput)}; + auto call_extern_api = + ir::Call::Make(Void(), + runtime::intrinsic::call_cuda_kernel, + {kernel_ptr, + kernel_args, + kernel_args_num, + Expr(func->cuda_axis_info.grid_dim(0)), // grid_x + Expr(func->cuda_axis_info.grid_dim(1)), // grid_y + Expr(func->cuda_axis_info.grid_dim(2)), // grid_z + Expr(func->cuda_axis_info.block_dim(0)), // block_x + Expr(func->cuda_axis_info.block_dim(1)), // block_y + Expr(func->cuda_axis_info.block_dim(2)), // block_z + kernel_stream}, + {}, + ir::CallType::Extern, + ir::FunctionRef(), + 0); + std::vector arguments = { + ir::Argument(kernel_args, ir::Argument::IO::kOutput), + ir::Argument(kernel_args_num, ir::Argument::IO::kInput), + ir::Argument(kernel_stream, ir::Argument::IO::kOutput)}; return ir::_LoweredFunc_::Make(func->name, arguments, call_extern_api, {}); } Expr CreateDeviceFunctionGivenDeviceKernel(Expr expr) { - auto copied = optim::IRCopy(expr); + auto copied = optim::IRCopy(expr); auto* lowered_func = copied.as_lowered_func(); lowered_func->name = GenDeviceKernelName(lowered_func->name); return copied; } - inline std::string GenDeviceKernelName(const std::string& fn) { return fn + "_kernel"; } + inline std::string GenDeviceKernelName(const std::string& fn) { + return fn + "_kernel"; + } private: ir::Module::Builder host_module_builder; diff --git a/paddle/cinn/backends/codegen_debug_test.cc b/paddle/cinn/backends/codegen_debug_test.cc index 317e9b9957440..a156f5475b3db 100644 --- a/paddle/cinn/backends/codegen_debug_test.cc +++ b/paddle/cinn/backends/codegen_debug_test.cc @@ -52,7 +52,10 @@ CUdeviceptr CreateCudaMemory(const std::vector& shape, const T* data) { CUdeviceptr cuda_ptr = cuMemAlloc(&cuda_ptr, numel * sizeof(T)); if (data != nullptr) { - CUDA_CALL(cudaMemcpy(reinterpret_cast(cuda_ptr), data, numel * sizeof(T), cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy(reinterpret_cast(cuda_ptr), + data, + numel * sizeof(T), + cudaMemcpyHostToDevice)); } return cuda_ptr; } @@ -108,13 +111,16 @@ void __launch_bounds__(512) fn_relu_1_kernel(const float* __restrict__ var_1, fl ASSERT_FALSE(ptx.empty()); CUDAModule cuda_module(ptx, CUDAModule::Kind::PTX); - CUdeviceptr var = CreateCudaMemory(/* shape */ {64 * 112 * 112}, /* data */ nullptr); - CUdeviceptr out = CreateCudaMemory(/* shape */ {64 * 112 * 112}, /* data */ nullptr); + CUdeviceptr var = + CreateCudaMemory(/* shape */ {64 * 112 * 112}, /* data */ nullptr); + CUdeviceptr out = + CreateCudaMemory(/* shape */ {64 * 112 * 112}, /* data */ nullptr); void* args[] = {&var, &out}; dim3 grid(512, 1, 1); dim3 block(512, 1, 1); - cuda_module.LaunchKernel(/*device_id*/ 0, "fn_relu_1_kernel", grid, block, args); + cuda_module.LaunchKernel( + /*device_id*/ 0, "fn_relu_1_kernel", grid, block, args); } } // namespace backends diff --git a/paddle/cinn/backends/compiler.cc b/paddle/cinn/backends/compiler.cc index 5289e6a95196c..72d69ccc3fed3 100644 --- a/paddle/cinn/backends/compiler.cc +++ b/paddle/cinn/backends/compiler.cc @@ -38,8 +38,9 @@ static constexpr int DebugLogMaxLen = 30000; SourceCodePrint::SourceCodePrint() { if (!FLAGS_cinn_source_code_save_path.empty()) { - LOG(INFO) << "The CINN auto generated source code will writing into file: \"" << FLAGS_cinn_source_code_save_path - << "\""; + LOG(INFO) + << "The CINN auto generated source code will writing into file: \"" + << FLAGS_cinn_source_code_save_path << "\""; of.open(FLAGS_cinn_source_code_save_path, std::ios_base::out); } } @@ -55,11 +56,14 @@ void SourceCodePrint::write(const std::string& source_code) { if (of.is_open()) { of << source_code << std::endl; } else if (!FLAGS_cinn_source_code_save_path.empty()) { - LOG(WARNING) << "Failed to open \"" << FLAGS_cinn_source_code_save_path << "\", source code will print."; + LOG(WARNING) << "Failed to open \"" << FLAGS_cinn_source_code_save_path + << "\", source code will print."; if (source_code.size() > DebugLogMaxLen) { - LOG(INFO) << "[CUDA] source code-0:\n" << source_code.substr(0, DebugLogMaxLen); + LOG(INFO) << "[CUDA] source code-0:\n" + << source_code.substr(0, DebugLogMaxLen); for (int i = 1; i * DebugLogMaxLen < source_code.size(); ++i) { - LOG(INFO) << "[CUDA] source code-" << i << ":\n" << source_code.substr(DebugLogMaxLen * i, DebugLogMaxLen); + LOG(INFO) << "[CUDA] source code-" << i << ":\n" + << source_code.substr(DebugLogMaxLen * i, DebugLogMaxLen); } } else { LOG(INFO) << "[CUDA] source code:\n" << source_code; @@ -80,9 +84,10 @@ void Compiler::Build(const Module& module, const std::string& code) { std::string Compiler::GetSourceCode(const ir::Module& module) { if (target_.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDA - auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT - auto& host_module = std::get<0>(_host_module_device_module_); - auto& device_module = std::get<1>(_host_module_device_module_); + auto _host_module_device_module_ = + SplitCudaAndHostModule(module); // NOLINT + auto& host_module = std::get<0>(_host_module_device_module_); + auto& device_module = std::get<1>(_host_module_device_module_); CodeGenCUDA_Dev codegen(target_); auto source_code = codegen.Compile(device_module); return source_code; @@ -104,11 +109,12 @@ void Compiler::BuildDefault(const Module& module) { } } -void Compiler::CompileCudaModule(const Module& module, const std::string& code) { +void Compiler::CompileCudaModule(const Module& module, + const std::string& code) { #ifdef CINN_WITH_CUDA auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT - auto& host_module = std::get<0>(_host_module_device_module_); - auto& device_module = std::get<1>(_host_module_device_module_); + auto& host_module = std::get<0>(_host_module_device_module_); + auto& device_module = std::get<1>(_host_module_device_module_); VLOG(3) << "[CUDA] host module:\n" << host_module; VLOG(3) << "[CUDA] device module:\n" << device_module; @@ -119,24 +125,30 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code) } else { source_code = code; } - CHECK(!source_code.empty()) << "Compile CUDA C code failed from device module:\n" << device_module; + CHECK(!source_code.empty()) + << "Compile CUDA C code failed from device module:\n" + << device_module; VLOG(3) << "[CUDA] C:\n" << source_code; SourceCodePrint::GetInstance()->write(source_code); using runtime::cuda::CUDAModule; nvrtc::Compiler compiler; auto ptx = compiler(source_code); - CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << source_code; - cuda_module_.reset( - new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX)); + CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" + << source_code; + cuda_module_.reset(new CUDAModule(ptx, + compiler.compile_to_cubin() + ? CUDAModule::Kind::CUBIN + : CUDAModule::Kind::PTX)); RuntimeSymbols symbols; for (auto& fn : device_module.functions()) { std::string kernel_fn_name = fn->name; - auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name); + auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name); CHECK(fn_kernel); - symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast(fn_kernel)); + symbols.RegisterVar(kernel_fn_name + "_ptr_", + reinterpret_cast(fn_kernel)); } engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols)); @@ -147,9 +159,13 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code) #endif } -void Compiler::CompileX86Module(const Module& module) { engine_->Link(module); } +void Compiler::CompileX86Module(const Module& module) { + engine_->Link(module); +} -void Compiler::ExportObject(const std::string& path) { engine_->ExportObject(path); } +void Compiler::ExportObject(const std::string& path) { + engine_->ExportObject(path); +} void* Compiler::Lookup(absl::string_view fn_name) { CHECK(engine_); diff --git a/paddle/cinn/backends/compiler.h b/paddle/cinn/backends/compiler.h index 9fbed5c518977..1293125129a81 100644 --- a/paddle/cinn/backends/compiler.h +++ b/paddle/cinn/backends/compiler.h @@ -73,11 +73,13 @@ class Compiler final { void* Lookup(absl::string_view fn_name); private: - void CompileCudaModule(const ir::Module& module, const std::string& code = ""); + void CompileCudaModule(const ir::Module& module, + const std::string& code = ""); void CompileX86Module(const ir::Module& module); - explicit Compiler(const Target& target) : target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {} + explicit Compiler(const Target& target) + : target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {} CINN_DISALLOW_COPY_AND_ASSIGN(Compiler); diff --git a/paddle/cinn/backends/compiler_test.cc b/paddle/cinn/backends/compiler_test.cc index e415eaa14e157..84abedd91e5b6 100644 --- a/paddle/cinn/backends/compiler_test.cc +++ b/paddle/cinn/backends/compiler_test.cc @@ -42,9 +42,9 @@ TEST(Compiler, x86) { { // test x86 auto _A_B_C_ = create_module(); // NOLINT - auto& A = std::get<0>(_A_B_C_); - auto& B = std::get<1>(_A_B_C_); - auto& C = std::get<2>(_A_B_C_); + auto& A = std::get<0>(_A_B_C_); + auto& B = std::get<1>(_A_B_C_); + auto& C = std::get<2>(_A_B_C_); auto stages = CreateStages({C}); @@ -59,9 +59,15 @@ TEST(Compiler, x86) { auto* fnp = compiler->Lookup("fn"); ASSERT_TRUE(fnp); - auto* Ab = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); - auto* Bb = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); - auto* Cb = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_zero().Build(); + auto* Ab = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_random() + .Build(); + auto* Bb = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_random() + .Build(); + auto* Cb = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_zero() + .Build(); auto args = common::ArgsBuilder().Add(Ab).Add(Bb).Add(Cb).Build(); reinterpret_cast(fnp)(args.data(), args.size()); @@ -91,10 +97,10 @@ TEST(Compiler, cuda) { { // cuda auto _A_B_C_ = create_module(); // NOLINT - auto& A = std::get<0>(_A_B_C_); - auto& B = std::get<1>(_A_B_C_); - auto& C = std::get<2>(_A_B_C_); - auto stages = CreateStages({C}); + auto& A = std::get<0>(_A_B_C_); + auto& B = std::get<1>(_A_B_C_); + auto& C = std::get<2>(_A_B_C_); + auto stages = CreateStages({C}); stages[C]->Bind(0, "blockIdx.x"); stages[C]->Bind(1, "threadIdx.x"); @@ -110,9 +116,15 @@ TEST(Compiler, cuda) { auto* fnp = compiler->Lookup("fn"); ASSERT_TRUE(fnp); - auto* Ab = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); - auto* Bb = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); - auto* Cb = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_zero().Build(); + auto* Ab = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_random() + .Build(); + auto* Bb = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_random() + .Build(); + auto* Cb = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_zero() + .Build(); // allocate CUDA buffer void *Ag, *Bg, *Cg; @@ -138,7 +150,8 @@ TEST(Compiler, cuda) { timer.Start(); void* stream = nullptr; for (int i = 0; i < 1000; i++) { - reinterpret_cast(fnp)(args.data(), args.size(), stream); + reinterpret_cast(fnp)( + args.data(), args.size(), stream); } CUDA_CALL(cudaDeviceSynchronize()); @@ -146,7 +159,8 @@ TEST(Compiler, cuda) { LOG(INFO) << "latency: " << latency / 1000; std::vector ch(M.as_int32() * N.as_int32(), 0.f); - CUDA_CALL(cudaMemcpy(ch.data(), Cg, ch.size() * sizeof(float), cudaMemcpyDeviceToHost)); + CUDA_CALL(cudaMemcpy( + ch.data(), Cg, ch.size() * sizeof(float), cudaMemcpyDeviceToHost)); auto* Ad = reinterpret_cast(Ab->memory); auto* Bd = reinterpret_cast(Bb->memory); @@ -173,17 +187,22 @@ TEST(Compiler, sqrt) { auto A = Compute( {N, C, H, W}, [=](Expr n, Expr c, Expr h, Expr w) { - return (input(n, c, h, w) - mean(c)) * scale(c) / lang::Sqrt(variance(c) + Expr(epsilon)) + bias(c); + return (input(n, c, h, w) - mean(c)) * scale(c) / + lang::Sqrt(variance(c) + Expr(epsilon)) + + bias(c); }, "A"); - auto B = hlir::pe::Pool2d(input, {3, 3}, {1, 1}, {1, 1, 1, 1}, "max", false, false); + auto B = hlir::pe::Pool2d( + input, {3, 3}, {1, 1}, {1, 1, 1, 1}, "max", false, false); - auto BB = hlir::pe::BatchNorm_NCHW(input, scale, bias, mean, variance, epsilon, "batchnorm"); + auto BB = hlir::pe::BatchNorm_NCHW( + input, scale, bias, mean, variance, epsilon, "batchnorm"); auto stages = CreateStages({input, mean, scale, variance, A, bias, B[0], BB}); - auto fn = Lower("fn", stages, {input, mean, scale, bias, variance, A, B[0], BB}); + auto fn = + Lower("fn", stages, {input, mean, scale, bias, variance, A, B[0], BB}); Module::Builder builder("some", common::DefaultHostTarget()); builder.AddFunction(fn); diff --git a/paddle/cinn/backends/extern_func_emitter.cc b/paddle/cinn/backends/extern_func_emitter.cc index 83d18060ec122..9c1c7fa347dbc 100644 --- a/paddle/cinn/backends/extern_func_emitter.cc +++ b/paddle/cinn/backends/extern_func_emitter.cc @@ -37,19 +37,22 @@ ExternFunctionEmitterRegistry& ExternFunctionEmitterRegistry::Global() { return x; } -void ExternFunctionEmitterRegistry::Register(const ExternFuncID& name, const std::string& x) { +void ExternFunctionEmitterRegistry::Register(const ExternFuncID& name, + const std::string& x) { #ifdef CINN_WITH_DEBUG if (FLAGS_verbose_function_register) { - RAW_LOG_INFO("Register extern function emitter [%s]", utils::GetStreamCnt(name).c_str()); + RAW_LOG_INFO("Register extern function emitter [%s]", + utils::GetStreamCnt(name).c_str()); } #endif // CINN_WITH_DEBUG CHECK(!x.empty()) << "Extern Function name is empty."; data_[name] = x; } -const std::string& ExternFunctionEmitterRegistry::Lookup(const ExternFuncID& name) const { +const std::string& ExternFunctionEmitterRegistry::Lookup( + const ExternFuncID& name) const { static const std::string not_found = ""; - auto it = data_.find(name); + auto it = data_.find(name); if (it != data_.end()) { return it->second; } @@ -74,8 +77,10 @@ const FunctionProto& ExternFunctionEmitter::func_proto() const { namespace std { -size_t hash::operator()(const cinn::backends::ExternFuncID& x) const { - return absl::Hash{}(x.name) ^ absl::Hash{}(x.backend_id); +size_t hash::operator()( + const cinn::backends::ExternFuncID& x) const { + return absl::Hash{}(x.name) ^ + absl::Hash{}(x.backend_id); } } // namespace std diff --git a/paddle/cinn/backends/extern_func_emitter.h b/paddle/cinn/backends/extern_func_emitter.h index 98631055be904..e460f3fdb202b 100644 --- a/paddle/cinn/backends/extern_func_emitter.h +++ b/paddle/cinn/backends/extern_func_emitter.h @@ -13,8 +13,8 @@ // limitations under the License. /** - * \file Implements the ExternFuncEmitter class, which is the base of all the emitter of extern function in the - * backends. + * \file Implements the ExternFuncEmitter class, which is the base of all the + * emitter of extern function in the backends. */ #pragma once @@ -44,14 +44,14 @@ namespace cinn { namespace backends { //! IDs of backends. -static const char* backend_C = "C"; +static const char* backend_C = "C"; static const char* backend_llvm_host = "llvm_host"; -static const char* backend_llvm_x86 = "llvm_x86"; -static const char* backend_nvgpu = "nvgpu"; +static const char* backend_llvm_x86 = "llvm_x86"; +static const char* backend_nvgpu = "nvgpu"; /** - * \brief Base class of the emitter of all the extern functions able to trigger inside CINN CodeGen system. - * There are some common attributes and interfaces. + * \brief Base class of the emitter of all the extern functions able to trigger + * inside CINN CodeGen system. There are some common attributes and interfaces. */ class ExternFunctionEmitter { public: @@ -63,7 +63,8 @@ class ExternFunctionEmitter { */ virtual const char* func_name() const = 0; /** - * Emit a store node, if the call node's RetValuePacked is true, otherwise Emit a Call node. + * Emit a store node, if the call node's RetValuePacked is true, otherwise + * Emit a Call node. */ void Emit(const ir::Call* op, bool insert_global_if_missing = false) { @@ -82,13 +83,12 @@ class ExternFunctionEmitter { * s = Call(some_func, arg0) * \endcode * - * If this function returns true, some pass will applied and transform the IR to - * \code - * Call(some_func, get_addr(s) - * \endcode + * If this function returns true, some pass will applied and transform the IR + * to \code Call(some_func, get_addr(s) \endcode * - * The `RetValuePacked` should be true when the external function modify an existing buffer (or some view of it) due - * to that the C language can't return a container. + * The `RetValuePacked` should be true when the external function modify an + * existing buffer (or some view of it) due to that the C language can't + * return a container. */ virtual bool RetValuePacked() const = 0; @@ -107,7 +107,8 @@ struct ExternFuncID { std::string name; std::string backend_id; - ExternFuncID(const char* name, const char* backend_id) : name(name), backend_id(backend_id) {} + ExternFuncID(const char* name, const char* backend_id) + : name(name), backend_id(backend_id) {} friend std::ostream& operator<<(std::ostream& os, const ExternFuncID& x); friend bool operator==(const ExternFuncID& a, const ExternFuncID& b) { diff --git a/paddle/cinn/backends/extern_func_emitter_builtin.cc b/paddle/cinn/backends/extern_func_emitter_builtin.cc index f3ce3a5521b85..b502cf0eeff23 100644 --- a/paddle/cinn/backends/extern_func_emitter_builtin.cc +++ b/paddle/cinn/backends/extern_func_emitter_builtin.cc @@ -22,11 +22,17 @@ namespace cinn { namespace backends { -void ExternFunctionLLVMEmitter::BindCodeGen(void* codegen) { codegen_ = reinterpret_cast(codegen); } +void ExternFunctionLLVMEmitter::BindCodeGen(void* codegen) { + codegen_ = reinterpret_cast(codegen); +} -const char* ExternFunctionLLVMEmitter::func_name() const { return fn_name_.c_str(); } +const char* ExternFunctionLLVMEmitter::func_name() const { + return fn_name_.c_str(); +} -bool ExternFunctionLLVMEmitter::RetValuePacked() const { return fn_proto().ret_type.is_void(); } +bool ExternFunctionLLVMEmitter::RetValuePacked() const { + return fn_proto().ret_type.is_void(); +} FunctionProto& ExternFunctionLLVMEmitter::fn_proto() const { auto* proto = ExternFunctionProtoRegistry::Global().Lookup(fn_name_); @@ -54,14 +60,17 @@ void ExternFunctionLLVMEmitter::EmitImpl(const ir::Call* op) { CHECK(codegen_); CodeGenLLVMforEmitter codegen_for_emitter(codegen_); llvm::Function* custom_function = llvm::dyn_cast( - codegen_for_emitter.m()->getOrInsertFunction(fn_name_, llvm_fn_type()).getCallee()); + codegen_for_emitter.m() + ->getOrInsertFunction(fn_name_, llvm_fn_type()) + .getCallee()); CHECK(custom_function) << "No function registered in JIT called " << fn_name_; custom_function->setCallingConv(llvm::CallingConv::C); std::vector args; for (auto& v : op->read_args) { if (v.as_tensor()) { - args.push_back(codegen_for_emitter.GetVar(v.as_tensor()->buffer->name, false)); + args.push_back( + codegen_for_emitter.GetVar(v.as_tensor()->buffer->name, false)); } else { auto* arg = codegen_for_emitter.Visit(&v); args.push_back(arg); @@ -69,16 +78,18 @@ void ExternFunctionLLVMEmitter::EmitImpl(const ir::Call* op) { } for (auto& v : op->write_args) { if (v.as_tensor()) { - args.push_back(codegen_for_emitter.GetVar(v.as_tensor()->buffer->name, false)); + args.push_back( + codegen_for_emitter.GetVar(v.as_tensor()->buffer->name, false)); } else { auto* arg = codegen_->Visit(&v); args.push_back(arg); } } - VLOG(3) << "function type " << op->name << ": " << DumpToString(*custom_function); + VLOG(3) << "function type " << op->name << ": " + << DumpToString(*custom_function); - auto* command = codegen_for_emitter.b()->CreateCall(custom_function, args); + auto* command = codegen_for_emitter.b()->CreateCall(custom_function, args); codegen_->extern_func_emit_res_ = command; VLOG(3) << "call: " << DumpToString(*command); } diff --git a/paddle/cinn/backends/extern_func_emitter_builtin.h b/paddle/cinn/backends/extern_func_emitter_builtin.h index 80301d7de62a7..9c896f3d85c61 100644 --- a/paddle/cinn/backends/extern_func_emitter_builtin.h +++ b/paddle/cinn/backends/extern_func_emitter_builtin.h @@ -28,7 +28,7 @@ namespace backends { //! Function names -static const char* extern_tanh_host_repr = "__cinn_host_tanh_fp32"; +static const char* extern_tanh_host_repr = "__cinn_host_tanh_fp32"; static const char* extern_tanh_v_host_repr = "__cinn_host_tanh_v"; /** @@ -36,12 +36,14 @@ static const char* extern_tanh_v_host_repr = "__cinn_host_tanh_v"; */ class CodeGenLLVMforEmitter : public CodeGenLLVM { public: - explicit CodeGenLLVMforEmitter(CodeGenLLVM* x) : CodeGenLLVM(x->m(), x->b(), x->named_vars()) {} + explicit CodeGenLLVMforEmitter(CodeGenLLVM* x) + : CodeGenLLVM(x->m(), x->b(), x->named_vars()) {} }; class ExternFunctionLLVMEmitter : public ExternFunctionEmitter { public: - explicit ExternFunctionLLVMEmitter(const std::string& fn_name) : fn_name_(fn_name) {} + explicit ExternFunctionLLVMEmitter(const std::string& fn_name) + : fn_name_(fn_name) {} void BindCodeGen(void* codegen) override; const char* func_name() const override; diff --git a/paddle/cinn/backends/extern_func_jit_register.cc b/paddle/cinn/backends/extern_func_jit_register.cc index f56a266d8e2b8..e528c7a3efe7c 100644 --- a/paddle/cinn/backends/extern_func_jit_register.cc +++ b/paddle/cinn/backends/extern_func_jit_register.cc @@ -26,9 +26,11 @@ void RegisterExternFunctionHelper(const std::string &fn_name, ExternFunctionProtoRegistry::Global().Register(fn_name, fn_proto.release()); CHECK(ExternFunctionProtoRegistry::Global().Lookup(fn_name)); - ExternFunctionEmitterRegistry::Global().Register(ExternFuncID{TargetToBackendRepr(target), fn_name.c_str()}, fn_name); + ExternFunctionEmitterRegistry::Global().Register( + ExternFuncID{TargetToBackendRepr(target), fn_name.c_str()}, fn_name); - GlobalSymbolRegistry::Global().RegisterFn(fn_name, reinterpret_cast(fn_ptr)); + GlobalSymbolRegistry::Global().RegisterFn(fn_name, + reinterpret_cast(fn_ptr)); } void RegisterExternFunction::End() { diff --git a/paddle/cinn/backends/extern_func_jit_register.h b/paddle/cinn/backends/extern_func_jit_register.h index b9ca806d4f8d0..c0f59ec5d2ff5 100644 --- a/paddle/cinn/backends/extern_func_jit_register.h +++ b/paddle/cinn/backends/extern_func_jit_register.h @@ -13,7 +13,8 @@ // limitations under the License. /** - * \file This file defines some functions and macros to help register the extern functions into JIT. + * \file This file defines some functions and macros to help register the extern + * functions into JIT. */ #pragma once @@ -33,45 +34,59 @@ #include "paddle/cinn/common/macros.h" /** - * Helper to register an external function into CINN, including the prototype, the function address. + * Helper to register an external function into CINN, including the prototype, + * the function address. * @param fn__: name of the function * @param target__: the Target. */ #define REGISTER_EXTERN_FUNC_HELPER(fn__, target__) \ - ::cinn::backends::RegisterExternFunction(#fn__, target__, reinterpret_cast(fn__)) + ::cinn::backends::RegisterExternFunction( \ + #fn__, target__, reinterpret_cast(fn__)) -#define REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__) ::cinn::backends::RegisterExternFunction(#fn__, target__) +#define REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__) \ + ::cinn::backends::RegisterExternFunction(#fn__, target__) /** * Register an external function with one input and one output. */ #define REGISTER_EXTERN_FUNC_1_IN_1_OUT(fn__, target__, in_type__, out_type__) \ - REGISTER_EXTERN_FUNC_HELPER(fn__, target__).SetRetType().AddInputType().End() + REGISTER_EXTERN_FUNC_HELPER(fn__, target__) \ + .SetRetType() \ + .AddInputType() \ + .End() /** * Register an external function with one input and one output. */ -#define REGISTER_EXTERN_FUNC_2_IN_1_OUT(fn__, target__, in_type1__, in_type2__, out_type__) \ - REGISTER_EXTERN_FUNC_HELPER(fn__, target__) \ - .SetRetType() \ - .AddInputType() \ - .AddInputType() \ +#define REGISTER_EXTERN_FUNC_2_IN_1_OUT( \ + fn__, target__, in_type1__, in_type2__, out_type__) \ + REGISTER_EXTERN_FUNC_HELPER(fn__, target__) \ + .SetRetType() \ + .AddInputType() \ + .AddInputType() \ .End() /** - * Register a sourced function(No function address, called in generated source code). + * Register a sourced function(No function address, called in generated source + * code). */ -#define REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(fn__, target__, in_type__, out_type__) \ - REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__).SetRetType().AddInputType().End() +#define REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + fn__, target__, in_type__, out_type__) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__) \ + .SetRetType() \ + .AddInputType() \ + .End() /** - * Register a sourced function(No function address, called in generated source code). + * Register a sourced function(No function address, called in generated source + * code). */ -#define REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(fn__, target__, in_type1__, in_type2__, out_type__) \ - REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__) \ - .SetRetType() \ - .AddInputType() \ - .AddInputType() \ +#define REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + fn__, target__, in_type1__, in_type2__, out_type__) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__) \ + .SetRetType() \ + .AddInputType() \ + .AddInputType() \ .End() namespace cinn { @@ -99,8 +114,13 @@ struct RegisterExternFunction { * @param target Target of the function. * @param fn_ptr Address of the function, not valid if leave as null. */ - RegisterExternFunction(const std::string& fn_name, Target target, void* fn_ptr = nullptr) - : fn_name_(fn_name), target_(target), fn_ptr_(fn_ptr), fn_proto_builder_(fn_name) {} + RegisterExternFunction(const std::string& fn_name, + Target target, + void* fn_ptr = nullptr) + : fn_name_(fn_name), + target_(target), + fn_ptr_(fn_ptr), + fn_proto_builder_(fn_name) {} /** * Add an input type. @@ -140,7 +160,8 @@ struct RegisterExternFunction { * @param handle The handle to help inference the shape. * @return itself. */ - RegisterExternFunction& SetShapeInference(FunctionProto::shape_inference_t handle) { + RegisterExternFunction& SetShapeInference( + FunctionProto::shape_inference_t handle) { fn_proto_builder_.SetShapeInference(handle); return *this; } diff --git a/paddle/cinn/backends/extern_func_protos.cc b/paddle/cinn/backends/extern_func_protos.cc index e9c737e48eb12..819d21f5eacae 100644 --- a/paddle/cinn/backends/extern_func_protos.cc +++ b/paddle/cinn/backends/extern_func_protos.cc @@ -22,26 +22,38 @@ namespace backends { ExternFunctionProtoRegistry::ExternFunctionProtoRegistry() { static const std::vector extern_funcs_fp32_unary = { - "exp", "erf", "sigmoid", "sqrt", "log", "log2", "log10", "floor", "ceil", "round", "trunc", "cos", - "cosh", "tan", "tanh", "sin", "sinh", "acos", "acosh", "asin", "asinh", "atan", "atanh", "fabs"}; - static const std::vector extern_funcs_float_bool_unary = {"isnan", "isfinite", "isinf"}; - static const std::vector extern_funcs_int_binary = { - "left_shift", "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not"}; - static const std::vector extern_funcs_int_int_unary = {"bitwise_not"}; + "exp", "erf", "sigmoid", "sqrt", "log", "log2", "log10", "floor", + "ceil", "round", "trunc", "cos", "cosh", "tan", "tanh", "sin", + "sinh", "acos", "acosh", "asin", "asinh", "atan", "atanh", "fabs"}; + static const std::vector extern_funcs_float_bool_unary = { + "isnan", "isfinite", "isinf"}; + static const std::vector extern_funcs_int_binary = { + "left_shift", + "right_shift", + "bitwise_or", + "bitwise_and", + "bitwise_xor", + "bitwise_not"}; + static const std::vector extern_funcs_int_int_unary = { + "bitwise_not"}; for (int i = 0; i < extern_funcs_fp32_unary.size(); ++i) { - auto* proto = new FunctionProto(extern_funcs_fp32_unary[i], {Float(32)}, Float(32)); + auto* proto = + new FunctionProto(extern_funcs_fp32_unary[i], {Float(32)}, Float(32)); Register(proto->name, proto); } for (int i = 0; i < extern_funcs_float_bool_unary.size(); ++i) { - auto* proto = new FunctionProto(extern_funcs_float_bool_unary[i], {Float(32)}, Bool()); + auto* proto = new FunctionProto( + extern_funcs_float_bool_unary[i], {Float(32)}, Bool()); Register(proto->name, proto); } for (int i = 0; i < extern_funcs_int_binary.size(); ++i) { - auto* proto = new FunctionProto(extern_funcs_int_binary[i], {Int(32), Int(32)}, Int(32)); + auto* proto = new FunctionProto( + extern_funcs_int_binary[i], {Int(32), Int(32)}, Int(32)); Register(proto->name, proto); } for (int i = 0; i < extern_funcs_int_int_unary.size(); ++i) { - auto* proto = new FunctionProto(extern_funcs_int_int_unary[i], {Int(32)}, Int(32)); + auto* proto = + new FunctionProto(extern_funcs_int_int_unary[i], {Int(32)}, Int(32)); Register(proto->name, proto); } @@ -57,8 +69,11 @@ ExternFunctionProtoRegistry& ExternFunctionProtoRegistry::Global() { namespace detail { FunctionProto* CreateTanhVProto() { - return new FunctionProto( - extern_func__tanh_v, {type_of()}, {type_of()}, Void(), FunctionProto::ShapeFollowNthArgument(0)); + return new FunctionProto(extern_func__tanh_v, + {type_of()}, + {type_of()}, + Void(), + FunctionProto::ShapeFollowNthArgument(0)); } } // namespace detail diff --git a/paddle/cinn/backends/function_prototype.cc b/paddle/cinn/backends/function_prototype.cc index 66e80a525f274..9f360d58f9a0c 100644 --- a/paddle/cinn/backends/function_prototype.cc +++ b/paddle/cinn/backends/function_prototype.cc @@ -43,9 +43,12 @@ bool FunctionProto::Match(const ir::Call *op) const { void FunctionProto::AssertMatch(const ir::Call *op) const { CHECK_EQ(name, op->name); - CHECK_EQ(ret_type, op->type()) << "function proto " << name << " check failed"; - CHECK_EQ(op->read_args.size(), readonly_arg_types.size()) << "function proto " << name << " check failed"; - CHECK_EQ(op->write_args.size(), mutable_arg_types.size()) << "function proto " << name << " check failed"; + CHECK_EQ(ret_type, op->type()) + << "function proto " << name << " check failed"; + CHECK_EQ(op->read_args.size(), readonly_arg_types.size()) + << "function proto " << name << " check failed"; + CHECK_EQ(op->write_args.size(), mutable_arg_types.size()) + << "function proto " << name << " check failed"; auto get_type = [](Expr u) { if (u.as_tensor() || u.as_buffer()) { @@ -73,9 +76,11 @@ void FunctionProto::AssertMatch(const ir::Call *op) const { void FunctionProto::CheckValid() { if (ret_type.is_void()) { CHECK(!mutable_arg_types.empty()) - << "A void function should have at least one mutable argument to output something"; + << "A void function should have at least one mutable argument to " + "output something"; } else { - CHECK(mutable_arg_types.empty()) << "A function with return should not have mutable argument"; + CHECK(mutable_arg_types.empty()) + << "A function with return should not have mutable argument"; } } @@ -109,7 +114,8 @@ FunctionProto *FunctionProtoRegistry::Lookup(const std::string &name) { return nullptr; } -FunctionProto *FunctionProtoRegistry::Register(absl::string_view name, FunctionProto *x) { +FunctionProto *FunctionProtoRegistry::Register(absl::string_view name, + FunctionProto *x) { #ifdef CINN_WITH_DEBUG if (FLAGS_verbose_function_register) { RAW_LOG_INFO("Register function prototype [%s]", name.data()); diff --git a/paddle/cinn/backends/function_prototype.h b/paddle/cinn/backends/function_prototype.h index 9950e4f4dad03..5c84cdc1ce0ee 100644 --- a/paddle/cinn/backends/function_prototype.h +++ b/paddle/cinn/backends/function_prototype.h @@ -28,8 +28,8 @@ namespace cinn { namespace backends { struct FunctionProto { - using shape_inference_t = - std::function /*shape*/ (const std::vector& /*arguments*/, int /*value_offset*/)>; + using shape_inference_t = std::function /*shape*/ ( + const std::vector& /*arguments*/, int /*value_offset*/)>; std::string name; std::vector readonly_arg_types; @@ -50,7 +50,7 @@ struct FunctionProto { FunctionProto(const std::string& name, const std::vector& readonly_arg_types, const std::vector& mutable_arg_types, - Type ret_type = Void(), + Type ret_type = Void(), shape_inference_t shape_inference = shape_inference_t()); /** @@ -59,7 +59,9 @@ struct FunctionProto { * @param input_types The input types. * @param ret_type The return type. */ - FunctionProto(const std::string& name, const std::vector& input_types, Type ret_type) + FunctionProto(const std::string& name, + const std::vector& input_types, + Type ret_type) : name(name), readonly_arg_types(input_types), ret_type(ret_type) {} /** diff --git a/paddle/cinn/backends/ir_schedule_test.cc b/paddle/cinn/backends/ir_schedule_test.cc index 5fd5453bffb82..2207a3d9cd073 100644 --- a/paddle/cinn/backends/ir_schedule_test.cc +++ b/paddle/cinn/backends/ir_schedule_test.cc @@ -48,17 +48,18 @@ TEST(IrSchedule, split_and_fuse1) { auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_split_and_fuse1", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_split_and_fuse1", stages, {A, B}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); - auto fused = ir_sch.Fuse("B", {0, 1}); + auto fused = ir_sch.Fuse("B", {0, 1}); auto splited = ir_sch.Split(fused, {4, -1}); auto loops = ir_sch.GetLoops("B"); - fused = ir_sch.Fuse(loops); - splited = ir_sch.Split(fused, {256, -1}); + fused = ir_sch.Fuse(loops); + splited = ir_sch.Split(fused, {256, -1}); Module::Builder builder("module1", target); for (auto& i : func) { @@ -106,16 +107,18 @@ TEST(IrSchedule, split_and_fuse2) { auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_split_and_fuse2", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_split_and_fuse2", stages, {A, B}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); auto loops = ir_sch.GetLoops("B"); - auto fused = ir_sch.Fuse(loops); + auto fused = ir_sch.Fuse(loops); auto splited = ir_sch.Split(fused, {-1, 20}); - VLOG(3) << "After split {-1, 20}, IR is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After split {-1, 20}, IR is : " + << ir_sch.GetModule().GetExprs().at(0); Module::Builder builder("module1", target); for (auto& i : func) { @@ -167,14 +170,15 @@ TEST(IrSchedule, reorder1) { auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_reorder1", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_reorder1", stages, {A, B}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); auto splited = ir_sch.Split("B", 0, {-1, 4}); - splited = ir_sch.Split("B", 2, {-1, 2}); + splited = ir_sch.Split("B", 2, {-1, 2}); auto loops = ir_sch.GetLoops("B"); ir_sch.Reorder({loops[4], loops[0]}); @@ -233,14 +237,15 @@ TEST(IrSchedule, reorder2) { auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_reorder2", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_reorder2", stages, {A, B}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); auto splited = ir_sch.Split("B", 0, {-1, 4}); - splited = ir_sch.Split("B", 2, {-1, 2}); + splited = ir_sch.Split("B", 2, {-1, 2}); ir_sch.Reorder("B", {4, 2, 3, 1, 0}); @@ -298,16 +303,17 @@ TEST(IrSchedule, reorder3) { auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_reorder3", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_reorder3", stages, {A, B}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); auto all_blocks = ir_sch.GetAllBlocks(); - auto loops = ir_sch.GetLoops(all_blocks[0]); + auto loops = ir_sch.GetLoops(all_blocks[0]); auto splited = ir_sch.Split(loops[0], {-1, 5}); - splited = ir_sch.Split("B", 2, {-1, 2}); + splited = ir_sch.Split("B", 2, {-1, 2}); ir_sch.Reorder("B", {3, 1, 2, 0, 4}); @@ -367,18 +373,19 @@ TEST(IrSchedule, reorder4) { auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_reorder4", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_reorder4", stages, {A, B}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); auto all_blocks = ir_sch.GetAllBlocks(); - auto block_b = ir_sch.GetBlock("B"); - auto loops = ir_sch.GetLoops(block_b); + auto block_b = ir_sch.GetBlock("B"); + auto loops = ir_sch.GetLoops(block_b); auto splited = ir_sch.Split("B", 0, {-1, 10}); - splited = ir_sch.Split("B", 2, {-1, 5}); + splited = ir_sch.Split("B", 2, {-1, 5}); ir_sch.Reorder("B", {0, 2, 1, 3, 4}); @@ -439,7 +446,8 @@ TEST(IrSchedule, parallel) { {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_parallel", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_parallel", stages, {A, B}, {}, {}, nullptr, target, true); CHECK(!func.empty()); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -504,7 +512,8 @@ TEST(IrSchedule, vectorize) { {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_vectorize", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_vectorize", stages, {A, B}, {}, {}, nullptr, target, true); CHECK(!func.empty()); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -577,7 +586,8 @@ TEST(IrSchedule, unroll) { {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_unroll", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_unroll", stages, {A, B}, {}, {}, nullptr, target, true); CHECK(!func.empty()); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -649,7 +659,8 @@ TEST(IrSchedule, bind) { {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_bind", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_bind", stages, {A, B}, {}, {}, nullptr, target, true); CHECK(!func.empty()); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -695,7 +706,8 @@ TEST(IrSchedule, simple_compute_at) { auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_simple_compute_at", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_simple_compute_at", stages, {A, C}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); auto ast_expr = func[0]->body; @@ -703,11 +715,11 @@ TEST(IrSchedule, simple_compute_at) { ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); - auto fused = ir_sch.Fuse("B", {0, 1}); + auto fused = ir_sch.Fuse("B", {0, 1}); auto splited = ir_sch.Split(fused, {-1, 1024}); - fused = ir_sch.Fuse("C", {0, 1}); - splited = ir_sch.Split(fused, {-1, 1024}); + fused = ir_sch.Fuse("C", {0, 1}); + splited = ir_sch.Split(fused, {-1, 1024}); auto block_b = ir_sch.GetBlock("B"); ir_sch.SimpleComputeAt(block_b, splited[1]); @@ -769,7 +781,8 @@ TEST(IrSchedule, compute_at0) { auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_compute_at0", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_compute_at0", stages, {A, C}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); auto ast_expr = func[0]->body; @@ -777,11 +790,11 @@ TEST(IrSchedule, compute_at0) { ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); - auto fused = ir_sch.Fuse("B", {0, 1}); + auto fused = ir_sch.Fuse("B", {0, 1}); auto splited = ir_sch.Split(fused, {-1, 1024}); - fused = ir_sch.Fuse("C", {0, 1}); - splited = ir_sch.Split(fused, {-1, 1024}); + fused = ir_sch.Fuse("C", {0, 1}); + splited = ir_sch.Split(fused, {-1, 1024}); auto block_b = ir_sch.GetBlock("B"); ir_sch.ComputeAt(block_b, splited[1]); @@ -844,7 +857,8 @@ TEST(IrSchedule, compute_at1) { auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_compute_at1", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_compute_at1", stages, {A, C}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); auto ast_expr = func[0]->body; @@ -853,7 +867,7 @@ TEST(IrSchedule, compute_at1) { ir::IRSchedule ir_sch(mod_expr); auto block_b = ir_sch.GetBlock("B"); - auto loops = ir_sch.GetLoops("C"); + auto loops = ir_sch.GetLoops("C"); ir_sch.ComputeAt(block_b, loops[1]); @@ -915,7 +929,8 @@ TEST(IrSchedule, compute_at2) { auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_compute_at2", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_compute_at2", stages, {A, C}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); auto ast_expr = func[0]->body; @@ -924,7 +939,7 @@ TEST(IrSchedule, compute_at2) { ir::IRSchedule ir_sch(mod_expr); auto block_b = ir_sch.GetBlock("B"); - auto loops = ir_sch.GetLoops("C"); + auto loops = ir_sch.GetLoops("C"); ir_sch.ComputeAt(block_b, loops[0]); @@ -986,7 +1001,8 @@ TEST(IrSchedule, compute_at3) { auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_compute_at3", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_compute_at3", stages, {A, C}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); auto ast_expr = func[0]->body; @@ -996,7 +1012,7 @@ TEST(IrSchedule, compute_at3) { auto block_b = ir_sch.GetBlock("B"); - auto fused = ir_sch.Fuse("C", {0, 1}); + auto fused = ir_sch.Fuse("C", {0, 1}); auto splited = ir_sch.Split(fused, {32, -1}); auto loops = ir_sch.GetLoops("C"); @@ -1066,7 +1082,8 @@ TEST(IrSchedule, compute_at4) { auto stages = CreateStages({A, B, C}); stages[B]->SetBuffer("local"); - auto func = cinn::lang::LowerVec("test_compute_at4", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_compute_at4", stages, {A, C}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); auto ast_expr = func[0]->body; @@ -1075,7 +1092,7 @@ TEST(IrSchedule, compute_at4) { ir::IRSchedule ir_sch(mod_expr); auto block_b = ir_sch.GetBlock("B"); - auto loops = ir_sch.GetLoops("C"); + auto loops = ir_sch.GetLoops("C"); ir_sch.ComputeAt(block_b, loops[1]); @@ -1127,7 +1144,8 @@ TEST(IrSchedule, compute_at5) { auto stages = CreateStages({A, B, C}); stages[B]->SetBuffer("local"); - auto func = cinn::lang::LowerVec("test_compute_at5", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_compute_at5", stages, {A, C}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); auto ast_expr = func[0]->body; @@ -1136,7 +1154,7 @@ TEST(IrSchedule, compute_at5) { ir::IRSchedule ir_sch(mod_expr); auto block_b = ir_sch.GetBlock("B"); - auto loops = ir_sch.GetLoops("C"); + auto loops = ir_sch.GetLoops("C"); ir_sch.ComputeAt(block_b, loops[0]); @@ -1189,7 +1207,8 @@ TEST(IrSchedule, compute_at6) { auto stages = CreateStages({A, B, C}); stages[B]->SetBuffer("local"); - auto func = cinn::lang::LowerVec("test_compute_at6", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_compute_at6", stages, {A, C}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); auto ast_expr = func[0]->body; @@ -1199,7 +1218,7 @@ TEST(IrSchedule, compute_at6) { auto block_b = ir_sch.GetBlock("B"); - auto fused = ir_sch.Fuse("C", {0, 1}); + auto fused = ir_sch.Fuse("C", {0, 1}); auto splited = ir_sch.Split(fused, {32, -1}); auto loops = ir_sch.GetLoops("C"); @@ -1253,7 +1272,8 @@ TEST(IrSchedule, cache_read1) { auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_cache_read1", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_cache_read1", stages, {A, C}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); @@ -1335,7 +1355,8 @@ TEST(IrSchedule, cache_read2) { auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_cache_read2", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_cache_read2", stages, {A, B}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); @@ -1351,7 +1372,8 @@ TEST(IrSchedule, cache_read2) { auto loops = ir_sch.GetLoops("B"); ir_sch.ComputeAt(a_cache, loops[1]); - VLOG(1) << "After CacheRead and ComputeAt, IR is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(1) << "After CacheRead and ComputeAt, IR is : " + << ir_sch.GetModule().GetExprs().at(0); Module::Builder builder("module1", target); for (auto& i : func) { @@ -1403,7 +1425,8 @@ TEST(IrSchedule, cache_write1) { auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_cache_write1", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_cache_write1", stages, {A, C}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); @@ -1417,7 +1440,8 @@ TEST(IrSchedule, cache_write1) { auto block_c = ir_sch.GetBlock("C"); auto c_cache = ir_sch.CacheWrite(block_c, 0, "local"); - VLOG(1) << "After CacheWrite, IR is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(1) << "After CacheWrite, IR is : " + << ir_sch.GetModule().GetExprs().at(0); Module::Builder builder("module1", target); for (auto& i : func) { @@ -1485,7 +1509,8 @@ TEST(IrSchedule, cache_write2) { auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_cache_write2", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_cache_write2", stages, {A, B}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); @@ -1496,10 +1521,11 @@ TEST(IrSchedule, cache_write2) { auto block_b = ir_sch.GetBlock("B"); auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); - auto loops = ir_sch.GetLoops("B"); + auto loops = ir_sch.GetLoops("B"); ir_sch.ComputeAt(b_cache, loops[1]); - VLOG(1) << "After CacheWrite and ComputeAt, IR is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(1) << "After CacheWrite and ComputeAt, IR is : " + << ir_sch.GetModule().GetExprs().at(0); Module::Builder builder("module1", target); for (auto& i : func) { @@ -1554,7 +1580,8 @@ TEST(IrSchedule, cache_read3) { auto stages = CreateStages({A, B, C}); stages[B]->SetBuffer("local"); - auto func = cinn::lang::LowerVec("test_cache_read3", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_cache_read3", stages, {A, C}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); @@ -1634,7 +1661,8 @@ TEST(IrSchedule, cache_write3) { auto stages = CreateStages({A, B, C}); stages[B]->SetBuffer("shared"); - auto func = cinn::lang::LowerVec("test_cache_write3", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_cache_write3", stages, {A, C}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); @@ -1652,7 +1680,8 @@ TEST(IrSchedule, cache_write3) { auto loops_b = ir_sch.GetLoops("B"); ir_sch.SyncThreads(loops_b[0]); - VLOG(1) << "After CacheWrite, IR is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(1) << "After CacheWrite, IR is : " + << ir_sch.GetModule().GetExprs().at(0); Module::Builder builder("module1", target); for (auto& i : func) { @@ -1714,7 +1743,8 @@ TEST(IrSchedule, sync_threads) { auto stages = CreateStages({A, B, C}); stages[B]->SetBuffer("shared"); - auto func = cinn::lang::LowerVec("test_sync_threads", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_sync_threads", stages, {A, C}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); @@ -1727,12 +1757,13 @@ TEST(IrSchedule, sync_threads) { auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); auto block_c = ir_sch.GetBlock("C"); auto c_cache = ir_sch.CacheWrite(block_c, 0, "local"); - block_c = ir_sch.GetBlock("C"); + block_c = ir_sch.GetBlock("C"); ir_sch.SyncThreads(block_c, false); block_b = ir_sch.GetBlock("B"); ir_sch.SyncThreads(block_b); - VLOG(1) << "After CacheWrite and SyncThreads, IR is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(1) << "After CacheWrite and SyncThreads, IR is : " + << ir_sch.GetModule().GetExprs().at(0); Module::Builder builder("module1", target); for (auto& i : func) { @@ -1787,11 +1818,14 @@ TEST(IrSchedule, cache_write4) { Placeholder A("A", {M, N, N}); Var k(32, "k0"); auto B = Compute( - {M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, j, k), {k}); }, "B"); + {M, N}, + [&](Var i, Var j) { return lang::ReduceSum(A(i, j, k), {k}); }, + "B"); auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_cache_write4", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_cache_write4", stages, {A, B}, {}, {}, nullptr, target, true); CHECK_EQ(func.size(), 1U); @@ -1802,9 +1836,10 @@ TEST(IrSchedule, cache_write4) { auto block_b = ir_sch.GetBlock("B"); auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); - auto loops = ir_sch.GetLoops("B"); + auto loops = ir_sch.GetLoops("B"); - VLOG(1) << "After CacheWrite, IR is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(1) << "After CacheWrite, IR is : " + << ir_sch.GetModule().GetExprs().at(0); Module::Builder builder("module1", target); for (auto& i : func) { @@ -1867,7 +1902,8 @@ TEST(IrSchedule, rfactor) { "B"); auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_rfactor", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_rfactor", stages, {A, B}, {}, {}, nullptr, target, true); CHECK(!func.empty()); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -1875,7 +1911,7 @@ TEST(IrSchedule, rfactor) { ir::IRSchedule ir_sch(mod_expr); auto loops = ir_sch.GetLoops("B"); CHECK_EQ(loops.size(), 3U); - auto new_rf_tensor = ir_sch.Rfactor(loops[2], 0); + auto new_rf_tensor = ir_sch.Rfactor(loops[2], 0); auto* new_rf_tensor_ref = new_rf_tensor.As(); CHECK(new_rf_tensor_ref); CHECK(new_rf_tensor_ref->buffer.defined()); @@ -1993,7 +2029,8 @@ TEST(IrSchedule, rfactor1) { "B"); auto stages = CreateStages({A, B}); - auto func = cinn::lang::LowerVec("test_rfactor", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_rfactor", stages, {A, B}, {}, {}, nullptr, target, true); CHECK(!func.empty()); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -2001,7 +2038,7 @@ TEST(IrSchedule, rfactor1) { ir::IRSchedule ir_sch(mod_expr); auto loops = ir_sch.GetLoops("B"); CHECK_EQ(loops.size(), 3U); - auto new_rf_tensor = ir_sch.Rfactor(loops[1], 1); + auto new_rf_tensor = ir_sch.Rfactor(loops[1], 1); auto* new_rf_tensor_ref = new_rf_tensor.As(); CHECK(new_rf_tensor_ref); CHECK(new_rf_tensor_ref->buffer.defined()); @@ -2113,10 +2150,13 @@ TEST(IrSchedule, rfactor2) { Placeholder B("B", {K, N}); Var k(16, "k0"); auto C = Compute( - {M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + {M, N}, + [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, + "C"); auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_rfactor", stages, {A, B, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_rfactor", stages, {A, B, C}, {}, {}, nullptr, target, true); CHECK(!func.empty()); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -2124,7 +2164,7 @@ TEST(IrSchedule, rfactor2) { ir::IRSchedule ir_sch(mod_expr); auto loops = ir_sch.GetLoops("C"); CHECK_EQ(loops.size(), 3U); - auto new_rf_tensor = ir_sch.Rfactor(loops[2], 0); + auto new_rf_tensor = ir_sch.Rfactor(loops[2], 0); auto* new_rf_tensor_ref = new_rf_tensor.As(); CHECK(new_rf_tensor_ref); CHECK(new_rf_tensor_ref->buffer.defined()); @@ -2241,13 +2281,18 @@ TEST(IrSchedule, compute_inline1) { Placeholder A("A", {M, N, P}); auto B = Compute( - {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, "B"); + {M, N, P}, + [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, + "B"); auto C = Compute( - {M, N, P}, [&](Var i, Var j, Var k) { return B(j, i, k) * Expr(2.f); }, "C"); + {M, N, P}, + [&](Var i, Var j, Var k) { return B(j, i, k) * Expr(2.f); }, + "C"); auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_compute_inline1", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_compute_inline1", stages, {A, C}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -2256,7 +2301,8 @@ TEST(IrSchedule, compute_inline1) { auto block_b = ir_sch.GetBlock("B"); ir_sch.ComputeInline(block_b); - VLOG(1) << "After ComputeInline, IR is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(1) << "After ComputeInline, IR is : " + << ir_sch.GetModule().GetExprs().at(0); Module::Builder builder("module1", target); for (auto& i : func) { builder.AddFunction(i); @@ -2306,13 +2352,18 @@ TEST(IrSchedule, compute_inline2) { Placeholder A("A", {M, N, P}); auto B = Compute( - {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, "B"); + {M, N, P}, + [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, + "B"); auto C = Compute( - {M, N, P}, [&](Var i, Var j, Var k) { return B(i, j, k) * Expr(2.f); }, "C"); + {M, N, P}, + [&](Var i, Var j, Var k) { return B(i, j, k) * Expr(2.f); }, + "C"); auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_compute_inline2", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_compute_inline2", stages, {A, C}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -2320,11 +2371,12 @@ TEST(IrSchedule, compute_inline2) { ir::IRSchedule ir_sch(mod_expr); auto block_b = ir_sch.GetBlock("B"); - auto loops = ir_sch.GetLoops("C"); + auto loops = ir_sch.GetLoops("C"); ir_sch.ComputeAt(block_b, loops[1]); block_b = ir_sch.GetBlock("B"); ir_sch.ComputeInline(block_b); - VLOG(1) << "After ComputeInline, IR is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(1) << "After ComputeInline, IR is : " + << ir_sch.GetModule().GetExprs().at(0); Module::Builder builder("module1", target); for (auto& i : func) { builder.AddFunction(i); @@ -2375,14 +2427,19 @@ TEST(IrSchedule, compute_inline3) { Placeholder A("A", {M, N, P}); auto B = Compute( - {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, "B"); + {M, N, P}, + [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, + "B"); auto C = Compute( - {M, N, P}, [&](Var i, Var j, Var k) { return B(j, i, k) * Expr(2.f); }, "C"); + {M, N, P}, + [&](Var i, Var j, Var k) { return B(j, i, k) * Expr(2.f); }, + "C"); auto stages = CreateStages({A, B, C}); stages[B]->SetBuffer("local"); - auto func = cinn::lang::LowerVec("test_compute_inline3", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_compute_inline3", stages, {A, C}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -2391,7 +2448,8 @@ TEST(IrSchedule, compute_inline3) { auto block_b = ir_sch.GetBlock("B"); ir_sch.ComputeInline(block_b); - VLOG(1) << "After ComputeInline, IR is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(1) << "After ComputeInline, IR is : " + << ir_sch.GetModule().GetExprs().at(0); Module::Builder builder("module1", target); for (auto& i : func) { @@ -2431,14 +2489,19 @@ TEST(IrSchedule, compute_inline4) { Placeholder A("A", {M, N, P}); auto B = Compute( - {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, "B"); + {M, N, P}, + [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, + "B"); auto C = Compute( - {M, N, P}, [&](Var i, Var j, Var k) { return B(i, j, k) * Expr(2.f); }, "C"); + {M, N, P}, + [&](Var i, Var j, Var k) { return B(i, j, k) * Expr(2.f); }, + "C"); auto stages = CreateStages({A, B, C}); stages[B]->SetBuffer("local"); - auto func = cinn::lang::LowerVec("test_compute_inline4", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_compute_inline4", stages, {A, C}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -2446,11 +2509,12 @@ TEST(IrSchedule, compute_inline4) { ir::IRSchedule ir_sch(mod_expr); auto block_b = ir_sch.GetBlock("B"); - auto loops = ir_sch.GetLoops("C"); + auto loops = ir_sch.GetLoops("C"); ir_sch.ComputeAt(block_b, loops[1]); block_b = ir_sch.GetBlock("B"); ir_sch.ComputeInline(block_b); - VLOG(1) << "After ComputeInline, IR is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(1) << "After ComputeInline, IR is : " + << ir_sch.GetModule().GetExprs().at(0); Module::Builder builder("module1", target); for (auto& i : func) { builder.AddFunction(i); @@ -2493,7 +2557,8 @@ TEST(IrSchedule, reverse_compute_inline1) { auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_compute_inline1", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_compute_inline1", stages, {A, C}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -2549,13 +2614,18 @@ TEST(IrSchedule, reverse_compute_inline2) { Placeholder A("A", {M, N, P}); auto B = Compute( - {M, N, P}, [&](Var i, Var j, Var k) { return Expr(1.f) + A(i, j, k); }, "B"); + {M, N, P}, + [&](Var i, Var j, Var k) { return Expr(1.f) + A(i, j, k); }, + "B"); auto C = Compute( - {N, M, P}, [&](Var i, Var j, Var k) { return Expr(2.f) * B(j, i, k); }, "C"); + {N, M, P}, + [&](Var i, Var j, Var k) { return Expr(2.f) * B(j, i, k); }, + "C"); auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_compute_inline1", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_compute_inline1", stages, {A, C}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -2613,13 +2683,18 @@ TEST(IrSchedule, copytransform1) { Placeholder A("A", {M, N, P}); auto B = Compute( - {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, "B"); + {M, N, P}, + [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, + "B"); auto C = Compute( - {M, N, P}, [&](Var i, Var j, Var k) { return B(j, i, k) * Expr(2.f); }, "C"); + {M, N, P}, + [&](Var i, Var j, Var k) { return B(j, i, k) * Expr(2.f); }, + "C"); auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_copytransform1", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_copytransform1", stages, {A, C}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -2629,12 +2704,12 @@ TEST(IrSchedule, copytransform1) { auto block_c = ir_sch.GetBlock("C"); auto loops_c = ir_sch.GetLoops(block_c); auto splited = ir_sch.Split(loops_c[1], {-1, 4}); - block_c = ir_sch.GetBlock("C"); - loops_c = ir_sch.GetLoops(block_c); - splited = ir_sch.Split(loops_c[0], {-1, 8}); + block_c = ir_sch.GetBlock("C"); + loops_c = ir_sch.GetLoops(block_c); + splited = ir_sch.Split(loops_c[0], {-1, 8}); auto block_b = ir_sch.GetBlock("B"); - block_c = ir_sch.GetBlock("C"); + block_c = ir_sch.GetBlock("C"); ir_sch.CopyTransformAndLoopInfo(block_b, block_c); Module::Builder builder("module1", target); @@ -2699,13 +2774,18 @@ TEST(IrSchedule, copytransform2) { Placeholder A("A", {M, N, P}); auto B = Compute( - {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, "B"); + {M, N, P}, + [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, + "B"); auto C = Compute( - {M, M, P}, [&](Var i, Var j, Var k) { return B(i, j, k) * Expr(2.f); }, "C"); + {M, M, P}, + [&](Var i, Var j, Var k) { return B(i, j, k) * Expr(2.f); }, + "C"); auto stages = CreateStages({A, B, C}); - auto func = cinn::lang::LowerVec("test_copytransform2", stages, {A, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_copytransform2", stages, {A, C}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; std::vector vec_ast{ast_expr}; @@ -2715,12 +2795,12 @@ TEST(IrSchedule, copytransform2) { auto block_c = ir_sch.GetBlock("C"); auto loops_c = ir_sch.GetLoops(block_c); auto splited = ir_sch.Split(loops_c[1], {-1, 4}); - block_c = ir_sch.GetBlock("C"); - loops_c = ir_sch.GetLoops(block_c); - splited = ir_sch.Split(loops_c[0], {-1, 8}); + block_c = ir_sch.GetBlock("C"); + loops_c = ir_sch.GetLoops(block_c); + splited = ir_sch.Split(loops_c[0], {-1, 8}); auto block_b = ir_sch.GetBlock("B"); - block_c = ir_sch.GetBlock("C"); + block_c = ir_sch.GetBlock("C"); ir_sch.CopyTransformAndLoopInfo(block_b, block_c); Module::Builder builder("module1", target); for (auto& i : func) { @@ -2780,10 +2860,16 @@ TEST(IrSchedule, Annotate) { auto B = Compute( {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); - auto funcs = cinn::lang::LowerVec( - "test_annotate", CreateStages({A, B}), {A, B}, {}, {}, nullptr, common::DefaultHostTarget(), true); + auto funcs = cinn::lang::LowerVec("test_annotate", + CreateStages({A, B}), + {A, B}, + {}, + {}, + nullptr, + common::DefaultHostTarget(), + true); ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body})); - auto fused = ir_sch.Fuse("B", {0, 1}); + auto fused = ir_sch.Fuse("B", {0, 1}); auto block_b = ir_sch.GetBlock("B"); ir_sch.Annotate(block_b, "k1", int(64)); block_b = ir_sch.GetBlock("B"); @@ -2806,7 +2892,8 @@ TEST(IrSchedule, Annotate) { } } })ROC"; - ASSERT_EQ(utils::GetStreamCnt(ir_sch.GetModule().GetExprs().front()), expected_expr); + ASSERT_EQ(utils::GetStreamCnt(ir_sch.GetModule().GetExprs().front()), + expected_expr); } TEST(IrSchedule, Unannotate) { @@ -2817,10 +2904,16 @@ TEST(IrSchedule, Unannotate) { auto B = Compute( {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); - auto funcs = cinn::lang::LowerVec( - "test_unannotate", CreateStages({A, B}), {A, B}, {}, {}, nullptr, common::DefaultHostTarget(), true); + auto funcs = cinn::lang::LowerVec("test_unannotate", + CreateStages({A, B}), + {A, B}, + {}, + {}, + nullptr, + common::DefaultHostTarget(), + true); ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body})); - auto fused = ir_sch.Fuse("B", {0, 1}); + auto fused = ir_sch.Fuse("B", {0, 1}); auto block_b = ir_sch.GetBlock("B"); ir_sch.Annotate(block_b, "k1", int(64)); block_b = ir_sch.GetBlock("B"); @@ -2850,7 +2943,8 @@ TEST(IrSchedule, Unannotate) { } } })ROC"; - ASSERT_EQ(utils::GetStreamCnt(ir_sch.GetModule().GetExprs().front()), expected_expr); + ASSERT_EQ(utils::GetStreamCnt(ir_sch.GetModule().GetExprs().front()), + expected_expr); } TEST(IrSchedule, ComplexIndices) { @@ -2865,14 +2959,22 @@ TEST(IrSchedule, ComplexIndices) { poly::StageMap stages = CreateStages({B}); std::vector funcs = - lang::LowerVec("TestIrSchedule_ReduceSum", stages, {A, B}, {}, {}, nullptr, target, true); + lang::LowerVec("TestIrSchedule_ReduceSum", + stages, + {A, B}, + {}, + {}, + nullptr, + target, + true); ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body})); VLOG(3) << "Lowered Expr:" << ir_sch.GetModule().GetExprs().front(); auto loops_b = ir_sch.GetLoops("B"); CHECK_EQ(loops_b.size(), 2); ir_sch.Split("B", 0, {8, -1}); - ir_sch.Split("B", 2, {32, -1}); // after first splited, loops size has added to 3 + ir_sch.Split( + "B", 2, {32, -1}); // after first splited, loops size has added to 3 VLOG(3) << "Splited Expr:" << ir_sch.GetModule().GetExprs().front(); CHECK_EQ(ir_sch.GetLoops("B").size(), 4); @@ -2880,22 +2982,29 @@ TEST(IrSchedule, ComplexIndices) { VLOG(3) << "Reordered Expr:\n" << ir_sch.GetModule().GetExprs().front(); auto block_b = ir_sch.GetBlock("B"); - auto a_cache = ir_sch.CacheRead(block_b, 1, "shared"); // actually the read_buffer A should be indexed by 0 + auto a_cache = ir_sch.CacheRead( + block_b, + 1, + "shared"); // actually the read_buffer A should be indexed by 0 VLOG(3) << "CacheRead-A Expr:\n" << ir_sch.GetModule().GetExprs().front(); loops_b = ir_sch.GetLoops("B"); ir_sch.ComputeAt(a_cache, loops_b[0]); - VLOG(3) << "A_cache-ComputeAt-B Expr:\n" << ir_sch.GetModule().GetExprs().front(); + VLOG(3) << "A_cache-ComputeAt-B Expr:\n" + << ir_sch.GetModule().GetExprs().front(); - block_b = ir_sch.GetBlock("B"); + block_b = ir_sch.GetBlock("B"); auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); VLOG(3) << "CacheWrite-B Expr:\n" << ir_sch.GetModule().GetExprs().front(); auto loops_b_cache = - ir_sch.GetLoops(b_cache.As()->schedule_block.As()->name); + ir_sch.GetLoops(b_cache.As() + ->schedule_block.As() + ->name); block_b = ir_sch.GetBlock("B"); ir_sch.ReverseComputeAt(block_b, loops_b_cache[1]); - VLOG(3) << "B-ReverseComputeAt-B_cache Expr:\n" << ir_sch.GetModule().GetExprs().front(); + VLOG(3) << "B-ReverseComputeAt-B_cache Expr:\n" + << ir_sch.GetModule().GetExprs().front(); Module::Builder builder("module1", target); for (auto& i : funcs) { @@ -2955,11 +3064,17 @@ TEST(IrSchedule, SamplePerfectTile) { {M}, [&](Expr i) { return A(i) + 1; }, "B"); poly::StageMap stages = CreateStages({A, B}); - auto funcs = cinn::lang::LowerVec( - "test_sampleperfecttile", stages, {A, B}, {}, {}, nullptr, common::DefaultHostTarget(), true); + auto funcs = cinn::lang::LowerVec("test_sampleperfecttile", + stages, + {A, B}, + {}, + {}, + nullptr, + common::DefaultHostTarget(), + true); ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body})); - auto loops_b = ir_sch.GetLoops("B"); + auto loops_b = ir_sch.GetLoops("B"); std::vector result = ir_sch.SamplePerfectTile(loops_b[0], 3, 64); ASSERT_EQ(result.size(), 3); } @@ -2974,14 +3089,20 @@ TEST(IrSchedule, GetChildBlocks) { {M, N, K}, [&A](Var i, Var j, Var k) { return A(i, j, k); }, "B"); auto C = Compute( {M, N, K}, [&B](Var i, Var j, Var k) { return B(i, j, k); }, "C"); - auto funcs = cinn::lang::LowerVec( - "test_getchildblocks", CreateStages({A, B, C}), {A, C}, {}, {}, nullptr, common::DefaultHostTarget(), true); + auto funcs = cinn::lang::LowerVec("test_getchildblocks", + CreateStages({A, B, C}), + {A, C}, + {}, + {}, + nullptr, + common::DefaultHostTarget(), + true); ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body})); auto block_b = ir_sch.GetBlock("B"); - auto loops = ir_sch.GetLoops("C"); + auto loops = ir_sch.GetLoops("C"); ir_sch.ComputeAt(block_b, loops[1]); - loops = ir_sch.GetLoops("B"); + loops = ir_sch.GetLoops("B"); auto root_block = ir_sch.GetRootBlock(loops[1]); std::string expected_expr = R"ROC(ScheduleBlock(B) @@ -2994,7 +3115,8 @@ TEST(IrSchedule, GetChildBlocks) { i0_0, i1_0, i2_0 = axis.bind(i, j, k) C[i0_0, i1_0, i2_0] = B[i0_0, i1_0, i2_0] })ROC"; - ASSERT_EQ(utils::GetStreamCnt(ir_sch.GetChildBlocks(root_block)), expected_expr); + ASSERT_EQ(utils::GetStreamCnt(ir_sch.GetChildBlocks(root_block)), + expected_expr); } TEST(IrSchedule, SampleCategorical) { @@ -3007,11 +3129,18 @@ TEST(IrSchedule, SampleCategorical) { {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k); }, "B"); poly::StageMap stages = CreateStages({A, B}); std::vector decision; - auto funcs = cinn::lang::LowerVec( - "test_samplecategorical", stages, {A, B}, {}, {}, nullptr, common::DefaultHostTarget(), true); + auto funcs = cinn::lang::LowerVec("test_samplecategorical", + stages, + {A, B}, + {}, + {}, + nullptr, + common::DefaultHostTarget(), + true); ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body})); - Expr result = ir_sch.SampleCategorical({1, 2, 3}, {1.0, 2.0, 3.0}, {decision}); + Expr result = + ir_sch.SampleCategorical({1, 2, 3}, {1.0, 2.0, 3.0}, {decision}); ASSERT_EQ(result.type(), Int(32)); } diff --git a/paddle/cinn/backends/llvm/codegen_llvm.cc b/paddle/cinn/backends/llvm/codegen_llvm.cc index 67b4979c3fb58..5e49c36525b30 100644 --- a/paddle/cinn/backends/llvm/codegen_llvm.cc +++ b/paddle/cinn/backends/llvm/codegen_llvm.cc @@ -33,6 +33,11 @@ #include #include +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/Alignment.h" #include "paddle/cinn/backends/extern_func_emitter.h" #include "paddle/cinn/backends/extern_func_emitter_builtin.h" #include "paddle/cinn/backends/llvm/llvm_util.h" @@ -45,11 +50,6 @@ #include "paddle/cinn/runtime/cinn_runtime.h" #include "paddle/cinn/runtime/intrinsic.h" #include "paddle/cinn/utils/string.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Verifier.h" -#include "llvm/Support/Alignment.h" namespace cinn { namespace backends { @@ -110,16 +110,19 @@ CodeGenLLVM::CodeGenLLVM(llvm::Module *m, } symbol_table_->PushScope(); // Create a new scope by default. - md_builder_ = std::make_unique(b_->getContext()); - md_tbaa_root_ = md_builder_->createTBAARoot("cinn-tbaa"); + md_builder_ = std::make_unique(b_->getContext()); + md_tbaa_root_ = md_builder_->createTBAARoot("cinn-tbaa"); md_tbaa_alias_set_ = md_builder_->createTBAANode("cinn-alias", md_tbaa_root_); InitTarget(target_); } CodeGenLLVM::~CodeGenLLVM() {} -llvm::Value *CodeGenLLVM::EmitVectorSlice(llvm::Value *vec, int begin, int extent) { - int numel = llvm::dyn_cast(vec->getType())->getNumElements(); +llvm::Value *CodeGenLLVM::EmitVectorSlice(llvm::Value *vec, + int begin, + int extent) { + int numel = + llvm::dyn_cast(vec->getType())->getNumElements(); if (extent == numel && begin == 0) return vec; CHECK(begin >= 0 && extent <= numel) << "Slicing out of bound!"; @@ -138,18 +141,21 @@ llvm::Value *CodeGenLLVM::EmitVectorSlice(llvm::Value *vec, int begin, int exten llvm::Value *CodeGenLLVM::EmitVectorPad(llvm::Value *vec, int lanes) { #if LLVM_VERSION_MAJOR <= 10 - llvm::Value *mask = llvm::UndefValue::get(llvm::VectorType::get(b_->getInt32Ty(), lanes)); -#else llvm::Value *mask = - llvm::UndefValue::get(llvm::VectorType::get(b_->getInt32Ty(), llvm::ElementCount(lanes, false /*Scalable*/))); + llvm::UndefValue::get(llvm::VectorType::get(b_->getInt32Ty(), lanes)); +#else + llvm::Value *mask = llvm::UndefValue::get(llvm::VectorType::get( + b_->getInt32Ty(), llvm::ElementCount(lanes, false /*Scalable*/))); #endif - int numel = llvm::dyn_cast(vec->getType())->getNumElements(); + int numel = + llvm::dyn_cast(vec->getType())->getNumElements(); CHECK(numel <= lanes); if (numel == lanes) return vec; for (int i = 0; i < numel; i++) { - mask = - InsertElement(mask, llvm::ConstantInt::get(b_->getInt32Ty(), i), llvm::ConstantInt::get(b_->getInt32Ty(), i)); + mask = InsertElement(mask, + llvm::ConstantInt::get(b_->getInt32Ty(), i), + llvm::ConstantInt::get(b_->getInt32Ty(), i)); } return ShuffleVector(vec, vec, mask); @@ -163,10 +169,12 @@ llvm::Value *CodeGenLLVM::EmitVectorConcat(std::vector vecs) { while (vecs.size() > 1) { std::vector new_vecs; for (size_t i = 0; i < vecs.size() - 1; i += 2) { - auto *lhs = vecs[i]; - auto *rhs = vecs[i + 1]; - const auto lhs_lanes = llvm::dyn_cast(lhs->getType())->getNumElements(); - const auto rhs_lanes = llvm::dyn_cast(rhs->getType())->getNumElements(); + auto *lhs = vecs[i]; + auto *rhs = vecs[i + 1]; + const auto lhs_lanes = + llvm::dyn_cast(lhs->getType())->getNumElements(); + const auto rhs_lanes = + llvm::dyn_cast(rhs->getType())->getNumElements(); if (lhs_lanes < rhs_lanes) { lhs = EmitVectorPad(lhs, rhs_lanes); } else if (lhs_lanes > rhs_lanes) { @@ -189,29 +197,39 @@ llvm::Value *CodeGenLLVM::EmitVectorConcat(std::vector vecs) { return EmitVectorSlice(vecs[0], 0, lanes); } -llvm::Value *CodeGenLLVM::EmitBinaryOp( - llvm::Value *lhs, llvm::Value *rhs, char opcode, bool is_integral, bool is_signed) { +llvm::Value *CodeGenLLVM::EmitBinaryOp(llvm::Value *lhs, + llvm::Value *rhs, + char opcode, + bool is_integral, + bool is_signed) { llvm::Instruction::BinaryOps ops; CHECK_EQ(lhs->getType(), rhs->getType()) << "the types of operands of binary operation are mismatch" - << ", lhs[" << DumpToString(*lhs) << "] " << opcode << " rhs[" << DumpToString(*rhs) << "]" - << ", lhs_type[" << DumpToString(*lhs->getType()) << "], rhs_type[" << DumpToString(*rhs->getType()) << "]"; + << ", lhs[" << DumpToString(*lhs) << "] " << opcode << " rhs[" + << DumpToString(*rhs) << "]" + << ", lhs_type[" << DumpToString(*lhs->getType()) << "], rhs_type[" + << DumpToString(*rhs->getType()) << "]"; switch (opcode) { case '+': - ops = is_integral ? llvm::Instruction::BinaryOps::Add : llvm::Instruction::BinaryOps::FAdd; + ops = is_integral ? llvm::Instruction::BinaryOps::Add + : llvm::Instruction::BinaryOps::FAdd; break; case '-': - ops = is_integral ? llvm::Instruction::BinaryOps::Sub : llvm::Instruction::BinaryOps::FSub; + ops = is_integral ? llvm::Instruction::BinaryOps::Sub + : llvm::Instruction::BinaryOps::FSub; break; case '*': - ops = is_integral ? llvm::Instruction::BinaryOps::Mul : llvm::Instruction::BinaryOps::FMul; + ops = is_integral ? llvm::Instruction::BinaryOps::Mul + : llvm::Instruction::BinaryOps::FMul; break; case '/': - ops = is_integral ? (is_signed ? llvm::Instruction::BinaryOps::SDiv : llvm::Instruction::BinaryOps::UDiv) + ops = is_integral ? (is_signed ? llvm::Instruction::BinaryOps::SDiv + : llvm::Instruction::BinaryOps::UDiv) : llvm::Instruction::BinaryOps::FDiv; break; case '%': - ops = is_integral ? (is_signed ? llvm::Instruction::BinaryOps::SRem : llvm::Instruction::BinaryOps::URem) + ops = is_integral ? (is_signed ? llvm::Instruction::BinaryOps::SRem + : llvm::Instruction::BinaryOps::URem) : llvm::Instruction::BinaryOps::FRem; break; default: @@ -249,16 +267,22 @@ llvm::Value *CodeGenLLVM::Visit(const ir::FloatImm *op) { return nullptr; } -llvm::Value *CodeGenLLVM::LLVMGenGlobalStringVar(const std::string &data) { return b_->CreateGlobalStringPtr(data); } +llvm::Value *CodeGenLLVM::LLVMGenGlobalStringVar(const std::string &data) { + return b_->CreateGlobalStringPtr(data); +} -llvm::Value *CodeGenLLVM::Visit(const ir::StringImm *op) { return LLVMGenGlobalStringVar(op->value); } +llvm::Value *CodeGenLLVM::Visit(const ir::StringImm *op) { + return LLVMGenGlobalStringVar(op->value); +} llvm::Value *CodeGenLLVM::Visit(const ir::Add *op) { - return EmitBinaryOp(Visit(&op->a()), Visit(&op->b()), '+', is_integral_type(op->type())); + return EmitBinaryOp( + Visit(&op->a()), Visit(&op->b()), '+', is_integral_type(op->type())); } llvm::Value *CodeGenLLVM::Visit(const ir::Sub *op) { - return EmitBinaryOp(Visit(&op->a()), Visit(&op->b()), '-', is_integral_type(op->type())); + return EmitBinaryOp( + Visit(&op->a()), Visit(&op->b()), '-', is_integral_type(op->type())); } llvm::Value *CodeGenLLVM::Visit(const ir::Mul *op) { @@ -268,11 +292,13 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Mul *op) { } llvm::Value *CodeGenLLVM::Visit(const ir::Div *op) { - return EmitBinaryOp(Visit(&op->a()), Visit(&op->b()), '/', is_integral_type(op->type())); + return EmitBinaryOp( + Visit(&op->a()), Visit(&op->b()), '/', is_integral_type(op->type())); } llvm::Value *CodeGenLLVM::Visit(const ir::Mod *op) { - return EmitBinaryOp(Visit(&op->a()), Visit(&op->b()), '%', is_integral_type(op->type())); + return EmitBinaryOp( + Visit(&op->a()), Visit(&op->b()), '%', is_integral_type(op->type())); } #define __IR_EMITTER_DEFINE_CMP_VISITOR(__sop, __uop, __fop) \ @@ -289,23 +315,39 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Mod *op) { } \ return EmitComparison(predicate, lhs, rhs, b_) -llvm::Value *CodeGenLLVM::Visit(const ir::EQ *op) { __IR_EMITTER_DEFINE_CMP_VISITOR(EQ, EQ, OEQ); } +llvm::Value *CodeGenLLVM::Visit(const ir::EQ *op) { + __IR_EMITTER_DEFINE_CMP_VISITOR(EQ, EQ, OEQ); +} -llvm::Value *CodeGenLLVM::Visit(const ir::NE *op) { __IR_EMITTER_DEFINE_CMP_VISITOR(NE, NE, ONE); } +llvm::Value *CodeGenLLVM::Visit(const ir::NE *op) { + __IR_EMITTER_DEFINE_CMP_VISITOR(NE, NE, ONE); +} -llvm::Value *CodeGenLLVM::Visit(const ir::LT *op) { __IR_EMITTER_DEFINE_CMP_VISITOR(SLT, ULT, OLT); } +llvm::Value *CodeGenLLVM::Visit(const ir::LT *op) { + __IR_EMITTER_DEFINE_CMP_VISITOR(SLT, ULT, OLT); +} -llvm::Value *CodeGenLLVM::Visit(const ir::LE *op) { __IR_EMITTER_DEFINE_CMP_VISITOR(SLE, ULE, OLE); } +llvm::Value *CodeGenLLVM::Visit(const ir::LE *op) { + __IR_EMITTER_DEFINE_CMP_VISITOR(SLE, ULE, OLE); +} -llvm::Value *CodeGenLLVM::Visit(const ir::GT *op) { __IR_EMITTER_DEFINE_CMP_VISITOR(SGT, UGT, OGT); } +llvm::Value *CodeGenLLVM::Visit(const ir::GT *op) { + __IR_EMITTER_DEFINE_CMP_VISITOR(SGT, UGT, OGT); +} -llvm::Value *CodeGenLLVM::Visit(const ir::GE *op) { __IR_EMITTER_DEFINE_CMP_VISITOR(SGE, UGE, OGE); } +llvm::Value *CodeGenLLVM::Visit(const ir::GE *op) { + __IR_EMITTER_DEFINE_CMP_VISITOR(SGE, UGE, OGE); +} #undef __IR_EMITTER_DEFINE_CMP_VISITOR -llvm::Value *CodeGenLLVM::Visit(const ir::And *op) { return And(Visit(&op->a()), Visit(&op->b())); } +llvm::Value *CodeGenLLVM::Visit(const ir::And *op) { + return And(Visit(&op->a()), Visit(&op->b())); +} -llvm::Value *CodeGenLLVM::Visit(const ir::Or *op) { return Or(Visit(&op->a()), Visit(&op->b())); } +llvm::Value *CodeGenLLVM::Visit(const ir::Or *op) { + return Or(Visit(&op->a()), Visit(&op->b())); +} llvm::Value *CodeGenLLVM::Visit(const ir::Min *op) { auto *lhs = Visit(&op->a()); @@ -344,11 +386,13 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Minus *op) { return (op->type().is_int() || op->type().is_uint()) ? Neg(v) : FNeg(v); } -llvm::Value *CodeGenLLVM::Visit(const ir::Not *op) { return Not(Visit(&op->v())); } +llvm::Value *CodeGenLLVM::Visit(const ir::Not *op) { + return Not(Visit(&op->v())); +} llvm::Value *CodeGenLLVM::Visit(const ir::Cast *op) { auto from = op->v().type(); - auto to = op->type(); + auto to = op->type(); llvm::Type *source = CinnTypeToLLVMType(from, m_); llvm::Type *target = CinnTypeToLLVMType(to, m_); @@ -360,7 +404,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Cast *op) { // pod_value_t cast to a value. if (op->v().type().is_customized_type() && - op->v().type().customized_type() == common::customized_type::kpod_value_t) { // pod_value_t operator + op->v().type().customized_type() == + common::customized_type::kpod_value_t) { // pod_value_t operator llvm::Function *callee{}; if (op->type().is_bool()) { callee = m_->getFunction(runtime::intrinsic::pod_value_to_bool); @@ -390,7 +435,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Cast *op) { callee = m_->getFunction(runtime::intrinsic::pod_value_to_float16); } else if (op->type() == type_of()) { callee = m_->getFunction(runtime::intrinsic::pod_value_to_void_p); - } else if (op->type() == type_of() || op->type() == type_of()) { + } else if (op->type() == type_of() || + op->type() == type_of()) { callee = m_->getFunction(runtime::intrinsic::pod_value_to_buffer_p); } else { LOG(ERROR) << "can't cast cinn_pod_value_t to " << op->type(); @@ -415,10 +461,10 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Cast *op) { if (to.is_bool()) { if (from.is_float()) { llvm::Constant *zero = llvm::ConstantFP::get(source, 0.); - value = FCmpONE(value, zero); + value = FCmpONE(value, zero); } else { llvm::Constant *zero = llvm::ConstantInt::get(source, 0); - value = ICmpNE(value, zero); + value = ICmpNE(value, zero); } break; } @@ -464,14 +510,17 @@ llvm::Value *CodeGenLLVM::CreateSerialFor(const ir::For *op, int stride) { do { break; llvm::BasicBlock *preheader_bb = b_->GetInsertBlock(); - auto *for_begin = llvm::BasicBlock::Create(b_->getContext(), "for_begin", b_->GetInsertBlock()->getParent()); - auto *for_body = llvm::BasicBlock::Create(b_->getContext(), "for_body", b_->GetInsertBlock()->getParent()); - auto *for_end = llvm::BasicBlock::Create(b_->getContext(), "for_end", b_->GetInsertBlock()->getParent()); + auto *for_begin = llvm::BasicBlock::Create( + b_->getContext(), "for_begin", b_->GetInsertBlock()->getParent()); + auto *for_body = llvm::BasicBlock::Create( + b_->getContext(), "for_body", b_->GetInsertBlock()->getParent()); + auto *for_end = llvm::BasicBlock::Create( + b_->getContext(), "for_end", b_->GetInsertBlock()->getParent()); Br(for_begin); b_->SetInsertPoint(for_begin); - auto *begin = Visit(&op->min); + auto *begin = Visit(&op->min); auto *loop_value = PHI(begin->getType(), 2); loop_value->addIncoming(begin, preheader_bb); @@ -488,27 +537,34 @@ llvm::Value *CodeGenLLVM::CreateSerialFor(const ir::For *op, int stride) { symbol_table_->Erase(op->loop_var->name); } - auto loop_next = Add(loop_value, llvm::ConstantInt::get(b_->getInt32Ty(), stride), "indvar.inc", true, true); + auto loop_next = Add(loop_value, + llvm::ConstantInt::get(b_->getInt32Ty(), stride), + "indvar.inc", + true, + true); loop_value->addIncoming(loop_next, b_->GetInsertBlock()); Br(for_begin); b_->SetInsertPoint(for_end); return nullptr; - // llvm::AllocaInst *loop_var = Alloca(b_->getInt32Ty(), nullptr, op->loop_var->name); - // loop_var->setAlignment(llvm::Align(4)); + // llvm::AllocaInst *loop_var = Alloca(b_->getInt32Ty(), nullptr, + // op->loop_var->name); loop_var->setAlignment(llvm::Align(4)); // SetVar(op->loop_var->name, loop_var); } while (false); //////////////////////////////////// llvm::BasicBlock *preheader_bb = b_->GetInsertBlock(); - llvm::BasicBlock *exit_bb = nullptr; + llvm::BasicBlock *exit_bb = nullptr; llvm::BasicBlock::iterator insert_point = b_->GetInsertPoint(); if (insert_point == preheader_bb->end()) { CHECK(!preheader_bb->getTerminator()); - exit_bb = llvm::BasicBlock::Create(b_->getContext(), "loop_exit", b_->GetInsertBlock()->getParent(), nullptr); + exit_bb = llvm::BasicBlock::Create(b_->getContext(), + "loop_exit", + b_->GetInsertBlock()->getParent(), + nullptr); } else { CHECK(preheader_bb->getTerminator()); exit_bb = preheader_bb->splitBasicBlock(insert_point, "loop_exit"); @@ -516,29 +572,37 @@ llvm::Value *CodeGenLLVM::CreateSerialFor(const ir::For *op, int stride) { } llvm::BasicBlock *header_bb = - llvm::BasicBlock::Create(b_->getContext(), "loop_header", b_->GetInsertBlock()->getParent(), nullptr); + llvm::BasicBlock::Create(b_->getContext(), + "loop_header", + b_->GetInsertBlock()->getParent(), + nullptr); llvm::BasicBlock *body_bb = - llvm::BasicBlock::Create(b_->getContext(), "loop_body", b_->GetInsertBlock()->getParent(), nullptr); + llvm::BasicBlock::Create(b_->getContext(), + "loop_body", + b_->GetInsertBlock()->getParent(), + nullptr); llvm::Function *func = preheader_bb->getParent(); - b_->SetInsertPoint(&func->getEntryBlock(), func->getEntryBlock().getFirstInsertionPt()); + b_->SetInsertPoint(&func->getEntryBlock(), + func->getEntryBlock().getFirstInsertionPt()); llvm::Value *old_var = GetVar(op->loop_var->name); // loop iterator - llvm::AllocaInst *loop_var = Alloca(b_->getInt32Ty(), nullptr, op->loop_var->name); + llvm::AllocaInst *loop_var = + Alloca(b_->getInt32Ty(), nullptr, op->loop_var->name); loop_var->setAlignment(llvm::Align(4)); SetVar(op->loop_var->name, loop_var); b_->SetInsertPoint(preheader_bb); llvm::Value *start_index = Visit(&op->min); - llvm::Value *end_index = Visit(&op->extent); + llvm::Value *end_index = Visit(&op->extent); Store(start_index, loop_var); CHECK(!preheader_bb->getTerminator()); Br(header_bb); // loop_header b_->SetInsertPoint(header_bb); - llvm::Value *indvar = Load(loop_var, "indvar"); + llvm::Value *indvar = Load(loop_var, "indvar"); llvm::Value *exit_cond = ICmpSGE(indvar, end_index); CondBr(/*Cond=*/exit_cond, /*True=*/exit_bb, @@ -564,8 +628,8 @@ llvm::Value *CodeGenLLVM::CreateSerialFor(const ir::For *op, int stride) { loop_metadata.push_back(temp_node.get()); // TODO(fc500110): Loop vectorize - // auto *vectorization = op->metadata.vectorization ? b_->getTrue() : b_->getFalse(); - // loop_metadata.push_back(llvm::MDNode::get( + // auto *vectorization = op->metadata.vectorization ? b_->getTrue() : + // b_->getFalse(); loop_metadata.push_back(llvm::MDNode::get( // ctx, {llvm::MDString::get(ctx, "llvm.loop.vectorize.enable"), // llvm::ConstantAsMetadata::get(b_->getFalse())})); @@ -583,9 +647,9 @@ llvm::Value *CodeGenLLVM::CreateSerialFor(const ir::For *op, int stride) { } /* - loop_metadata.push_back(llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, llvm_unroll_metadata)})); - auto loop_id = llvm::MDNode::get(ctx, loop_metadata); - loop_id->replaceOperandWith(0, loop_id); + loop_metadata.push_back(llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, + llvm_unroll_metadata)})); auto loop_id = llvm::MDNode::get(ctx, + loop_metadata); loop_id->replaceOperandWith(0, loop_id); back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id); */ @@ -599,7 +663,9 @@ llvm::Value *CodeGenLLVM::CreateSerialFor(const ir::For *op, int stride) { return nullptr; } -llvm::Value *CodeGenLLVM::Visit(const ir::For *op) { return CreateSerialFor(op); } +llvm::Value *CodeGenLLVM::Visit(const ir::For *op) { + return CreateSerialFor(op); +} llvm::Value *CodeGenLLVM::Visit(const ir::PolyFor *op) { CINN_NOT_IMPLEMENTED @@ -607,7 +673,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::PolyFor *op) { } llvm::Value *CodeGenLLVM::Visit(const ir::Select *op) { - return Select(Visit(&op->condition), Visit(&op->true_value), Visit(&op->false_value)); + return Select( + Visit(&op->condition), Visit(&op->true_value), Visit(&op->false_value)); } llvm::Value *CodeGenLLVM::Visit(const ir::IfThenElse *op) { @@ -615,15 +682,18 @@ llvm::Value *CodeGenLLVM::Visit(const ir::IfThenElse *op) { bool emit_else = op->false_case.defined(); - auto &ll_ctx = b_->getContext(); + auto &ll_ctx = b_->getContext(); auto *ll_function = b_->GetInsertBlock()->getParent(); - llvm::Value *cond = Visit(&op->condition); - llvm::BasicBlock *then_block = llvm::BasicBlock::Create(ll_ctx, "if-then", ll_function); - llvm::BasicBlock *end_block = llvm::BasicBlock::Create(ll_ctx, "if-end", ll_function); + llvm::Value *cond = Visit(&op->condition); + llvm::BasicBlock *then_block = + llvm::BasicBlock::Create(ll_ctx, "if-then", ll_function); + llvm::BasicBlock *end_block = + llvm::BasicBlock::Create(ll_ctx, "if-end", ll_function); if (op->false_case.defined()) { - llvm::BasicBlock *else_block = llvm::BasicBlock::Create(ll_ctx, "if-else", ll_function); + llvm::BasicBlock *else_block = + llvm::BasicBlock::Create(ll_ctx, "if-else", ll_function); CondBr(cond, then_block, else_block); // true case @@ -652,8 +722,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Block *op) { llvm::Value *ret = nullptr; - llvm::BasicBlock *block = - llvm::BasicBlock::Create(b_->getContext(), "block", b_->GetInsertBlock()->getParent(), nullptr); + llvm::BasicBlock *block = llvm::BasicBlock::Create( + b_->getContext(), "block", b_->GetInsertBlock()->getParent(), nullptr); Br(block); b_->SetInsertPoint(block); @@ -665,17 +735,26 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Block *op) { return ret; } -llvm::Value *CodeGenLLVM::Visit(const ir::PrimitiveNode *) { CINN_NOT_IMPLEMENTED return nullptr; } -llvm::Value *CodeGenLLVM::Visit(const ir::_BufferRange_ *) { CINN_NOT_IMPLEMENTED return nullptr; } -llvm::Value *CodeGenLLVM::Visit(const ir::ScheduleBlock *) { CINN_NOT_IMPLEMENTED return nullptr; } -llvm::Value *CodeGenLLVM::Visit(const ir::ScheduleBlockRealize *) { CINN_NOT_IMPLEMENTED return nullptr; } +llvm::Value *CodeGenLLVM::Visit(const ir::PrimitiveNode *) { + CINN_NOT_IMPLEMENTED return nullptr; +} +llvm::Value *CodeGenLLVM::Visit(const ir::_BufferRange_ *) { + CINN_NOT_IMPLEMENTED return nullptr; +} +llvm::Value *CodeGenLLVM::Visit(const ir::ScheduleBlock *) { + CINN_NOT_IMPLEMENTED return nullptr; +} +llvm::Value *CodeGenLLVM::Visit(const ir::ScheduleBlockRealize *) { + CINN_NOT_IMPLEMENTED return nullptr; +} llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) { if (op->name == runtime::intrinsic::debug_log_repr) { return EmitCall_debug_info(op); } else if (op->is_extern_call()) { - auto emitter_id = ExternFuncID{backend_llvm_host, op->name.c_str()}; - const auto &fn_name = ExternFunctionEmitterRegistry::Global().Lookup(emitter_id); + auto emitter_id = ExternFuncID{backend_llvm_host, op->name.c_str()}; + const auto &fn_name = + ExternFunctionEmitterRegistry::Global().Lookup(emitter_id); if (!fn_name.empty()) { ExternFunctionLLVMEmitter emitter(fn_name); emitter.BindCodeGen(this); @@ -701,8 +780,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) { if (op->is_cinn_call()) { auto arg = ir::intrinsics::GetAddr::Make(op->read_args[0]); - args[0] = Visit(&arg); - args[0] = BitCast(args[0], ll_void_p_ty(), "cast_to_void_p"); + args[0] = Visit(&arg); + args[0] = BitCast(args[0], ll_void_p_ty(), "cast_to_void_p"); } return Call(callee, std::move(args)); @@ -715,7 +794,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Module_ *op) { } for (auto &fn : op->functions) { - VLOG(1) << "JIT Linking function [" << fn.As()->name << "]"; + VLOG(1) << "JIT Linking function [" << fn.As()->name + << "]"; ir::Expr fn_expr(fn); auto fnll = Visit(&fn_expr); @@ -741,7 +821,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Var_ *op) { return result; } -void CodeGenLLVM::Scalarize(const Expr &e, std::function flambda) { +void CodeGenLLVM::Scalarize( + const Expr &e, std::function flambda) { if (const ir::Ramp *ramp = e.As()) { for (int i = 0; i < ramp->type().lanes(); ++i) { Expr offset = ramp->base + (ramp->stride * i); @@ -762,12 +843,13 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Load *op) { if (auto *tensor_op = op->tensor.As()) { array = GetVar(tensor_op->name); } else if (auto *var_op = op->tensor.As()) { - array = GetVar(var_op->name); + array = GetVar(var_op->name); is_alias = alias_vars_.count(const_cast(var_op)); } else { array = Visit(&op->tensor); } - CHECK(array) << "fail to Visit Load node: " << Expr(const_cast(op)); + CHECK(array) << "fail to Visit Load node: " + << Expr(const_cast(op)); ir::Expr index = op->index(); if (index.type().lanes() <= 1) { @@ -775,11 +857,13 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Load *op) { indices.push_back(Visit(&index)); // auto load_inst = Load(InBoundsGEP(array, std::move(indices))); - auto *load_inst = AlignedLoad(InBoundsGEP(array, std::move(indices)), llvm::MaybeAlign()); + auto *load_inst = + AlignedLoad(InBoundsGEP(array, std::move(indices)), llvm::MaybeAlign()); /* if (is_alias) { - llvm::MDNode *meta = md_builder_->createTBAANode("cinn-alias", md_tbaa_root_); - load_inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); + llvm::MDNode *meta = md_builder_->createTBAANode("cinn-alias", + md_tbaa_root_); load_inst->setMetadata("tbaa", + md_builder_->createTBAAStructTagNode(meta, meta, 0)); } */ if (auto *load_tensor = op->tensor.as_tensor()) { @@ -788,32 +872,35 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Load *op) { { int alignment = op->type().bits(); - alignment = 8; + alignment = 8; CHECK_GT(alignment, 0); load_inst->setAlignment(llvm::Align(std::min(alignment, 8))); } // TODO(fc500110): tbaa AliasAnalysis // auto md_tbaa_root = md_builder_->createTBAARoot("cinn-tbaa"); - // auto md_tbaa_alias_set = md_builder_->createTBAANode("cinn-alias", md_tbaa_root); - // llvm::MDNode *meta = md_tbaa_alias_set; - // load_inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); + // auto md_tbaa_alias_set = md_builder_->createTBAANode("cinn-alias", + // md_tbaa_root); llvm::MDNode *meta = md_tbaa_alias_set; + // load_inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, + // meta, 0)); return load_inst; } else { // vector load Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1); - llvm::Value *buffer = Visit(&op->tensor); + llvm::Value *buffer = Visit(&op->tensor); if (dense_strided_ramp.defined()) { CHECK(op->type().is_vector()); return DenseVectorLoad(op); } // scalarize load - Type type = op->type(); - int alignment = type.bits() / 8; - llvm::Value *ret = llvm::UndefValue::get(CinnTypeToLLVMType(type, m_, true)); - auto flambda = [&](int i, llvm::Value *index) { - auto *ptr = CreateBufferPtr(type.ElementOf(), buffer, index); - llvm::LoadInst *load_inst = b_->CreateAlignedLoad(ptr, llvm::Align(alignment), "load_vec"); - ret = b_->CreateInsertElement(ret, load_inst, ll_const_int32(i)); + Type type = op->type(); + int alignment = type.bits() / 8; + llvm::Value *ret = + llvm::UndefValue::get(CinnTypeToLLVMType(type, m_, true)); + auto flambda = [&](int i, llvm::Value *index) { + auto *ptr = CreateBufferPtr(type.ElementOf(), buffer, index); + llvm::LoadInst *load_inst = + b_->CreateAlignedLoad(ptr, llvm::Align(alignment), "load_vec"); + ret = b_->CreateInsertElement(ret, load_inst, ll_const_int32(i)); if (auto *load_tensor = op->tensor.as_tensor()) { AddTbaaMetadata(load_inst, load_tensor->name, op->index()); } @@ -829,7 +916,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Store *op) { if (auto *tensor_op = op->tensor.As()) { array = GetVar(tensor_op->name); } else if (auto *var_op = op->tensor.As()) { - array = GetVar(var_op->name); + array = GetVar(var_op->name); is_alias = alias_vars_.count(const_cast(var_op)); } CHECK(array) << "array is null"; @@ -840,62 +927,75 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Store *op) { std::vector indices; indices.push_back(Visit(&index)); - // auto *store_inst = Store(Visit(&op->value), InBoundsGEP(array, std::move(indices))); - auto *store_inst = AlignedStore(Visit(&op->value), InBoundsGEP(array, std::move(indices)), llvm::MaybeAlign()); + // auto *store_inst = Store(Visit(&op->value), InBoundsGEP(array, + // std::move(indices))); + auto *store_inst = AlignedStore(Visit(&op->value), + InBoundsGEP(array, std::move(indices)), + llvm::MaybeAlign()); /* if (is_alias) { - llvm::MDNode *meta = md_builder_->createTBAANode("cinn-alias", md_tbaa_root_); - store_inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); + llvm::MDNode *meta = md_builder_->createTBAANode("cinn-alias", + md_tbaa_root_); store_inst->setMetadata("tbaa", + md_builder_->createTBAAStructTagNode(meta, meta, 0)); } */ { int alignment = op->type().bits(); - alignment = 8; + alignment = 8; CHECK_GT(alignment, 0); store_inst->setAlignment(llvm::Align(std::min(alignment, 8))); } // TODO(fc500110): tbaa AliasAnalysis // auto md_tbaa_root = md_builder_->createTBAARoot("cinn-tbaa"); - // auto md_tbaa_alias_set = md_builder_->createTBAANode("cinn-alias", md_tbaa_root); - // llvm::MDNode *meta = md_tbaa_alias_set; - // store_inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); + // auto md_tbaa_alias_set = md_builder_->createTBAANode("cinn-alias", + // md_tbaa_root); llvm::MDNode *meta = md_tbaa_alias_set; + // store_inst->setMetadata("tbaa", + // md_builder_->createTBAAStructTagNode(meta, meta, 0)); AddTbaaMetadata(store_inst, op->tensor.as_tensor()->name, op->index()); return store_inst; } else { // vector store Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1); - auto ramp_expr = op->index(); - auto *ramp = index.As(); - auto *buffer = Visit(&op->tensor); - auto *value = Visit(&op->value); + auto ramp_expr = op->index(); + auto *ramp = index.As(); + auto *buffer = Visit(&op->tensor); + auto *value = Visit(&op->value); if (dense_strided_ramp.defined()) { // stride 1 int total_lanes = op->type().lanes(); - int step = naive_vec_alignment_ / op->type().ElementOf().bits(); + int step = naive_vec_alignment_ / op->type().ElementOf().bits(); // fit the total_lanes in native_lanes(split into multiple native steps) for (int offset = 0; offset < total_lanes; offset += total_lanes) { int lanes = total_lanes; Expr base = common::AutoSimplify(ramp->base + offset); optim::VarModSimplify(&base); - auto *ptr = CreateBufferPtr(op->type().ElementOf(), buffer, Visit(&base)); - auto *vtype = llvm::VectorType::get(CinnTypeToLLVMType(op->type().ElementOf(), m_, true), - llvm::ElementCount(lanes, false /*Scalable*/)) + auto *ptr = + CreateBufferPtr(op->type().ElementOf(), buffer, Visit(&base)); + auto *vtype = llvm::VectorType::get( + CinnTypeToLLVMType(op->type().ElementOf(), m_, true), + llvm::ElementCount(lanes, false /*Scalable*/)) ->getPointerTo(); int alignment = std::max(op->type().ElementOf().bits() / 8, 1); llvm::StoreInst *inst = - b_->CreateAlignedStore(CreateVecSlice(value, offset, lanes), b_->CreatePointerCast(ptr, vtype), alignment); + b_->CreateAlignedStore(CreateVecSlice(value, offset, lanes), + b_->CreatePointerCast(ptr, vtype), + alignment); AddTbaaMetadata(inst, op->tensor.as_tensor()->name, base); return inst; } } // scalarize store - Type type = op->type(); - int alignment = type.bits() / 8; - llvm::Value *ret = llvm::UndefValue::get(CinnTypeToLLVMType(type, m_, true)); - auto flambda = [&](int i, llvm::Value *index) { + Type type = op->type(); + int alignment = type.bits() / 8; + llvm::Value *ret = + llvm::UndefValue::get(CinnTypeToLLVMType(type, m_, true)); + auto flambda = [&](int i, llvm::Value *index) { auto *ptr = CreateBufferPtr(type.ElementOf(), buffer, index); llvm::StoreInst *store_inst = - b_->CreateAlignedStore(b_->CreateExtractElement(value, i), ptr, llvm::Align(alignment), "store_vec"); + b_->CreateAlignedStore(b_->CreateExtractElement(value, i), + ptr, + llvm::Align(alignment), + "store_vec"); ret = b_->CreateInsertElement(ret, store_inst, ll_const_int32(i)); if (auto *store_tensor = op->tensor.as_tensor()) { AddTbaaMetadata(store_inst, store_tensor->name, op->index()); @@ -909,7 +1009,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Store *op) { llvm::Value *CodeGenLLVM::Visit(const ir::Alloc *op) { auto *buffer_op = op->destination.As(); - auto *buffer = GetVar(buffer_op->name); + auto *buffer = GetVar(buffer_op->name); CHECK(buffer); return buffer; @@ -922,7 +1022,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Free *op) { return nullptr; } -llvm::Value *CodeGenLLVM::Visit(const ir::_Buffer_ *op) { return GetVar(op->name); } +llvm::Value *CodeGenLLVM::Visit(const ir::_Buffer_ *op) { + return GetVar(op->name); +} llvm::Value *CodeGenLLVM::Visit(const ir::_Tensor_ *op) { return GetVar(op->name); @@ -934,12 +1036,14 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Tensor_ *op) { return SetVar(buffer_op->name, Visit(buffer_op)); } -template ::value, int> = 0> +template ::value, int> = 0> void appendBody(std::vector &new_body, T &&v) { new_body.push_back(v); } -template ::value, int> = 1> +template ::value, int> = 1> void appendBody(std::vector &new_body, T &&v) { new_body.insert(new_body.end(), v.begin(), v.end()); } @@ -948,12 +1052,13 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_LoweredFunc_ *op) { auto init_function_state = [this]() { alias_vars_.clear(); }; init_function_state(); - CHECK_EQ(op->alloc_output_buffer_exprs.size(), op->dealloc_output_buffer_exprs.size()) + CHECK_EQ(op->alloc_output_buffer_exprs.size(), + op->dealloc_output_buffer_exprs.size()) << "the count of allocation and deallocation expressions is not match"; std::vector new_body; - auto create_temp_buffers = op->PrepareCreateTempBufferExprs(); - auto alloca_temp_buffers = op->PrepareAllocTempBufferExprs(); + auto create_temp_buffers = op->PrepareCreateTempBufferExprs(); + auto alloca_temp_buffers = op->PrepareAllocTempBufferExprs(); auto dealloca_temp_buffers = op->PrepareDeallocTempBufferExprs(); appendBody(new_body, op->argument_prepare_exprs); @@ -974,7 +1079,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_LoweredFunc_ *op) { /*Result=*/b_->getVoidTy(), /*Params=*/std::move(arg_types), /*isVarArg=*/false); - CHECK(m_->getFunction(op->name) == nullptr) << "function[" << op->name << "] exists"; + CHECK(m_->getFunction(op->name) == nullptr) + << "function[" << op->name << "] exists"; f_ = llvm::Function::Create( /*FunctionType=*/function_type, @@ -987,7 +1093,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_LoweredFunc_ *op) { std::vector args; args.reserve(f_->arg_size()); std::transform( - f_->arg_begin(), f_->arg_end(), std::back_inserter(args), [](auto &arg) { return std::addressof(arg); }); + f_->arg_begin(), f_->arg_end(), std::back_inserter(args), [](auto &arg) { + return std::addressof(arg); + }); llvm::BasicBlock *entry = llvm::BasicBlock::Create( /*Context=*/b_->getContext(), @@ -1012,8 +1120,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Let *op) { if (op->body.defined()) { SetVar(name, Visit(&op->body)); } else { - llvm::AllocaInst *inst = Alloca(CinnTypeToLLVMType(op->type(), m_), nullptr, name); - auto get_align = [](int n) { + llvm::AllocaInst *inst = + Alloca(CinnTypeToLLVMType(op->type(), m_), nullptr, name); + auto get_align = [](int n) { int i{0}, r{1}; while (n > r) { r *= 2; @@ -1022,7 +1131,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Let *op) { return r / 8; }; int align_bits = std::max(op->type().bits(), 8); - int align = get_align(align_bits); + int align = get_align(align_bits); inst->setAlignment(llvm::Align(align)); SetVar(name, inst); } @@ -1030,9 +1139,13 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Let *op) { return GetVar(name); } -llvm::Value *CodeGenLLVM::Visit(const ir::Reduce *op) { __IR_EMITTER_NOT_IMPLEMENTED(op); } +llvm::Value *CodeGenLLVM::Visit(const ir::Reduce *op) { + __IR_EMITTER_NOT_IMPLEMENTED(op); +} -llvm::Value *CodeGenLLVM::Visit(const ir::Ramp *op) { __IR_EMITTER_NOT_IMPLEMENTED(op); } +llvm::Value *CodeGenLLVM::Visit(const ir::Ramp *op) { + __IR_EMITTER_NOT_IMPLEMENTED(op); +} llvm::Value *CodeGenLLVM::Visit(const ir::Broadcast *op) { #if LLVM_VERSION_MAJOR >= 11 @@ -1040,15 +1153,18 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Broadcast *op) { #else const int elem_count = op->lanes; #endif - llvm::Value *value = Visit(&op->value); - llvm::Constant *undef = llvm::UndefValue::get(llvm::VectorType::get(value->getType(), elem_count)); - llvm::Constant *zero = llvm::ConstantInt::get(ll_int32_ty(), 0); - value = b_->CreateInsertElement(undef, value, zero, "broadcast"); + llvm::Value *value = Visit(&op->value); + llvm::Constant *undef = llvm::UndefValue::get( + llvm::VectorType::get(value->getType(), elem_count)); + llvm::Constant *zero = llvm::ConstantInt::get(ll_int32_ty(), 0); + value = b_->CreateInsertElement(undef, value, zero, "broadcast"); llvm::Constant *zeros = llvm::ConstantVector::getSplat(elem_count, zero); return b_->CreateShuffleVector(value, undef, zeros, "broadcast_shuffle"); } -llvm::Value *CodeGenLLVM::Visit(const ir::FracOp *op) { __IR_EMITTER_NOT_IMPLEMENTED(op); } +llvm::Value *CodeGenLLVM::Visit(const ir::FracOp *op) { + __IR_EMITTER_NOT_IMPLEMENTED(op); +} llvm::Value *CodeGenLLVM::Visit(const ir::Product *op) { auto size = op->operands().size(); @@ -1088,7 +1204,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Sum *op) { void CodeGenLLVM::Compile(const ir::Module &module) { Visit(module.self()); } -llvm::Value *CodeGenLLVM::EmitCall_buffer_malloc(const ir::Call *op) { return nullptr; } +llvm::Value *CodeGenLLVM::EmitCall_buffer_malloc(const ir::Call *op) { + return nullptr; +} llvm::Value *CodeGenLLVM::EmitCall_get_address(const ir::Call *op) { if (auto *read_var = op->read_args.front().as_var()) { @@ -1125,7 +1243,8 @@ llvm::Value *CodeGenLLVM::SetVar(const std::string &name, llvm::Value *val) { return val; } -llvm::FunctionType *CodeGenLLVM::GenFunctionTypeFromCinnFunction(const ir::_LoweredFunc_ *func, bool with_buffer_type) { +llvm::FunctionType *CodeGenLLVM::GenFunctionTypeFromCinnFunction( + const ir::_LoweredFunc_ *func, bool with_buffer_type) { auto func_ret_type = CinnTypeToLLVMType(Void(), m_); std::vector arg_types; for (auto &arg : func->args) { @@ -1151,7 +1270,7 @@ llvm::Value *CodeGenLLVM::DenseVectorLoad(const ir::Load *op) { auto *ramp = index.As(); CHECK(ramp); - int load_lanes = op->type().lanes(); + int load_lanes = op->type().lanes(); int native_lanes = naive_vec_alignment_ / op->type().bits(); std::vector slices; @@ -1164,7 +1283,7 @@ llvm::Value *CodeGenLLVM::DenseVectorLoad(const ir::Load *op) { auto slice_base = common::AutoSimplify(ramp->base + i); optim::VarModSimplify(&slice_base); auto slide_stride = Expr(1); - auto slide_index = slice_base; + auto slide_index = slice_base; #if LLVM_VERSION_MAJOR >= 11 const llvm::ElementCount elem_count(slice_lanes, /*scalable*/ false); @@ -1172,14 +1291,18 @@ llvm::Value *CodeGenLLVM::DenseVectorLoad(const ir::Load *op) { const int elem_count = slice_lanes; #endif - llvm::Type *slice_type = llvm::VectorType::get(CinnTypeToLLVMType(op->type().ElementOf(), m_, true), elem_count); + llvm::Type *slice_type = llvm::VectorType::get( + CinnTypeToLLVMType(op->type().ElementOf(), m_, true), elem_count); - llvm::Value *elt_ptr = CreateBufferPtr(op->type().ElementOf(), buffer, Visit(&slice_base)); - llvm::Value *vec_ptr = b_->CreatePointerCast(elt_ptr, slice_type->getPointerTo(), "get_vec_ptr"); + llvm::Value *elt_ptr = + CreateBufferPtr(op->type().ElementOf(), buffer, Visit(&slice_base)); + llvm::Value *vec_ptr = b_->CreatePointerCast( + elt_ptr, slice_type->getPointerTo(), "get_vec_ptr"); int alignment = std::max(op->type().ElementOf().bits() / 8, 1); - llvm::Instruction *load_inst = b_->CreateAlignedLoad(vec_ptr, llvm::Align(alignment), "load_vec"); + llvm::Instruction *load_inst = + b_->CreateAlignedLoad(vec_ptr, llvm::Align(alignment), "load_vec"); AddTbaaMetadata(load_inst, op->tensor.as_tensor()->name, op->index()); slices.push_back(load_inst); @@ -1190,22 +1313,29 @@ llvm::Value *CodeGenLLVM::DenseVectorLoad(const ir::Load *op) { return slices[0]; } -llvm::Value *CodeGenLLVM::CreateBufferVecPtr(Type t, llvm::Value *buffer, llvm::Value *index) { +llvm::Value *CodeGenLLVM::CreateBufferVecPtr(Type t, + llvm::Value *buffer, + llvm::Value *index) { CHECK_GT(t.lanes(), 1) << "type is not a vector type: " << t; - llvm::PointerType *btype = llvm::dyn_cast(buffer->getType()); + llvm::PointerType *btype = + llvm::dyn_cast(buffer->getType()); CHECK(btype); - llvm::PointerType *ptype = CinnTypeToLLVMType(t, m_)->getPointerTo(btype->getAddressSpace()); + llvm::PointerType *ptype = + CinnTypeToLLVMType(t, m_)->getPointerTo(btype->getAddressSpace()); if (btype != ptype) { buffer = b_->CreatePointerCast(buffer, ptype); } return b_->CreateInBoundsGEP(buffer, index); } -llvm::Value *CodeGenLLVM::CreateBufferPtr(Type t, llvm::Value *buffer, llvm::Value *index) { +llvm::Value *CodeGenLLVM::CreateBufferPtr(Type t, + llvm::Value *buffer, + llvm::Value *index) { CHECK_EQ(t.lanes(), 1); auto *btype = llvm::dyn_cast(buffer->getType()); CHECK(btype); - auto *ptype = CinnTypeToLLVMType(t, m_)->getPointerTo(btype->getAddressSpace()); + auto *ptype = + CinnTypeToLLVMType(t, m_)->getPointerTo(btype->getAddressSpace()); CHECK(ptype); if (btype != ptype) { buffer = b_->CreatePointerCast(buffer, ptype, "pointer_cast"); @@ -1213,8 +1343,11 @@ llvm::Value *CodeGenLLVM::CreateBufferPtr(Type t, llvm::Value *buffer, llvm::Val return b_->CreateInBoundsGEP(buffer, index, "buffer_ptr"); } -llvm::Value *CodeGenLLVM::CreateVecSlice(llvm::Value *vec, int begin, int lanes) { - int total_lanes = llvm::dyn_cast(vec->getType())->getNumElements(); +llvm::Value *CodeGenLLVM::CreateVecSlice(llvm::Value *vec, + int begin, + int lanes) { + int total_lanes = + llvm::dyn_cast(vec->getType())->getNumElements(); CHECK_LE(begin + lanes, total_lanes); if (lanes == total_lanes && begin == 0) return vec; // full slice std::vector indices; @@ -1222,7 +1355,8 @@ llvm::Value *CodeGenLLVM::CreateVecSlice(llvm::Value *vec, int begin, int lanes) indices.push_back(ll_const_int32(begin + i)); } llvm::Constant *undef = llvm::UndefValue::get(vec->getType()); - return b_->CreateShuffleVector(vec, undef, llvm::ConstantVector::get(indices)); + return b_->CreateShuffleVector( + vec, undef, llvm::ConstantVector::get(indices)); } void CodeGenLLVM::InitTarget(const Target &target) { @@ -1257,19 +1391,22 @@ bool LLVM_WillVarLowerAsPointer(const std::string &var_name) { return var_name == "_args" || utils::Endswith(var_name, "__ptr"); } -void CodeGenLLVM::AddTbaaMetadata(llvm::Instruction *inst, absl::string_view buffer, Expr index) { - // If the index is constant, generate some TBAA info that helps LLVM understand our loads/stores aren't aliased. +void CodeGenLLVM::AddTbaaMetadata(llvm::Instruction *inst, + absl::string_view buffer, + Expr index) { + // If the index is constant, generate some TBAA info that helps LLVM + // understand our loads/stores aren't aliased. bool constant_index = false; - int base = 0; - int width = 1; + int base = 0; + int width = 1; if (index.defined()) { if (const ir::Ramp *ramp = index.As()) { auto *pstride_int = ramp->stride.As(); - auto *pbase_int = ramp->base.As(); + auto *pbase_int = ramp->base.As(); if (pstride_int && pbase_int) { int stride = pstride_int->value; - base = pbase_int->value; + base = pbase_int->value; CHECK_GE(base, 0); width = NextPowerOfTwo(ramp->lanes * stride); @@ -1282,8 +1419,8 @@ void CodeGenLLVM::AddTbaaMetadata(llvm::Instruction *inst, absl::string_view buf } else { auto *pbase_int = index.As(); if (pbase_int) { - int pbase = pbase_int->value; - base = pbase; + int pbase = pbase_int->value; + base = pbase; constant_index = true; } } @@ -1291,16 +1428,18 @@ void CodeGenLLVM::AddTbaaMetadata(llvm::Instruction *inst, absl::string_view buf llvm::MDBuilder builder(b_->getContext()); - // Add type-based-alias-analysis metadata to the pointer, so that loads and stores to different buffers can get - // reordered. + // Add type-based-alias-analysis metadata to the pointer, so that loads and + // stores to different buffers can get reordered. llvm::MDNode *tbaa = builder.createTBAARoot("cinn buffer"); - tbaa = builder.createTBAAScalarTypeNode(std::string(buffer), tbaa); + tbaa = builder.createTBAAScalarTypeNode(std::string(buffer), tbaa); - // Add metadata for constant indices to allow loads and stores to the same buffer to get reordered. + // Add metadata for constant indices to allow loads and stores to the same + // buffer to get reordered. if (constant_index) { for (int w = 1024; w >= width; w /= 2) { int b = (base / w) * w; - tbaa = builder.createTBAAScalarTypeNode(utils::StringFormat("%s.width%d.base%d", buffer.data(), w, b), tbaa); + tbaa = builder.createTBAAScalarTypeNode( + utils::StringFormat("%s.width%d.base%d", buffer.data(), w, b), tbaa); } } @@ -1324,17 +1463,19 @@ llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::BufferGetDataHandle *op) { return Call(callee, std::move(args)); } -llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::BufferGetDataConstHandle *op) { +llvm::Value *CodeGenLLVM::Visit( + const ir::intrinsics::BufferGetDataConstHandle *op) { std::vector args({Visit(&op->buffer)}); auto *callee = m_->getFunction("cinn_buffer_get_data_const_handle"); return Call(callee, std::move(args)); } llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::BufferCreate *op) { - auto *callee = m_->getFunction(runtime::intrinsic::buffer_create_default); + auto *callee = m_->getFunction(runtime::intrinsic::buffer_create_default); auto buffer_node = op->buffer.as_buffer(); CHECK(buffer_node); - std::vector args({ll_const_int32(buffer_node->target.runtime_arch())}); + std::vector args( + {ll_const_int32(buffer_node->target.runtime_arch())}); uint64_t memory_size = (buffer_node->dtype.ElementOf().bits() + 7) / 8; for (auto shape : buffer_node->shape) { int shape_int = shape.as_int32(); @@ -1352,7 +1493,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::GetAddr *op) { } else if (auto *n = op->data.as_buffer()) { return GetVar(n->name); } - if (auto *n = op->data.As()) { // get the address to an element in a buffer + if (auto *n = + op->data + .As()) { // get the address to an element in a buffer auto *e = Visit(&op->data); if (auto *e_load = llvm::dyn_cast(e)) { return e_load->getPointerOperand(); @@ -1369,10 +1512,11 @@ llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::ArgsConstruct *op) { var = ir::intrinsics::GetAddr::Make(var); llvm::Value *ll_var = Visit(&var); - var = ir::Cast::Make(type_of(), var); + var = ir::Cast::Make(type_of(), var); Expr num_args(static_cast(op->args.size())); - args.push_back(BitCast(ll_var, ll_cinn_pod_p_ty(), "cast_to_pod_value_t_ptr")); + args.push_back( + BitCast(ll_var, ll_cinn_pod_p_ty(), "cast_to_pod_value_t_ptr")); args.push_back(Visit(&num_args)); for (auto &arg : op->args) { args.push_back(Visit(&arg)); @@ -1382,9 +1526,10 @@ llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::ArgsConstruct *op) { return Call(callee, std::move(args)); } -llvm::Function *CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, - llvm::Type *ret_type, - llvm::ArrayRef arg_types) { +llvm::Function *CodeGenLLVM::GetIntrinsicDecl( + llvm::Intrinsic::ID id, + llvm::Type *ret_type, + llvm::ArrayRef arg_types) { llvm::Module *module = m_; if (!llvm::Intrinsic::isOverloaded(id)) { @@ -1398,7 +1543,8 @@ llvm::Function *CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, auto try_match = [&](llvm::FunctionType *f_ty, bool var_arg) { overload_types.clear(); llvm::ArrayRef ref(infos); - auto match = llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types); + auto match = + llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types); if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) { if (llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref)) { return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg; @@ -1464,7 +1610,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::BuiltinIntrin *op) { } llvm::Intrinsic::ID id = op->id; - int64_t num_signature = op->arg_nums; + int64_t num_signature = op->arg_nums; std::vector arg_value; std::vector arg_type; for (size_t i = 0; i < op->args.size(); ++i) { @@ -1475,8 +1621,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::BuiltinIntrin *op) { } CHECK(!op->args.empty()); llvm::Type *return_type = CinnTypeToLLVMType(op->type(), m_, true); - llvm::Function *fn = GetIntrinsicDecl(id, return_type, arg_type); - CHECK(fn) << "Cannot find intrinsic declaration, possible type mismatch: " << llvm::Intrinsic::getName(id, {}); + llvm::Function *fn = GetIntrinsicDecl(id, return_type, arg_type); + CHECK(fn) << "Cannot find intrinsic declaration, possible type mismatch: " + << llvm::Intrinsic::getName(id, {}); return b_->CreateCall(fn, arg_value); } diff --git a/paddle/cinn/backends/llvm/codegen_llvm.h b/paddle/cinn/backends/llvm/codegen_llvm.h index aba39d22e5073..facf13d05147b 100644 --- a/paddle/cinn/backends/llvm/codegen_llvm.h +++ b/paddle/cinn/backends/llvm/codegen_llvm.h @@ -50,7 +50,8 @@ class LLVMIRVisitor : public ir::IRVisitorBase { }; /** - * Tell whether a variable called \p \var_name will lowered to a pointer type in LLVM. + * Tell whether a variable called \p \var_name will lowered to a pointer type in + * LLVM. * @param var_name name of the variable. * @return a boolean. */ @@ -97,7 +98,10 @@ class SymbolTable { }; struct SymbolTableGuard { - explicit SymbolTableGuard(SymbolTable &symbol_table) : symbol_table_(symbol_table) { symbol_table.PushScope(); } + explicit SymbolTableGuard(SymbolTable &symbol_table) + : symbol_table_(symbol_table) { + symbol_table.PushScope(); + } ~SymbolTableGuard() { symbol_table_.PopScope(); } @@ -110,10 +114,11 @@ struct SymbolTableGuard { */ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin { public: - explicit CodeGenLLVM(llvm::Module *m, - llvm::IRBuilder<> *b, - const std::shared_ptr &symbol_table = nullptr, - const Target &target = common::DefaultHostTarget()); + explicit CodeGenLLVM( + llvm::Module *m, + llvm::IRBuilder<> *b, + const std::shared_ptr &symbol_table = nullptr, + const Target &target = common::DefaultHostTarget()); // Common llvm types // @{ @@ -130,23 +135,39 @@ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin { inline llvm::Type *ll_uint32_ty() const { return llvm_type_of(m_); } inline llvm::Type *ll_uint64_ty() const { return llvm_type_of(m_); } - inline llvm::Type *ll_bf16_ty() const { return llvm_type_of(m_); } - inline llvm::Type *ll_fp16_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_bf16_ty() const { + return llvm_type_of(m_); + } + inline llvm::Type *ll_fp16_ty() const { + return llvm_type_of(m_); + } inline llvm::Type *ll_fp32_ty() const { return llvm_type_of(m_); } inline llvm::Type *ll_fp64_ty() const { return llvm_type_of(m_); } - inline llvm::Type *ll_cinn_buffer_p_ty() const { return llvm_type_of(m_); } - inline llvm::Type *ll_cinn_pod_ty() const { return llvm_type_of(m_); } - inline llvm::Type *ll_cinn_pod_p_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_cinn_buffer_p_ty() const { + return llvm_type_of(m_); + } + inline llvm::Type *ll_cinn_pod_ty() const { + return llvm_type_of(m_); + } + inline llvm::Type *ll_cinn_pod_p_ty() const { + return llvm_type_of(m_); + } // @} //! get a llvm type equivalent to a CINN type. - inline llvm::Type *ll_type_of(Type type) { return CinnTypeToLLVMType(type, m_); } + inline llvm::Type *ll_type_of(Type type) { + return CinnTypeToLLVMType(type, m_); + } // Common methods to get a constant // @{ - inline llvm::Constant *ll_const_int32(int v) const { return llvm::ConstantInt::get(b_->getInt32Ty(), v); } - inline llvm::Constant *ll_const_int64(int v) const { return llvm::ConstantInt::get(b_->getInt64Ty(), v); } + inline llvm::Constant *ll_const_int32(int v) const { + return llvm::ConstantInt::get(b_->getInt32Ty(), v); + } + inline llvm::Constant *ll_const_int64(int v) const { + return llvm::ConstantInt::get(b_->getInt64Ty(), v); + } // @} //! Get the bound LLVM module. @@ -171,7 +192,8 @@ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin { std::shared_ptr named_vars() { return symbol_table_; } - llvm::FunctionType *GenFunctionTypeFromCinnFunction(const ir::_LoweredFunc_ *func, bool with_buffer_type); + llvm::FunctionType *GenFunctionTypeFromCinnFunction( + const ir::_LoweredFunc_ *func, bool with_buffer_type); virtual llvm::Value *GetVar(const std::string &name, bool lazy = true); @@ -181,13 +203,16 @@ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin { // Constants // @{ - inline llvm::Value *llvm_int32_constant(int v) { return llvm::ConstantInt::get(ll_int32_ty(), v); } + inline llvm::Value *llvm_int32_constant(int v) { + return llvm::ConstantInt::get(ll_int32_ty(), v); + } // @} virtual ~CodeGenLLVM(); protected: - // TODO(Superjomn) When to clear the existing local variables when switch to another function? + // TODO(Superjomn) When to clear the existing local variables when switch to + // another function? llvm::Value *SetVar(const std::string &name, llvm::Value *val); llvm::Value *EmitVectorSlice(llvm::Value *vec, int begin, int extent); llvm::Value *EmitVectorPad(llvm::Value *vec, int lanes); @@ -202,26 +227,35 @@ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin { llvm::Value *EmitCall_debug_info(const ir::Call *op); // @} - llvm::Value *EmitBinaryOp(llvm::Value *lhs, llvm::Value *rhs, char opcode, bool is_integral, bool is_signed = true); + llvm::Value *EmitBinaryOp(llvm::Value *lhs, + llvm::Value *rhs, + char opcode, + bool is_integral, + bool is_signed = true); llvm::Value *LLVMGenGlobalStringVar(const std::string &data); llvm::Value *CreateBufferPtr(Type t, llvm::Value *buffer, llvm::Value *index); - llvm::Value *CreateBufferVecPtr(Type t, llvm::Value *buffer, llvm::Value *index); + llvm::Value *CreateBufferVecPtr(Type t, + llvm::Value *buffer, + llvm::Value *index); llvm::Value *CreateVecSlice(llvm::Value *vec, int begin, int lanes); llvm::Value *DenseVectorLoad(const ir::Load *load); llvm::Value *CreateSerialFor(const ir::For *op, int stride = 1); /** - * Mark a load or store with type-based-alias-analysis metadata so that LLVM can optimize by reordering loads and - * stores across different buffers. + * Mark a load or store with type-based-alias-analysis metadata so that LLVM + * can optimize by reordering loads and stores across different buffers. */ - void AddTbaaMetadata(llvm::Instruction *inst, absl::string_view buffer, Expr index); + void AddTbaaMetadata(llvm::Instruction *inst, + absl::string_view buffer, + Expr index); void InitTarget(const Target &target); - void Scalarize(const Expr &e, std::function flambda); + void Scalarize(const Expr &e, + std::function flambda); llvm::Module *m_; llvm::IRBuilder<> *b_; @@ -230,7 +264,8 @@ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin { std::unique_ptr md_builder_; - // std::shared_ptr> named_vars_; + // std::shared_ptr> + // named_vars_; std::shared_ptr symbol_table_; std::unordered_set alias_vars_; diff --git a/paddle/cinn/backends/llvm/codegen_llvm_test.cc b/paddle/cinn/backends/llvm/codegen_llvm_test.cc index b0d9370f43555..aa6ca91af1b26 100644 --- a/paddle/cinn/backends/llvm/codegen_llvm_test.cc +++ b/paddle/cinn/backends/llvm/codegen_llvm_test.cc @@ -43,11 +43,12 @@ namespace { auto CreateCodeGenLLVMTestLLVM() { auto context = std::make_unique(); - auto b = std::make_unique>(*context); - auto m = std::make_unique("test_codegen_llvm", *context); + auto b = std::make_unique>(*context); + auto m = std::make_unique("test_codegen_llvm", *context); auto emitter = std::make_unique(m.get(), b.get()); - return std::make_tuple(std::move(m), std::move(b), std::move(context), std::move(emitter)); + return std::make_tuple( + std::move(m), std::move(b), std::move(context), std::move(emitter)); } auto CreateTensor() { @@ -60,22 +61,27 @@ auto CreateTensor() { lang::Buffer c_buf(common::Float(32)); - return std::make_tuple(std::move(a), std::move(b), std::move(c), std::move(c_buf)); + return std::make_tuple( + std::move(a), std::move(b), std::move(c), std::move(c_buf)); } auto CreateLLVMType(llvm::LLVMContext *context) { - llvm::Type *i8 = llvm::Type::getInt8Ty(*context); - llvm::Type *i32 = llvm::Type::getInt32Ty(*context); - llvm::Type *i64 = llvm::Type::getInt64Ty(*context); - llvm::Type *u32 = llvm::Type::getInt32Ty(*context); - llvm::Type *f32 = llvm::Type::getFloatTy(*context); - llvm::Type *f16 = llvm::Type::getHalfTy(*context); + llvm::Type *i8 = llvm::Type::getInt8Ty(*context); + llvm::Type *i32 = llvm::Type::getInt32Ty(*context); + llvm::Type *i64 = llvm::Type::getInt64Ty(*context); + llvm::Type *u32 = llvm::Type::getInt32Ty(*context); + llvm::Type *f32 = llvm::Type::getFloatTy(*context); + llvm::Type *f16 = llvm::Type::getHalfTy(*context); llvm::Type *bf16 = llvm::Type::getBFloatTy(*context); return std::make_tuple(i8, i32, i64, u32, f32, f16, bf16); } -template +template auto CreateBinaryOp(common::Type t, T1 x, T2 y) { auto px = std::make_unique(t, x); auto py = std::make_unique(t, y); @@ -86,7 +92,10 @@ auto CreateBinaryOp(common::Type t, T1 x, T2 y) { return std::make_unique(std::move(ex), std::move(ey)); } -auto CreateIrBuffer(common::Type t, std::string name, std::vector shape, int data_alignment = 0) { +auto CreateIrBuffer(common::Type t, + std::string name, + std::vector shape, + int data_alignment = 0) { CHECK_GE(data_alignment, 0); auto buffer = ir::_Buffer_::Make(std::move(name), std::move(t)); @@ -125,14 +134,14 @@ using cinn::common::float16; TEST(CodeGenLLVM, Imm) { auto context = std::make_unique(); - auto b = std::make_unique>(*context); - auto m = std::make_unique("test_codegen_llvm", *context); + auto b = std::make_unique>(*context); + auto m = std::make_unique("test_codegen_llvm", *context); auto emitter = std::make_unique(m.get(), b.get()); - llvm::Type *i32 = llvm::Type::getInt32Ty(*context); - llvm::Type *u32 = llvm::Type::getInt32Ty(*context); - llvm::Type *f32 = llvm::Type::getFloatTy(*context); - llvm::Type *f16 = llvm::Type::getHalfTy(*context); + llvm::Type *i32 = llvm::Type::getInt32Ty(*context); + llvm::Type *u32 = llvm::Type::getInt32Ty(*context); + llvm::Type *f32 = llvm::Type::getFloatTy(*context); + llvm::Type *f16 = llvm::Type::getHalfTy(*context); llvm::Type *bf16 = llvm::Type::getBFloatTy(*context); llvm::Value *value = nullptr; @@ -166,20 +175,20 @@ TEST(CodeGenLLVM, Imm) { TEST(CodeGenLLVM, Expr) { auto context = std::make_unique(); - auto b = std::make_unique>(*context); - auto m = std::make_unique("test_binary_op", *context); + auto b = std::make_unique>(*context); + auto m = std::make_unique("test_binary_op", *context); auto emitter = std::make_unique(m.get(), b.get()); - llvm::Type *i1 = llvm::Type::getInt1Ty(*context); - llvm::Type *i8 = llvm::Type::getInt8Ty(*context); - llvm::Type *i32 = llvm::Type::getInt32Ty(*context); - llvm::Type *i64 = llvm::Type::getInt64Ty(*context); - llvm::Type *u32 = llvm::Type::getInt32Ty(*context); - llvm::Type *f32 = llvm::Type::getFloatTy(*context); - llvm::Type *f16 = llvm::Type::getHalfTy(*context); + llvm::Type *i1 = llvm::Type::getInt1Ty(*context); + llvm::Type *i8 = llvm::Type::getInt8Ty(*context); + llvm::Type *i32 = llvm::Type::getInt32Ty(*context); + llvm::Type *i64 = llvm::Type::getInt64Ty(*context); + llvm::Type *u32 = llvm::Type::getInt32Ty(*context); + llvm::Type *f32 = llvm::Type::getFloatTy(*context); + llvm::Type *f16 = llvm::Type::getHalfTy(*context); llvm::Type *bf16 = llvm::Type::getBFloatTy(*context); - llvm::Value *value = nullptr; + llvm::Value *value = nullptr; llvm::Value *expect_value = nullptr; std::string outs; @@ -187,12 +196,12 @@ TEST(CodeGenLLVM, Expr) { // + do { - int x = 2; - int y = 3; + int x = 2; + int y = 3; auto op = CreateBinaryOp(common::Int(32), x, y); expect_value = llvm::ConstantInt::get(i32, x + y); - value = emitter->Visit(op.get()); + value = emitter->Visit(op.get()); ASSERT_EQ(value->getType(), i32); ASSERT_EQ(value, expect_value); // value->print(llvm::outs(), false); @@ -204,10 +213,11 @@ TEST(CodeGenLLVM, Expr) { do { float x = 2.5; float y = 3.5; - auto op = CreateBinaryOp(common::Float(32), x, y); + auto op = + CreateBinaryOp(common::Float(32), x, y); expect_value = llvm::ConstantFP::get(f32, x - y); - value = emitter->Visit(op.get()); + value = emitter->Visit(op.get()); ASSERT_EQ(value->getType(), f32); ASSERT_EQ(value, expect_value); } while (false); @@ -216,10 +226,11 @@ TEST(CodeGenLLVM, Expr) { do { float16 x{2.5}; float16 y{3.5}; - auto op = CreateBinaryOp(common::Float16(), x, y); + auto op = + CreateBinaryOp(common::Float16(), x, y); expect_value = llvm::ConstantFP::get(f16, x - y); - value = emitter->Visit(op.get()); + value = emitter->Visit(op.get()); ASSERT_EQ(value->getType(), f16); ASSERT_EQ(value, expect_value); } while (false); @@ -228,32 +239,34 @@ TEST(CodeGenLLVM, Expr) { do { bfloat16 x{2.5}; bfloat16 y{3.5}; - auto op = CreateBinaryOp(common::BFloat16(), x, y); + auto op = CreateBinaryOp( + common::BFloat16(), x, y); expect_value = llvm::ConstantFP::get(bf16, x - y); - value = emitter->Visit(op.get()); + value = emitter->Visit(op.get()); ASSERT_EQ(value->getType(), bf16); ASSERT_EQ(value, expect_value); } while (false); // * do { - int x = 5; - int y = 3; - auto op = CreateBinaryOp(common::Int(64), x, y); + int x = 5; + int y = 3; + auto op = CreateBinaryOp(common::Int(64), x, y); expect_value = llvm::ConstantInt::get(i64, x * y); - value = emitter->Visit(op.get()); + value = emitter->Visit(op.get()); ASSERT_EQ(value->getType(), i64); ASSERT_EQ(value, expect_value); } while (false); // / do { - float x = 6; - float y = 4; - auto op = CreateBinaryOp(common::Float(32), x, y); + float x = 6; + float y = 4; + auto op = + CreateBinaryOp(common::Float(32), x, y); expect_value = llvm::ConstantFP::get(f32, x / y); - value = emitter->Visit(op.get()); + value = emitter->Visit(op.get()); ASSERT_EQ(value->getType(), f32); ASSERT_EQ(value, expect_value); } while (false); @@ -262,9 +275,10 @@ TEST(CodeGenLLVM, Expr) { do { float16 x{6}; float16 y{4}; - auto op = CreateBinaryOp(common::Float16(), x, y); + auto op = + CreateBinaryOp(common::Float16(), x, y); expect_value = llvm::ConstantFP::get(f16, x / y); - value = emitter->Visit(op.get()); + value = emitter->Visit(op.get()); ASSERT_EQ(value->getType(), f16); ASSERT_EQ(value, expect_value); } while (false); @@ -273,31 +287,32 @@ TEST(CodeGenLLVM, Expr) { do { bfloat16 x{6}; bfloat16 y{4}; - auto op = CreateBinaryOp(common::BFloat16(), x, y); + auto op = CreateBinaryOp( + common::BFloat16(), x, y); expect_value = llvm::ConstantFP::get(bf16, x / y); - value = emitter->Visit(op.get()); + value = emitter->Visit(op.get()); ASSERT_EQ(value->getType(), bf16); ASSERT_EQ(value, expect_value); } while (false); // % do { - int x = 25; - int y = 7; - auto op = CreateBinaryOp(common::Int(32), x, y); + int x = 25; + int y = 7; + auto op = CreateBinaryOp(common::Int(32), x, y); expect_value = llvm::ConstantInt::get(i32, x % y); - value = emitter->Visit(op.get()); + value = emitter->Visit(op.get()); ASSERT_EQ(value->getType(), i32); ASSERT_EQ(value, expect_value); } while (false); // == do { - int x = 3; - int y = 3; - auto op = CreateBinaryOp(common::Int(32), x, y); + int x = 3; + int y = 3; + auto op = CreateBinaryOp(common::Int(32), x, y); expect_value = llvm::ConstantInt::get(i1, 1); - value = emitter->Visit(op.get()); + value = emitter->Visit(op.get()); ASSERT_EQ(value->getType(), i1); ASSERT_EQ(value, expect_value); } while (false); @@ -307,19 +322,20 @@ TEST(CodeGenLLVM, Expr) { float x = 3; float y = 3; - auto op = CreateBinaryOp(common::Float(32), x, y); + auto op = + CreateBinaryOp(common::Float(32), x, y); expect_value = llvm::ConstantInt::get(i1, 0); - value = emitter->Visit(op.get()); + value = emitter->Visit(op.get()); ASSERT_EQ(value->getType(), i1); ASSERT_EQ(value, expect_value); } while (false); // < do { - int x = 6; - int y = 6; - auto op = CreateBinaryOp(common::Int(32), x, y); - value = emitter->Visit(op.get()); + int x = 6; + int y = 6; + auto op = CreateBinaryOp(common::Int(32), x, y); + value = emitter->Visit(op.get()); expect_value = llvm::ConstantInt::get(i1, 0); ASSERT_EQ(value->getType(), i1); ASSERT_EQ(value, expect_value); @@ -327,10 +343,10 @@ TEST(CodeGenLLVM, Expr) { // <= do { - int x = 6; - int y = 6; - auto op = CreateBinaryOp(common::Int(32), x, y); - value = emitter->Visit(op.get()); + int x = 6; + int y = 6; + auto op = CreateBinaryOp(common::Int(32), x, y); + value = emitter->Visit(op.get()); expect_value = llvm::ConstantInt::get(i1, 1); ASSERT_EQ(value->getType(), i1); ASSERT_EQ(value, expect_value); @@ -338,10 +354,10 @@ TEST(CodeGenLLVM, Expr) { // > do { - int x = 6; - int y = 6; - auto op = CreateBinaryOp(common::Int(32), x, y); - value = emitter->Visit(op.get()); + int x = 6; + int y = 6; + auto op = CreateBinaryOp(common::Int(32), x, y); + value = emitter->Visit(op.get()); expect_value = llvm::ConstantInt::get(i1, 0); ASSERT_EQ(value->getType(), i1); ASSERT_EQ(value, expect_value); @@ -349,10 +365,10 @@ TEST(CodeGenLLVM, Expr) { // >= do { - int x = 6; - int y = 6; - auto op = CreateBinaryOp(common::Int(32), x, y); - value = emitter->Visit(op.get()); + int x = 6; + int y = 6; + auto op = CreateBinaryOp(common::Int(32), x, y); + value = emitter->Visit(op.get()); expect_value = llvm::ConstantInt::get(i1, 1); ASSERT_EQ(value->getType(), i1); ASSERT_EQ(value, expect_value); @@ -364,10 +380,10 @@ TEST(CodeGenLLVM, Expr) { // min do { - int x = 2; - int y = 3; - auto op = CreateBinaryOp(common::Int(32), x, y); - value = emitter->Visit(op.get()); + int x = 2; + int y = 3; + auto op = CreateBinaryOp(common::Int(32), x, y); + value = emitter->Visit(op.get()); expect_value = llvm::ConstantInt::get(i32, std::min(x, y)); ASSERT_EQ(value->getType(), i32); ASSERT_EQ(value, expect_value); @@ -375,10 +391,11 @@ TEST(CodeGenLLVM, Expr) { // max do { - float x = 2; - float y = 3; - auto op = CreateBinaryOp(common::Float(32), x, y); - value = emitter->Visit(op.get()); + float x = 2; + float y = 3; + auto op = + CreateBinaryOp(common::Float(32), x, y); + value = emitter->Visit(op.get()); expect_value = llvm::ConstantFP::get(f32, std::max(x, y)); ASSERT_EQ(value->getType(), f32); ASSERT_EQ(value, expect_value); @@ -394,33 +411,33 @@ TEST(CodeGenLLVM, Expr) { // i32 -> f32 LOG(INFO) << "test i32 -> f32"; - int v2 = 2; - auto x2 = std::make_unique(common::Int(32), v2); - auto ex2 = ir::Expr(x2.release()); - auto op2 = ir::Cast::Make(common::Float(32), std::move(ex2)); - value = emitter->Visit(&op2); + int v2 = 2; + auto x2 = std::make_unique(common::Int(32), v2); + auto ex2 = ir::Expr(x2.release()); + auto op2 = ir::Cast::Make(common::Float(32), std::move(ex2)); + value = emitter->Visit(&op2); expect_value = llvm::ConstantFP::get(f32, v2); ASSERT_EQ(value->getType(), f32); ASSERT_EQ(value, expect_value); // f32 -> i32 LOG(INFO) << "test f32 -> i32"; - float v3 = 3; - auto x3 = std::make_unique(common::Float(32), v3); - auto ex3 = ir::Expr(x3.release()); - auto op3 = ir::Cast::Make(common::Int(32), std::move(ex3)); - value = emitter->Visit(&op3); + float v3 = 3; + auto x3 = std::make_unique(common::Float(32), v3); + auto ex3 = ir::Expr(x3.release()); + auto op3 = ir::Cast::Make(common::Int(32), std::move(ex3)); + value = emitter->Visit(&op3); expect_value = llvm::ConstantInt::get(i32, v3); ASSERT_EQ(value->getType(), i32); ASSERT_EQ(value, expect_value); // i32 -> f16 LOG(INFO) << "test i32 -> f16"; - int v4 = 4; - auto x4 = std::make_unique(common::Int(32), v4); - auto ex4 = ir::Expr(x4.release()); - auto op4 = ir::Cast::Make(common::Float16(), std::move(ex4)); - value = emitter->Visit(&op4); + int v4 = 4; + auto x4 = std::make_unique(common::Int(32), v4); + auto ex4 = ir::Expr(x4.release()); + auto op4 = ir::Cast::Make(common::Float16(), std::move(ex4)); + value = emitter->Visit(&op4); expect_value = llvm::ConstantFP::get(f16, v4); ASSERT_EQ(value->getType(), f16); ASSERT_EQ(value, expect_value); @@ -428,21 +445,21 @@ TEST(CodeGenLLVM, Expr) { // f16 -> f32 LOG(INFO) << "test f16 -> f32"; float16 v5{5}; - auto x5 = std::make_unique(common::Float16(), v5); - auto ex5 = ir::Expr(x5.release()); - auto op5 = ir::Cast::Make(common::Float(32), std::move(ex5)); - value = emitter->Visit(&op5); + auto x5 = std::make_unique(common::Float16(), v5); + auto ex5 = ir::Expr(x5.release()); + auto op5 = ir::Cast::Make(common::Float(32), std::move(ex5)); + value = emitter->Visit(&op5); expect_value = llvm::ConstantFP::get(f32, v5); ASSERT_EQ(value->getType(), f32); ASSERT_EQ(value, expect_value); // i32 -> bf16 LOG(INFO) << "test i32 -> bf16"; - int v6 = 4; - auto x6 = std::make_unique(common::Int(32), v6); - auto ex6 = ir::Expr(x6.release()); - auto op6 = ir::Cast::Make(common::BFloat16(), std::move(ex6)); - value = emitter->Visit(&op6); + int v6 = 4; + auto x6 = std::make_unique(common::Int(32), v6); + auto ex6 = ir::Expr(x6.release()); + auto op6 = ir::Cast::Make(common::BFloat16(), std::move(ex6)); + value = emitter->Visit(&op6); expect_value = llvm::ConstantFP::get(bf16, v6); ASSERT_EQ(value->getType(), bf16); ASSERT_EQ(value, expect_value); @@ -450,10 +467,10 @@ TEST(CodeGenLLVM, Expr) { // bf16 -> f32 LOG(INFO) << "test bf16 -> f32"; bfloat16 v7{5}; - auto x7 = std::make_unique(common::BFloat16(), v7); - auto ex7 = ir::Expr(x7.release()); - auto op7 = ir::Cast::Make(common::Float(32), std::move(ex7)); - value = emitter->Visit(&op7); + auto x7 = std::make_unique(common::BFloat16(), v7); + auto ex7 = ir::Expr(x7.release()); + auto op7 = ir::Cast::Make(common::Float(32), std::move(ex7)); + value = emitter->Visit(&op7); expect_value = llvm::ConstantFP::get(f32, v7); ASSERT_EQ(value->getType(), f32); ASSERT_EQ(value, expect_value); @@ -466,52 +483,56 @@ TEST(CodeGenLLVM, Statement) { llvm::raw_string_ostream ss(outs); do { - auto _m_b_context_emitter_ = CreateCodeGenLLVMTestLLVM(); // NOLINT - auto &m = std::get<0>(_m_b_context_emitter_); - auto &b = std::get<1>(_m_b_context_emitter_); - auto &context = std::get<2>(_m_b_context_emitter_); - auto &emitter = std::get<3>(_m_b_context_emitter_); - auto _i8_i32_i64_u32_f32_f16_ = CreateLLVMType(context.get()); // NOLINT - auto &i8 = std::get<0>(_i8_i32_i64_u32_f32_f16_); - auto &i32 = std::get<1>(_i8_i32_i64_u32_f32_f16_); - auto &i64 = std::get<2>(_i8_i32_i64_u32_f32_f16_); - auto &u32 = std::get<3>(_i8_i32_i64_u32_f32_f16_); - auto &f32 = std::get<4>(_i8_i32_i64_u32_f32_f16_); - auto &f16 = std::get<4>(_i8_i32_i64_u32_f32_f16_); + auto _m_b_context_emitter_ = CreateCodeGenLLVMTestLLVM(); // NOLINT + auto &m = std::get<0>(_m_b_context_emitter_); + auto &b = std::get<1>(_m_b_context_emitter_); + auto &context = std::get<2>(_m_b_context_emitter_); + auto &emitter = std::get<3>(_m_b_context_emitter_); + auto _i8_i32_i64_u32_f32_f16_ = CreateLLVMType(context.get()); // NOLINT + auto &i8 = std::get<0>(_i8_i32_i64_u32_f32_f16_); + auto &i32 = std::get<1>(_i8_i32_i64_u32_f32_f16_); + auto &i64 = std::get<2>(_i8_i32_i64_u32_f32_f16_); + auto &u32 = std::get<3>(_i8_i32_i64_u32_f32_f16_); + auto &f32 = std::get<4>(_i8_i32_i64_u32_f32_f16_); + auto &f16 = std::get<4>(_i8_i32_i64_u32_f32_f16_); llvm::FunctionType *function_type = llvm::FunctionType::get(i32, {}, false); - llvm::Function *function = llvm::Function::Create( - function_type, llvm::Function::ExternalLinkage, "codegen_llvm_test.Alloc_Store_Load_Free", m.get()); + llvm::Function *function = + llvm::Function::Create(function_type, + llvm::Function::ExternalLinkage, + "codegen_llvm_test.Alloc_Store_Load_Free", + m.get()); std::string module_str; module_str += "; ModuleID = 'test_codegen_llvm'"; module_str += "\nsource_filename = \"test_codegen_llvm\"\n"; module_str += "\ndefine i32 @codegen_llvm_test.Alloc_Store_Load_Free()"; - llvm::BasicBlock *entry = llvm::BasicBlock::Create(*context, "entry", function); + llvm::BasicBlock *entry = + llvm::BasicBlock::Create(*context, "entry", function); b->SetInsertPoint(entry); module_str += " {\nentry:"; // ir::Tensor - auto tensor_op = CreateIrTensor("x", {2, 3}); + auto tensor_op = CreateIrTensor("x", {2, 3}); tensor_op->buffer = CreateIrBuffer(common::Int(32), "", {2, 3}); // ir::Alloc - auto alloc_op = std::make_unique(); + auto alloc_op = std::make_unique(); alloc_op->destination = ir::Expr(tensor_op->buffer); // ir::Store - auto store_op = std::make_unique(); + auto store_op = std::make_unique(); store_op->tensor = ir::Expr(tensor_op); for (int i : {1, 1}) { auto pi = std::make_unique(common::Int(32), std::move(i)); store_op->indices.emplace_back(pi.release()); } auto store_value = std::make_unique(common::Int(32), 5); - store_op->value = ir::Expr(store_value.release()); + store_op->value = ir::Expr(store_value.release()); // ir::Load - auto load_op = std::make_unique(); + auto load_op = std::make_unique(); load_op->tensor = ir::Expr(tensor_op); for (int i : {1, 1}) { auto pi = std::make_unique(common::Int(32), std::move(i)); @@ -519,20 +540,23 @@ TEST(CodeGenLLVM, Statement) { } // ir::Free - auto free_op = std::make_unique(); + auto free_op = std::make_unique(); free_op->destination = ir::Expr(tensor_op->buffer); // ir::Call - auto call_op = std::make_unique(common::Int(32)); + auto call_op = std::make_unique(common::Int(32)); call_op->name = "codegen_llvm_test.Alloc_Store_Load_Free"; // Emit llvm ir - auto *alloc_inst = llvm::dyn_cast(emitter->Visit(alloc_op.get())); + auto *alloc_inst = + llvm::dyn_cast(emitter->Visit(alloc_op.get())); module_str += "\n %0 = alloca [6 x i32]"; - auto *store_inst = llvm::dyn_cast(emitter->Visit(store_op.get())); + auto *store_inst = + llvm::dyn_cast(emitter->Visit(store_op.get())); module_str += "\n %1 = getelementptr [6 x i32], [6 x i32]* %0, i32 1"; module_str += "\n store i32 5, [6 x i32]* %1"; - auto *load_inst = llvm::dyn_cast(emitter->Visit(load_op.get())); + auto *load_inst = + llvm::dyn_cast(emitter->Visit(load_op.get())); module_str += "\n %2 = getelementptr [6 x i32], [6 x i32]* %0, i32 1"; module_str += "\n %3 = load [6 x i32], [6 x i32]* %2"; @@ -573,21 +597,21 @@ TEST(CodeGenLLVM, LowerFunc) { auto emitter = std::make_unique(m.get(), b.get()); auto _i8_i32_i64_u32_f32_f16_ = CreateLLVMType(context.get()); // NOLINT - auto &i8 = std::get<0>(_i8_i32_i64_u32_f32_f16_); - auto &i32 = std::get<1>(_i8_i32_i64_u32_f32_f16_); - auto &i64 = std::get<2>(_i8_i32_i64_u32_f32_f16_); - auto &u32 = std::get<3>(_i8_i32_i64_u32_f32_f16_); - auto &f32 = std::get<4>(_i8_i32_i64_u32_f32_f16_); - auto &f16 = std::get<5>(_i8_i32_i64_u32_f32_f16_); - auto _x_y_z_z_buf_ = CreateTensor(); // NOLINT - auto &x = std::get<0>(_x_y_z_z_buf_); - auto &y = std::get<1>(_x_y_z_z_buf_); - auto &z = std::get<2>(_x_y_z_z_buf_); - auto &z_buf = std::get<3>(_x_y_z_z_buf_); + auto &i8 = std::get<0>(_i8_i32_i64_u32_f32_f16_); + auto &i32 = std::get<1>(_i8_i32_i64_u32_f32_f16_); + auto &i64 = std::get<2>(_i8_i32_i64_u32_f32_f16_); + auto &u32 = std::get<3>(_i8_i32_i64_u32_f32_f16_); + auto &f32 = std::get<4>(_i8_i32_i64_u32_f32_f16_); + auto &f16 = std::get<5>(_i8_i32_i64_u32_f32_f16_); + auto _x_y_z_z_buf_ = CreateTensor(); // NOLINT + auto &x = std::get<0>(_x_y_z_z_buf_); + auto &y = std::get<1>(_x_y_z_z_buf_); + auto &z = std::get<2>(_x_y_z_z_buf_); + auto &z_buf = std::get<3>(_x_y_z_z_buf_); z->Bind(z_buf); - auto stages = CreateStages({x, y, z}); + auto stages = CreateStages({x, y, z}); auto function = lang::Lower("add1", stages, {x, y, z}); ir::Expr func_expr(function); diff --git a/paddle/cinn/backends/llvm/codegen_x86.cc b/paddle/cinn/backends/llvm/codegen_x86.cc index 8fd38489d345f..bc0ec7493f164 100644 --- a/paddle/cinn/backends/llvm/codegen_x86.cc +++ b/paddle/cinn/backends/llvm/codegen_x86.cc @@ -20,25 +20,28 @@ #include #include +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/Support/Casting.h" #include "paddle/cinn/backends/llvm/codegen_llvm.h" #include "paddle/cinn/common/target.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_operators.h" #include "paddle/cinn/optim/collect_undefined_vars.h" #include "paddle/cinn/runtime/intrinsic.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/Support/Casting.h" namespace cinn::backends { -CodeGenX86::CodeGenX86(llvm::Module* m, llvm::IRBuilder<>* b, const std::shared_ptr& vars) +CodeGenX86::CodeGenX86(llvm::Module* m, + llvm::IRBuilder<>* b, + const std::shared_ptr& vars) : CodeGenLLVM(m, b, vars) {} CodeGenX86::~CodeGenX86() {} -llvm::Value* CodeGenX86::PackVars(const std::vector& vars, uint64_t* num_bytes) { +llvm::Value* CodeGenX86::PackVars(const std::vector& vars, + uint64_t* num_bytes) { if (vars.empty()) { *num_bytes = 0U; return llvm::Constant::getNullValue(ll_void_p_ty()); @@ -48,26 +51,37 @@ llvm::Value* CodeGenX86::PackVars(const std::vector& vars, uint64_t types.push_back(GetVar(v, false)->getType()); } llvm::StructType* t_data = llvm::StructType::create(types); - llvm::Value* data = b_->CreateAlloca(t_data, llvm_int32_constant(1)); + llvm::Value* data = b_->CreateAlloca(t_data, llvm_int32_constant(1)); for (size_t i = 0; i < vars.size(); ++i) { - b_->CreateStore(GetVar(vars[i]), b_->CreateInBoundsGEP(data, {llvm_int32_constant(0), llvm_int32_constant(i)})); + b_->CreateStore( + GetVar(vars[i]), + b_->CreateInBoundsGEP( + data, {llvm_int32_constant(0), llvm_int32_constant(i)})); } - *num_bytes = m_->getDataLayout().getTypeAllocSize(llvm::cast(data->getType())->getElementType()); + *num_bytes = m_->getDataLayout().getTypeAllocSize( + llvm::cast(data->getType())->getElementType()); return data; } -void CodeGenX86::UnpackVars(const std::vector& vars, llvm::Value* data) { +void CodeGenX86::UnpackVars(const std::vector& vars, + llvm::Value* data) { for (size_t i = 0; i < vars.size(); ++i) { - SetVar(vars[i], b_->CreateLoad(b_->CreateInBoundsGEP(data, {llvm_int32_constant(0), llvm_int32_constant(i)}))); + SetVar(vars[i], + b_->CreateLoad(b_->CreateInBoundsGEP( + data, {llvm_int32_constant(0), llvm_int32_constant(i)}))); } } llvm::BasicBlock* CodeGenX86::CheckCallSuccess(llvm::Value* retcode) { llvm::BasicBlock* fail_block = - llvm::BasicBlock::Create(b_->getContext(), "call_fail", b_->GetInsertBlock()->getParent(), nullptr); - llvm::BasicBlock* end_block = - llvm::BasicBlock::Create(b_->getContext(), "call_end", b_->GetInsertBlock()->getParent(), nullptr); - llvm::Value* succ = b_->CreateICmpEQ(retcode, llvm::ConstantInt::get(ll_int32_ty(), 0)); + llvm::BasicBlock::Create(b_->getContext(), + "call_fail", + b_->GetInsertBlock()->getParent(), + nullptr); + llvm::BasicBlock* end_block = llvm::BasicBlock::Create( + b_->getContext(), "call_end", b_->GetInsertBlock()->getParent(), nullptr); + llvm::Value* succ = + b_->CreateICmpEQ(retcode, llvm::ConstantInt::get(ll_int32_ty(), 0)); b_->CreateCondBr(succ, end_block, fail_block); b_->SetInsertPoint(fail_block); RetVoid(); @@ -76,36 +90,48 @@ llvm::BasicBlock* CodeGenX86::CheckCallSuccess(llvm::Value* retcode) { } void CodeGenX86::CreateParallelLaunch(Expr body, int num_task) { - auto ftype_parallel_lambda = - llvm::FunctionType::get(ll_int32_ty(), {ll_int32_ty(), ll_int32_ty(), ll_type_of(Float(32).PointerOf())}, false); - llvm::Function* f = - llvm::Function::Create(ftype_parallel_lambda, llvm::Function::PrivateLinkage, "__parallel_lambda", m_); + auto ftype_parallel_lambda = llvm::FunctionType::get( + ll_int32_ty(), + {ll_int32_ty(), ll_int32_ty(), ll_type_of(Float(32).PointerOf())}, + false); + llvm::Function* f = llvm::Function::Create(ftype_parallel_lambda, + llvm::Function::PrivateLinkage, + "__parallel_lambda", + m_); std::vector vars = optim::CollectUndefinedVars(&body); uint64_t nbytes; auto* data = PackVars(vars, &nbytes); - auto ftype_parallel_launch = llvm::FunctionType::get( - ll_int32_ty(), {ftype_parallel_lambda->getPointerTo(), ll_type_of(Float(32).PointerOf()), ll_int32_ty()}, false); + auto ftype_parallel_launch = + llvm::FunctionType::get(ll_int32_ty(), + {ftype_parallel_lambda->getPointerTo(), + ll_type_of(Float(32).PointerOf()), + ll_int32_ty()}, + false); auto* launch_callee = llvm::dyn_cast( - m_->getOrInsertFunction(runtime::intrinsic::parallel_launch, ftype_parallel_launch).getCallee()); + m_->getOrInsertFunction(runtime::intrinsic::parallel_launch, + ftype_parallel_launch) + .getCallee()); launch_callee->setCallingConv(llvm::CallingConv::C); auto* launch_end = CheckCallSuccess(b_->CreateCall( launch_callee, - {f, b_->CreatePointerCast(data, ll_type_of(Float(32).PointerOf())), llvm_int32_constant(num_task)})); + {f, + b_->CreatePointerCast(data, ll_type_of(Float(32).PointerOf())), + llvm_int32_constant(num_task)})); auto* flambda = llvm::BasicBlock::Create(b_->getContext(), "flambda", f); b_->SetInsertPoint(flambda); - auto it = f->arg_begin(); + auto it = f->arg_begin(); auto* task_id = &(*it++); - auto* penv = &(*it++); - data = b_->CreatePointerCast(&(*it++), data->getType()); + auto* penv = &(*it++); + data = b_->CreatePointerCast(&(*it++), data->getType()); symbol_table_->PushScope(); UnpackVars(vars, data); ParallelEnv par_env; - auto task_id_name = common::UniqName("task_id"); + auto task_id_name = common::UniqName("task_id"); auto num_task_name = common::UniqName("num_task"); - par_env.task_id = ir::Var(task_id_name, Int(32)); - par_env.num_task = ir::Var(num_task_name, Int(32)); + par_env.task_id = ir::Var(task_id_name, Int(32)); + par_env.num_task = ir::Var(num_task_name, Int(32)); SetVar(task_id_name, task_id); SetVar(num_task_name, penv); par_env.penv = penv; @@ -118,7 +144,8 @@ void CodeGenX86::CreateParallelLaunch(Expr body, int num_task) { symbol_table_->PopScope(); std::swap(parallel_env_, par_env); std::swap(f_, f); - CHECK_NE(par_env.parallel_loop_count, 0) << "find no parallel loop within parallel launch"; + CHECK_NE(par_env.parallel_loop_count, 0) + << "find no parallel loop within parallel launch"; b_->SetInsertPoint(launch_end); } @@ -126,28 +153,43 @@ llvm::Value* CodeGenX86::Visit(const ir::For* op) { if (op->is_parallel()) { VLOG(3) << "parallel forloop"; if (parallel_env_.penv == nullptr) { - CreateParallelLaunch( - ir::For::Make( - op->loop_var, op->min, op->extent, op->for_type(), op->device_api, op->body, op->vectorize_info()), - 0); + CreateParallelLaunch(ir::For::Make(op->loop_var, + op->min, + op->extent, + op->for_type(), + op->device_api, + op->body, + op->vectorize_info()), + 0); } else { Expr num_task = parallel_env_.num_task; - Expr task_id = parallel_env_.task_id; - CHECK(!parallel_env_.in_parallel_loop) << "Nested parallel loop is not supported, try to fuse them instead"; + Expr task_id = parallel_env_.task_id; + CHECK(!parallel_env_.in_parallel_loop) + << "Nested parallel loop is not supported, try to fuse them instead"; parallel_env_.in_parallel_loop = true; if (parallel_env_.stride_pattern) { - auto new_for = ir::For::Make( - op->loop_var, task_id, op->extent, op->for_type(), op->device_api, op->body, op->vectorize_info()); + auto new_for = ir::For::Make(op->loop_var, + task_id, + op->extent, + op->for_type(), + op->device_api, + op->body, + op->vectorize_info()); auto for_node = new_for.As(); CHECK(for_node); CreateSerialFor(for_node, num_task.as_int32()); } else { Expr extent = op->extent; - Expr step = (extent + num_task - Expr(1)) / num_task; - Expr begin = min(task_id * step, op->extent); - Expr end = min((task_id + Expr(1)) * step, op->extent); - auto new_for = - ir::For::Make(op->loop_var, begin, end, op->for_type(), op->device_api, op->body, op->vectorize_info()); + Expr step = (extent + num_task - Expr(1)) / num_task; + Expr begin = min(task_id * step, op->extent); + Expr end = min((task_id + Expr(1)) * step, op->extent); + auto new_for = ir::For::Make(op->loop_var, + begin, + end, + op->for_type(), + op->device_api, + op->body, + op->vectorize_info()); auto for_node = new_for.As(); CHECK(for_node); CreateSerialFor(for_node); diff --git a/paddle/cinn/backends/llvm/codegen_x86.h b/paddle/cinn/backends/llvm/codegen_x86.h index 72ba4bc88c1e5..d9fd127249e03 100644 --- a/paddle/cinn/backends/llvm/codegen_x86.h +++ b/paddle/cinn/backends/llvm/codegen_x86.h @@ -27,7 +27,9 @@ namespace cinn::backends { class CodeGenX86 : public CodeGenLLVM { public: - explicit CodeGenX86(llvm::Module* m, llvm::IRBuilder<>* b, const std::shared_ptr& vars = nullptr); + explicit CodeGenX86(llvm::Module* m, + llvm::IRBuilder<>* b, + const std::shared_ptr& vars = nullptr); virtual ~CodeGenX86(); using LLVMIRVisitor::Visit; @@ -49,7 +51,8 @@ class CodeGenX86 : public CodeGenLLVM { // Create parallel launch void CreateParallelLaunch(Expr body, int num_task); - llvm::Value* PackVars(const std::vector& vars, uint64_t* num_bytes); + llvm::Value* PackVars(const std::vector& vars, + uint64_t* num_bytes); void UnpackVars(const std::vector& vars, llvm::Value* data); llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode); // Current parallel environment scope. diff --git a/paddle/cinn/backends/llvm/codegen_x86_test.cc b/paddle/cinn/backends/llvm/codegen_x86_test.cc index 1287ffcd5e6be..42cd0f171435d 100644 --- a/paddle/cinn/backends/llvm/codegen_x86_test.cc +++ b/paddle/cinn/backends/llvm/codegen_x86_test.cc @@ -53,9 +53,16 @@ TEST(Vectorize, basic) { auto* fn_ptr = reinterpret_cast(fn_); - auto* A_buf = common::BufferBuilder(Float(32), {1024}).set_random().set_align(64).Build(); - auto* B_buf = common::BufferBuilder(Float(32), {1024}).set_random().set_align(64).Build(); - auto* C_buf = common::BufferBuilder(Float(32), {1024}).set_zero().set_align(64).Build(); + auto* A_buf = common::BufferBuilder(Float(32), {1024}) + .set_random() + .set_align(64) + .Build(); + auto* B_buf = common::BufferBuilder(Float(32), {1024}) + .set_random() + .set_align(64) + .Build(); + auto* C_buf = + common::BufferBuilder(Float(32), {1024}).set_zero().set_align(64).Build(); auto args = common::ArgsBuilder().Add(A_buf).Add(B_buf).Add(C_buf).Build(); diff --git a/paddle/cinn/backends/llvm/execution_engine.cc b/paddle/cinn/backends/llvm/execution_engine.cc index 549675269bf36..050fd4e0d8389 100644 --- a/paddle/cinn/backends/llvm/execution_engine.cc +++ b/paddle/cinn/backends/llvm/execution_engine.cc @@ -87,15 +87,19 @@ void InitializeLLVMPasses() { // llvm::initializeCodeGenPreparePass(registry); } } // namespace -void NaiveObjectCache::notifyObjectCompiled(const llvm::Module *m, llvm::MemoryBufferRef obj_buffer) { +void NaiveObjectCache::notifyObjectCompiled(const llvm::Module *m, + llvm::MemoryBufferRef obj_buffer) { cached_objects_[m->getModuleIdentifier()] = - llvm::MemoryBuffer::getMemBufferCopy(obj_buffer.getBuffer(), obj_buffer.getBufferIdentifier()); + llvm::MemoryBuffer::getMemBufferCopy(obj_buffer.getBuffer(), + obj_buffer.getBufferIdentifier()); } -std::unique_ptr NaiveObjectCache::getObject(const llvm::Module *m) { +std::unique_ptr NaiveObjectCache::getObject( + const llvm::Module *m) { auto it = cached_objects_.find(m->getModuleIdentifier()); if (it == cached_objects_.end()) { - VLOG(1) << "No object for " << m->getModuleIdentifier() << " in cache. Compiling."; + VLOG(1) << "No object for " << m->getModuleIdentifier() + << " in cache. Compiling."; return nullptr; } @@ -103,13 +107,15 @@ std::unique_ptr NaiveObjectCache::getObject(const llvm::Modu return llvm::MemoryBuffer::getMemBuffer(it->second->getMemBufferRef()); } -/*static*/ std::unique_ptr ExecutionEngine::Create(const ExecutionOptions &config) { +/*static*/ std::unique_ptr ExecutionEngine::Create( + const ExecutionOptions &config) { return Create(config, {}); } -/*static*/ std::unique_ptr ExecutionEngine::Create(const ExecutionOptions &config, - RuntimeSymbols &&module_symbols) { - VLOG(1) << "===================== Create CINN ExecutionEngine begin ===================="; +/*static*/ std::unique_ptr ExecutionEngine::Create( + const ExecutionOptions &config, RuntimeSymbols &&module_symbols) { + VLOG(1) << "===================== Create CINN ExecutionEngine begin " + "===================="; VLOG(1) << "initialize llvm config"; VLOG(1) << "llvm version: " << LLVM_VERSION_STRING; VLOG(1) << "llvm default target triple: " << LLVM_DEFAULT_TARGET_TRIPLE; @@ -117,20 +123,26 @@ std::unique_ptr NaiveObjectCache::getObject(const llvm::Modu static std::once_flag flag; std::call_once(flag, InitializeLLVMPasses); - auto engine = std::make_unique(/*enable_object_cache=*/true, std::move(module_symbols)); + auto engine = std::make_unique(/*enable_object_cache=*/true, + std::move(module_symbols)); - auto compile_layer_creator = [&engine](llvm::orc::JITTargetMachineBuilder jtmb) - -> llvm::Expected> { + auto compile_layer_creator = + [&engine](llvm::orc::JITTargetMachineBuilder jtmb) + -> llvm::Expected< + std::unique_ptr> { auto machine = llvm::cantFail(jtmb.createTargetMachine()); VLOG(1) << "create llvm compile layer"; VLOG(1) << "Target Name: " << machine->getTarget().getName(); VLOG(1) << "Target CPU: " << machine->getTargetCPU().str() << std::endl; - return std::make_unique(std::move(machine), engine->cache_.get()); + return std::make_unique( + std::move(machine), engine->cache_.get()); }; - auto object_layer_creator = [&](llvm::orc::ExecutionSession &session, const llvm::Triple &triple) { + auto object_layer_creator = [&](llvm::orc::ExecutionSession &session, + const llvm::Triple &triple) { auto object_layer = std::make_unique( - session, []() { return std::make_unique(); }); + session, + []() { return std::make_unique(); }); llvm::orc::JITDylib *main_jd = session.getJITDylibByName("
"); if (!main_jd) { main_jd = &llvm::cantFail(session.createJITDylib("
")); @@ -139,18 +151,21 @@ std::unique_ptr NaiveObjectCache::getObject(const llvm::Modu }; VLOG(2) << "create jit execution engine"; - engine->jit_ = llvm::cantFail(llvm::orc::LLJITBuilder() - .setCompileFunctionCreator(compile_layer_creator) - .setObjectLinkingLayerCreator(object_layer_creator) - .create()); + engine->jit_ = + llvm::cantFail(llvm::orc::LLJITBuilder() + .setCompileFunctionCreator(compile_layer_creator) + .setObjectLinkingLayerCreator(object_layer_creator) + .create()); engine->jit_->getMainJITDylib().addGenerator(llvm::cantFail( - llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(engine->jit_->getDataLayout().getGlobalPrefix()))); + llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( + engine->jit_->getDataLayout().getGlobalPrefix()))); VLOG(2) << "register runtime call symbols"; engine->RegisterRuntimeSymbols(); - VLOG(2) << "===================== Create CINN ExecutionEngine end ===================="; + VLOG(2) << "===================== Create CINN ExecutionEngine end " + "===================="; return engine; } @@ -158,27 +173,31 @@ template void ExecutionEngine::Link(const ir::Module &module) { utils::RecordEvent("ExecutionEngine Link", utils::EventType::kOrdinary); llvm::SMDiagnostic error; - auto ctx = std::make_unique(); - auto m = llvm::parseAssemblyString(AsStringRef(backends::kRuntimeLlvmIr), error, *ctx); - auto b = std::make_unique>(*ctx); + auto ctx = std::make_unique(); + auto m = llvm::parseAssemblyString( + AsStringRef(backends::kRuntimeLlvmIr), error, *ctx); + auto b = std::make_unique>(*ctx); auto ir_emitter = std::make_unique(m.get(), b.get()); VLOG(3) << "ir_emitter->Compile(module) Begin"; ir_emitter->Compile(module); VLOG(3) << "ir_emitter->Compile(module) Succeed!"; CHECK(!llvm::verifyModule(*m, &llvm::errs())) << "Invalid module found"; - auto machine = - std::move(llvm::cantFail(llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost()).createTargetMachine())); + auto machine = std::move(llvm::cantFail( + llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost()) + .createTargetMachine())); LLVMModuleOptimizer optimize(machine.get(), 3, {}, true); optimize(m.get()); - CHECK(!llvm::verifyModule(*m, &llvm::errs())) << "Invalid optimized module detected"; + CHECK(!llvm::verifyModule(*m, &llvm::errs())) + << "Invalid optimized module detected"; for (auto &f : *m) { VLOG(5) << "function: " << DumpToString(f); } llvm::raw_svector_ostream rawstream(buffer_); llvm::legacy::PassManager pass_manager; - machine->addPassesToEmitFile(pass_manager, rawstream, nullptr, llvm::CGFT_ObjectFile); + machine->addPassesToEmitFile( + pass_manager, rawstream, nullptr, llvm::CGFT_ObjectFile); pass_manager.run(*m); CHECK(AddModule(std::move(m), std::move(ctx))); @@ -194,7 +213,8 @@ void ExecutionEngine::Link(const ir::Module &module) { } } -bool ExecutionEngine::AddModule(std::unique_ptr module, std::unique_ptr context) { +bool ExecutionEngine::AddModule(std::unique_ptr module, + std::unique_ptr context) { utils::RecordEvent("ExecutionEngine AddModule", utils::EventType::kOrdinary); module->setDataLayout(jit_->getDataLayout()); if (VLOG_IS_ON(5)) { @@ -230,16 +250,21 @@ void *ExecutionEngine::Lookup(absl::string_view name) { } void ExecutionEngine::RegisterRuntimeSymbols() { - utils::RecordEvent("ExecutionEngine RegisterRuntimeSymbols", utils::EventType::kOrdinary); + utils::RecordEvent("ExecutionEngine RegisterRuntimeSymbols", + utils::EventType::kOrdinary); const auto ®istry = GlobalSymbolRegistry::Global(); - auto *session = &jit_->getExecutionSession(); + auto *session = &jit_->getExecutionSession(); for (const auto &sym : registry.All()) { llvm::cantFail(jit_->define(llvm::orc::absoluteSymbols( - {{session->intern(sym.first), {llvm::pointerToJITTargetAddress(sym.second), llvm::JITSymbolFlags::None}}}))); + {{session->intern(sym.first), + {llvm::pointerToJITTargetAddress(sym.second), + llvm::JITSymbolFlags::None}}}))); } for (const auto &sym : module_symbols_.All()) { llvm::cantFail(jit_->define(llvm::orc::absoluteSymbols( - {{session->intern(sym.first), {llvm::pointerToJITTargetAddress(sym.second), llvm::JITSymbolFlags::None}}}))); + {{session->intern(sym.first), + {llvm::pointerToJITTargetAddress(sym.second), + llvm::JITSymbolFlags::None}}}))); } } diff --git a/paddle/cinn/backends/llvm/execution_engine.h b/paddle/cinn/backends/llvm/execution_engine.h index 31f6e144fbcc9..63f9427a53edb 100644 --- a/paddle/cinn/backends/llvm/execution_engine.h +++ b/paddle/cinn/backends/llvm/execution_engine.h @@ -52,7 +52,8 @@ namespace cinn::backends { class NaiveObjectCache : public llvm::ObjectCache { public: - void notifyObjectCompiled(const llvm::Module *, llvm::MemoryBufferRef) override; + void notifyObjectCompiled(const llvm::Module *, + llvm::MemoryBufferRef) override; std::unique_ptr getObject(const llvm::Module *) override; private: @@ -69,9 +70,11 @@ struct ExecutionOptions { class ExecutionEngine { public: - static std::unique_ptr Create(const ExecutionOptions &config); + static std::unique_ptr Create( + const ExecutionOptions &config); - static std::unique_ptr Create(const ExecutionOptions &config, RuntimeSymbols &&module_symbols); + static std::unique_ptr Create( + const ExecutionOptions &config, RuntimeSymbols &&module_symbols); void *Lookup(absl::string_view name); @@ -80,18 +83,22 @@ class ExecutionEngine { void ExportObject(const std::string &path); - bool AddModule(std::unique_ptr module, std::unique_ptr context); + bool AddModule(std::unique_ptr module, + std::unique_ptr context); protected: - explicit ExecutionEngine(bool enable_object_cache, RuntimeSymbols &&module_symbols) - : cache_(std::make_unique()), module_symbols_(std::move(module_symbols)) {} + explicit ExecutionEngine(bool enable_object_cache, + RuntimeSymbols &&module_symbols) + : cache_(std::make_unique()), + module_symbols_(std::move(module_symbols)) {} void RegisterRuntimeSymbols(); bool SetupTargetTriple(llvm::Module *module); // This may not be a compatible implementation. - friend std::unique_ptr std::make_unique(bool &&, cinn::backends::RuntimeSymbols &&); + friend std::unique_ptr std::make_unique( + bool &&, cinn::backends::RuntimeSymbols &&); private: mutable std::mutex mu_; diff --git a/paddle/cinn/backends/llvm/execution_engine_test.cc b/paddle/cinn/backends/llvm/execution_engine_test.cc index 55a2bbedb1133..7adca52f34ca7 100644 --- a/paddle/cinn/backends/llvm/execution_engine_test.cc +++ b/paddle/cinn/backends/llvm/execution_engine_test.cc @@ -58,10 +58,12 @@ bool RegisterKnownSymbols() { decltype(auto) registry = GlobalSymbolRegistry::Global(); registry.RegisterFn("sinf", reinterpret_cast(&sinf)); - registry.RegisterFn("sin", reinterpret_cast(static_cast(&sin))); + registry.RegisterFn( + "sin", reinterpret_cast(static_cast(&sin))); registry.RegisterFn("cosf", reinterpret_cast(&cosf)); - registry.RegisterFn("cos", reinterpret_cast(static_cast(&cos))); + registry.RegisterFn( + "cos", reinterpret_cast(static_cast(&cos))); return true; } @@ -71,9 +73,12 @@ constexpr int kM = 100; constexpr int kN = 32; auto CreateTestBuffer() { - auto *A = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {kM, kN}, 32); - auto *B = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {kM, kN}, 32); - auto *C = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {kM, kN}, 32); + auto *A = cinn_buffer_t::new_( + cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {kM, kN}, 32); + auto *B = cinn_buffer_t::new_( + cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {kM, kN}, 32); + auto *C = cinn_buffer_t::new_( + cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {kM, kN}, 32); cinn_buffer_malloc(nullptr, A); cinn_buffer_malloc(nullptr, B); cinn_buffer_malloc(nullptr, C); @@ -105,11 +110,11 @@ auto CreateTestCinnModule() { common::Target target; target.arch = common::Target::Arch::X86; target.bits = common::Target::Bit::k32; - target.os = common::Target::OS::Linux; + target.os = common::Target::OS::Linux; ir::Module::Builder builder("module1", target); auto stages = CreateStages({C}); - auto funcs = lang::Lower("elementwise_add", stages, {A, B, C}); + auto funcs = lang::Lower("elementwise_add", stages, {A, B, C}); // auto func = optim::Optimize(funcs); @@ -123,9 +128,9 @@ TEST(llvm_test01, elementwise_add) { auto engine = backends::ExecutionEngine::Create({1}); auto _a_b_c_ = CreateTestBuffer(); // NOLINT - auto &a = std::get<0>(_a_b_c_); - auto &b = std::get<1>(_a_b_c_); - auto &c = std::get<2>(_a_b_c_); + auto &a = std::get<0>(_a_b_c_); + auto &b = std::get<1>(_a_b_c_); + auto &c = std::get<2>(_a_b_c_); auto module = CreateTestCinnModule(); @@ -133,7 +138,8 @@ TEST(llvm_test01, elementwise_add) { auto elementwise_add_addr = engine->Lookup("elementwise_add"); return; - auto elementwise_add = reinterpret_cast(elementwise_add_addr); + auto elementwise_add = + reinterpret_cast(elementwise_add_addr); cinn_pod_value_t a_arg(a), b_arg(b), c_arg(c); cinn_pod_value_t args[3] = {a_arg, b_arg, c_arg}; elementwise_add(args, 3); @@ -158,7 +164,7 @@ TEST(llvm, module_call_lowered_func) { {M, N}, [&](auto i, auto j) { return a(i, j) + b(i, j); }, "C"); auto stages = CreateStages({c}); - auto fn = lang::Lower("elementwise_add", stages, {a, b, c}, {}); + auto fn = lang::Lower("elementwise_add", stages, {a, b, c}, {}); builder.AddFunction(fn); } @@ -166,26 +172,28 @@ TEST(llvm, module_call_lowered_func) { lang::Placeholder a("A", {M, N}); lang::Placeholder b("B", {M, N}); - std::vector ret_types({lang::ReturnType{Float(32), {M, N}, "c_out"}}); + std::vector ret_types( + {lang::ReturnType{Float(32), {M, N}, "c_out"}}); auto call_outs = lang::CallLowered("elementwise_add", {a, b}, ret_types); - auto c = call_outs[0]; + auto c = call_outs[0]; // here we must call the output, so that it cal output something. - auto stages = CreateStages({c}); + auto stages = CreateStages({c}); auto main_fn = lang::Lower("main", stages, {a, b, c}, {}); builder.AddFunction(main_fn); CodeGenC codegen(common::DefaultHostTarget()); codegen.SetInlineBuiltinCodes(false); - LOG(INFO) << "module:\n" << codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + LOG(INFO) << "module:\n" + << codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); } auto _ab_bb_cb_ = CreateTestBuffer(); // NOLINT - auto &ab = std::get<0>(_ab_bb_cb_); - auto &bb = std::get<1>(_ab_bb_cb_); - auto &cb = std::get<2>(_ab_bb_cb_); + auto &ab = std::get<0>(_ab_bb_cb_); + auto &bb = std::get<1>(_ab_bb_cb_); + auto &cb = std::get<2>(_ab_bb_cb_); do { // call the function auto engine = backends::ExecutionEngine::Create({1}); @@ -194,7 +202,8 @@ TEST(llvm, module_call_lowered_func) { auto cos_fn = (double (*)(double))engine->Lookup("cos"); LOG(INFO) << "=> LLVM JIT cos(0) = " << cos_fn(0); auto elementwise_add_addr = engine->Lookup("elementwise_add"); - auto elementwise_add = reinterpret_cast(elementwise_add_addr); + auto elementwise_add = + reinterpret_cast(elementwise_add_addr); LOG(INFO) << "JIT get elementwise_add_addr"; break; @@ -216,19 +225,24 @@ TEST(llvm, module_call_lowered_func) { TEST(ExecutionEngine, custom_runtime_symbols) { auto context = std::make_unique(); - auto module = std::make_unique("test_llvm_cpu_runtime", *context); + auto module = + std::make_unique("test_llvm_cpu_runtime", *context); auto builder = std::make_unique>(*context); auto call_custom_target = [&](std::string name, llvm::Type *ty) { llvm::FunctionType *fn_type = llvm::FunctionType::get(ty, {ty}, false); llvm::Function *function = - llvm::Function::Create(fn_type, llvm::Function::ExternalLinkage, "_call_custom_" + name, module.get()); + llvm::Function::Create(fn_type, + llvm::Function::ExternalLinkage, + "_call_custom_" + name, + module.get()); function->setCallingConv(llvm::CallingConv::C); - llvm::BasicBlock *entry = llvm::BasicBlock::Create(module->getContext(), "entry", function); + llvm::BasicBlock *entry = + llvm::BasicBlock::Create(module->getContext(), "entry", function); builder->SetInsertPoint(entry); llvm::Argument *arg = &*function->args().begin(); - llvm::Function *custom_function = - llvm::dyn_cast(module->getOrInsertFunction(name, fn_type).getCallee()); + llvm::Function *custom_function = llvm::dyn_cast( + module->getOrInsertFunction(name, fn_type).getCallee()); custom_function->setCallingConv(llvm::CallingConv::C); llvm::Value *ret = builder->CreateCall(custom_function, {arg}); builder->CreateRet(ret); @@ -252,7 +266,8 @@ TEST(ExecutionEngine, custom_runtime_symbols) { int random_y = dis(mt); decltype(auto) registry = GlobalSymbolRegistry::Global(); - // registry.Register("dereference_f64_ptr", (void *)+[](double *x) { return *x; }); + // registry.Register("dereference_f64_ptr", (void *)+[](double *x) { return + // *x; }); for (size_t i = 0; i < angle.size(); i++) { registry.RegisterVar("theta_" + std::to_string(i), angle[i]); @@ -261,10 +276,14 @@ TEST(ExecutionEngine, custom_runtime_symbols) { auto engine = cinn::backends::ExecutionEngine::Create({1}); engine->AddModule(std::move(module), std::move(context)); - auto *call_cosf = reinterpret_cast(engine->Lookup("_call_custom_cosf")); - auto *call_cos = reinterpret_cast(engine->Lookup("_call_custom_cos")); - auto *call_sinf = reinterpret_cast(engine->Lookup("_call_custom_sinf")); - auto *call_sin = reinterpret_cast(engine->Lookup("_call_custom_sin")); + auto *call_cosf = + reinterpret_cast(engine->Lookup("_call_custom_cosf")); + auto *call_cos = + reinterpret_cast(engine->Lookup("_call_custom_cos")); + auto *call_sinf = + reinterpret_cast(engine->Lookup("_call_custom_sinf")); + auto *call_sin = + reinterpret_cast(engine->Lookup("_call_custom_sin")); ASSERT_TRUE(call_cosf && call_cos && call_sinf && call_sin); @@ -288,7 +307,11 @@ TEST(ExecutionEngine, call_extern) { {M, N}, [=](Var i, Var j) { return x(i, j) + y(i, j); }, "add_out"); ir::Tensor res = Compute( - {M, N}, [&](Var i, Var j) -> Expr { return lang::CallExtern("tanh", {add_out(i, j)}); }, "res"); + {M, N}, + [&](Var i, Var j) -> Expr { + return lang::CallExtern("tanh", {add_out(i, j)}); + }, + "res"); auto stages = CreateStages({add_out, res}); @@ -303,12 +326,12 @@ TEST(ExecutionEngine, call_extern) { engine->Link(builder.Build()); auto _ab_bb_cb_ = CreateTestBuffer(); // NOLINT - auto &ab = std::get<0>(_ab_bb_cb_); - auto &bb = std::get<1>(_ab_bb_cb_); - auto &cb = std::get<2>(_ab_bb_cb_); + auto &ab = std::get<0>(_ab_bb_cb_); + auto &bb = std::get<1>(_ab_bb_cb_); + auto &cb = std::get<2>(_ab_bb_cb_); auto comp_addr = engine->Lookup("comp"); - auto comp = reinterpret_cast(comp_addr); + auto comp = reinterpret_cast(comp_addr); cinn_pod_value_t a_arg(ab), b_arg(bb), c_arg(cb); cinn_pod_value_t args[3] = {a_arg, b_arg, c_arg}; diff --git a/paddle/cinn/backends/llvm/ir_builder_mixin.h b/paddle/cinn/backends/llvm/ir_builder_mixin.h index 42b1e9663afbb..c897ff34dd9c7 100644 --- a/paddle/cinn/backends/llvm/ir_builder_mixin.h +++ b/paddle/cinn/backends/llvm/ir_builder_mixin.h @@ -300,7 +300,9 @@ class IrBuilderMixin { } private: - llvm::IRBuilder<> *mixin_builder() { return static_cast(this)->b(); } + llvm::IRBuilder<> *mixin_builder() { + return static_cast(this)->b(); + } }; } // namespace backends } // namespace cinn diff --git a/paddle/cinn/backends/llvm/llvm_intrin_rule.h b/paddle/cinn/backends/llvm/llvm_intrin_rule.h index 7e912fcdcc8b6..77d22349ed258 100644 --- a/paddle/cinn/backends/llvm/llvm_intrin_rule.h +++ b/paddle/cinn/backends/llvm/llvm_intrin_rule.h @@ -33,21 +33,24 @@ namespace codegen { template inline void MakeFloatIntrinOp(lang::Args args, lang::RetValue *rv) { CHECK_GE(args.size(), 1U); - Expr arg = args[0]; + Expr arg = args[0]; ir::Call *node = arg->as(); CHECK(node); CHECK_GE(node->read_args.size(), arg_nums); if (add_float_suffix) { CHECK(node->type().is_float()); - *rv = ir::intrinsics::BuiltinIntrin::Make(node->name + "f", node->read_args, id, arg_nums, node->type()); + *rv = ir::intrinsics::BuiltinIntrin::Make( + node->name + "f", node->read_args, id, arg_nums, node->type()); } else { - *rv = ir::intrinsics::BuiltinIntrin::Make(node->name, node->read_args, id, arg_nums, node->type()); + *rv = ir::intrinsics::BuiltinIntrin::Make( + node->name, node->read_args, id, arg_nums, node->type()); } } void RegisterCpuIntrinRule() { -#define __(intrin_name__, id) \ - ir::Registry::Register("lower_cpu_intrinsic_" #intrin_name__, true).SetBody(MakeFloatIntrinOp); +#define __(intrin_name__, id) \ + ir::Registry::Register("lower_cpu_intrinsic_" #intrin_name__, true) \ + .SetBody(MakeFloatIntrinOp); __(exp, ::llvm::Intrinsic::exp) __(exp2, ::llvm::Intrinsic::exp2) __(sqrt, ::llvm::Intrinsic::sqrt) @@ -64,114 +67,128 @@ void RegisterCpuIntrinRule() { #undef __ // set id -1 if not llvm intrinsics -#define RegisterBitwise(intrin_name__) \ - ir::Registry::Register("lower_cpu_intrinsic_" #intrin_name__, true).SetBody(MakeFloatIntrinOp<-1, 2, false>); - RegisterBitwise(bitwise_or) RegisterBitwise(bitwise_xor) RegisterBitwise(bitwise_and) RegisterBitwise(left_shift) - RegisterBitwise(right_shift) +#define RegisterBitwise(intrin_name__) \ + ir::Registry::Register("lower_cpu_intrinsic_" #intrin_name__, true) \ + .SetBody(MakeFloatIntrinOp<-1, 2, false>); + RegisterBitwise(bitwise_or) RegisterBitwise(bitwise_xor) RegisterBitwise( + bitwise_and) RegisterBitwise(left_shift) RegisterBitwise(right_shift) #undef RegisterBitwise - ir::Registry::Register("lower_cpu_intrinsic_fma", true) - .SetBody(MakeFloatIntrinOp<::llvm::Intrinsic::fmuladd, 3, false>); - - ir::Registry::Register("lower_cpu_intrinsic_bitwise_not", true).SetBody(MakeFloatIntrinOp<-1, 1, false>); - - ir::Registry::Register("lower_cpu_intrinsic_isnan", true).SetBody(MakeFloatIntrinOp<-1, 1, false>); - - ir::Registry::Register("lower_cpu_intrinsic_isfinite", true).SetBody([](lang::Args args, lang::RetValue *rv) { - CHECK_GE(args.size(), 1U); - Expr arg0 = args[0]; - ir::Call *node = arg0->as(); - CHECK(node); - CHECK(!node->read_args.empty()); - Expr arg = node->read_args[0]; - *rv = !(lang::IsInf(arg)) && !(lang::IsNan(arg)); - }); - - ir::Registry::Register("lower_cpu_intrinsic_isinf", true).SetBody([](lang::Args args, lang::RetValue *rv) { - CHECK_GE(args.size(), 1U); - Expr arg0 = args[0]; - ir::Call *node = arg0->as(); - CHECK(node); - CHECK(!node->read_args.empty()); - Expr arg = node->read_args[0]; - Type type = arg->type(); - if (type.is_int() || type.is_uint()) { - *rv = common::make_bool(false, type.lanes()); - } else if (type.is_float()) { - *rv = ir::EQ::Make(lang::Abs(arg), lang::Infinity(type)) && !(lang::IsNan(arg)); - } - }); - - ir::Registry::Register("lower_cpu_intrinsic_rsqrt", true).SetBody([](lang::Args args, lang::RetValue *rv) { - CHECK_GE(args.size(), 1U); - Expr arg0 = args[0]; - ir::Call *node = arg0->as(); - CHECK(node); - CHECK(!node->read_args.empty()); - Expr arg = node->read_args[0]; - *rv = make_const(arg->type(), 1) / lang::Sqrt(arg); - }); - - ir::Registry::Register("lower_cpu_intrinsic_exp10", true).SetBody([](lang::Args args, lang::RetValue *rv) { - CHECK_GE(args.size(), 1U); - Expr arg0 = args[0]; - ir::Call *node = arg0->as(); - CHECK(node); - CHECK(!node->read_args.empty()); - Expr arg = node->read_args[0]; - Expr ln10 = make_const(arg->type(), 2.302585093); - *rv = lang::Exp(arg * ln10); - }); - - ir::Registry::Register("lower_cpu_intrinsic_tan", true).SetBody([](lang::Args args, lang::RetValue *rv) { - CHECK_GE(args.size(), 1U); - Expr arg0 = args[0]; - ir::Call *node = arg0->as(); - CHECK(node); - CHECK(!node->read_args.empty()); - Expr arg = node->read_args[0]; - *rv = lang::Sin(arg) / lang::Cos(arg); - }); - - ir::Registry::Register("lower_cpu_intrinsic_tanh", true).SetBody([](lang::Args args, lang::RetValue *rv) { - CHECK_GE(args.size(), 1U); - Expr arg0 = args[0]; - ir::Call *node = arg0->as(); - CHECK(node); - CHECK(!node->read_args.empty()); - Expr arg = node->read_args[0]; - Expr zero = make_const(arg->type(), 0); - Expr one = make_const(arg->type(), 1); - Expr two = make_const(arg->type(), 2); - Expr neg_two = make_const(arg->type(), -2); - - Expr exp_neg2x = lang::Exp(neg_two * arg); - Expr exp_pos2x = lang::Exp(two * arg); - - Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); - Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); - *rv = ir::Select::Make(arg >= zero, tanh_pos, tanh_neg); - }); - - ir::Registry::Register("lower_cpu_intrinsic_cosh", true).SetBody([](lang::Args args, lang::RetValue *rv) { - CHECK_GE(args.size(), 1U); - Expr arg0 = args[0]; - ir::Call *node = arg0->as(); - CHECK(node); - CHECK(!node->read_args.empty()); - Expr arg = node->read_args[0]; - *rv = (lang::Exp(arg) + lang::Exp(arg * make_const(arg->type(), -1))) / make_const(arg->type(), 2); - }); - - ir::Registry::Register("lower_cpu_intrinsic_sinh", true).SetBody([](lang::Args args, lang::RetValue *rv) { - CHECK_GE(args.size(), 1U); - Expr arg0 = args[0]; - ir::Call *node = arg0->as(); - CHECK(node); - CHECK(!node->read_args.empty()); - Expr arg = node->read_args[0]; - *rv = (lang::Exp(arg) - lang::Exp(arg * make_const(arg->type(), -1))) / make_const(arg->type(), 2); - }); + ir::Registry::Register("lower_cpu_intrinsic_fma", true) + .SetBody(MakeFloatIntrinOp<::llvm::Intrinsic::fmuladd, 3, false>); + + ir::Registry::Register("lower_cpu_intrinsic_bitwise_not", true) + .SetBody(MakeFloatIntrinOp<-1, 1, false>); + + ir::Registry::Register("lower_cpu_intrinsic_isnan", true) + .SetBody(MakeFloatIntrinOp<-1, 1, false>); + + ir::Registry::Register("lower_cpu_intrinsic_isfinite", true) + .SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + *rv = !(lang::IsInf(arg)) && !(lang::IsNan(arg)); + }); + + ir::Registry::Register("lower_cpu_intrinsic_isinf", true) + .SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + Type type = arg->type(); + if (type.is_int() || type.is_uint()) { + *rv = common::make_bool(false, type.lanes()); + } else if (type.is_float()) { + *rv = ir::EQ::Make(lang::Abs(arg), lang::Infinity(type)) && + !(lang::IsNan(arg)); + } + }); + + ir::Registry::Register("lower_cpu_intrinsic_rsqrt", true) + .SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + *rv = make_const(arg->type(), 1) / lang::Sqrt(arg); + }); + + ir::Registry::Register("lower_cpu_intrinsic_exp10", true) + .SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + Expr ln10 = make_const(arg->type(), 2.302585093); + *rv = lang::Exp(arg * ln10); + }); + + ir::Registry::Register("lower_cpu_intrinsic_tan", true) + .SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + *rv = lang::Sin(arg) / lang::Cos(arg); + }); + + ir::Registry::Register("lower_cpu_intrinsic_tanh", true) + .SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + Expr zero = make_const(arg->type(), 0); + Expr one = make_const(arg->type(), 1); + Expr two = make_const(arg->type(), 2); + Expr neg_two = make_const(arg->type(), -2); + + Expr exp_neg2x = lang::Exp(neg_two * arg); + Expr exp_pos2x = lang::Exp(two * arg); + + Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); + Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); + *rv = ir::Select::Make(arg >= zero, tanh_pos, tanh_neg); + }); + + ir::Registry::Register("lower_cpu_intrinsic_cosh", true) + .SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + *rv = (lang::Exp(arg) + lang::Exp(arg * make_const(arg->type(), -1))) / + make_const(arg->type(), 2); + }); + + ir::Registry::Register("lower_cpu_intrinsic_sinh", true) + .SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + *rv = (lang::Exp(arg) - lang::Exp(arg * make_const(arg->type(), -1))) / + make_const(arg->type(), 2); + }); } } // namespace codegen } // namespace cinn diff --git a/paddle/cinn/backends/llvm/llvm_optimizer.cc b/paddle/cinn/backends/llvm/llvm_optimizer.cc index 3fd11ea8c1731..e64fb9f42ee0b 100644 --- a/paddle/cinn/backends/llvm/llvm_optimizer.cc +++ b/paddle/cinn/backends/llvm/llvm_optimizer.cc @@ -74,11 +74,13 @@ class CustomPassManager : public PassManagerT { void add(llvm::Pass *pass) override { if (print_passes_) { if (is_function_pass_manager_) { - VLOG(1) << "llvm run function pass[" << std::string(pass->getPassName()) << "]"; + VLOG(1) << "llvm run function pass[" << std::string(pass->getPassName()) + << "]"; } if (is_module_pass_manager_) { - VLOG(1) << "llvm run module pass[" << std::string(pass->getPassName()) << "]"; + VLOG(1) << "llvm run module pass[" << std::string(pass->getPassName()) + << "]"; } } // static bool add_pass = true; @@ -107,12 +109,14 @@ class CustomPassManager : public PassManagerT { private: static constexpr bool is_function_pass_manager_ = std::is_same::value; - static constexpr bool is_module_pass_manager_ = std::is_same::value; + static constexpr bool is_module_pass_manager_ = + std::is_same::value; bool print_passes_; }; -using CustomFunctionPassManager = CustomPassManager; -using CustomModulePassManager = CustomPassManager; +using CustomFunctionPassManager = + CustomPassManager; +using CustomModulePassManager = CustomPassManager; } // namespace LLVMModuleOptimizer::LLVMModuleOptimizer(llvm::TargetMachine *machine, @@ -122,8 +126,9 @@ LLVMModuleOptimizer::LLVMModuleOptimizer(llvm::TargetMachine *machine, : opt_level_(opt_level), print_passes_(print_passes), machine_(machine) {} void LLVMModuleOptimizer::operator()(llvm::Module *m) { - auto machine = - std::move(llvm::cantFail(llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost()).createTargetMachine())); + auto machine = std::move(llvm::cantFail( + llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost()) + .createTargetMachine())); auto fpm = std::make_unique(print_passes_, m); // fpm->add(llvm::createTargetTransformInfoWrapperPass(llvm::TargetIRAnalysis())); // fpm->add(llvm::createInstructionCombiningPass()); @@ -141,15 +146,18 @@ void LLVMModuleOptimizer::operator()(llvm::Module *m) { auto mpm = std::make_unique(print_passes_); // mpm->add(llvm::createTargetTransformInfoWrapperPass(llvm::TargetIRAnalysis())); - // LOG(INFO) << "llvm run pass: target machine: name[" << machine_->getTarget().getName() << "]"; - // LOG(INFO) << "llvm run pass: target machine: cpu[" << machine_->getTargetCPU().str() << "]"; - fpm->add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis())); - mpm->add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis())); - auto builder = std::make_unique(); - builder->OptLevel = opt_level_; - builder->Inliner = llvm::createFunctionInliningPass(); + // LOG(INFO) << "llvm run pass: target machine: name[" << + // machine_->getTarget().getName() << "]"; LOG(INFO) << "llvm run pass: target + // machine: cpu[" << machine_->getTargetCPU().str() << "]"; + fpm->add(llvm::createTargetTransformInfoWrapperPass( + machine->getTargetIRAnalysis())); + mpm->add(llvm::createTargetTransformInfoWrapperPass( + machine->getTargetIRAnalysis())); + auto builder = std::make_unique(); + builder->OptLevel = opt_level_; + builder->Inliner = llvm::createFunctionInliningPass(); builder->LoopVectorize = true; - builder->SLPVectorize = true; + builder->SLPVectorize = true; #if LLVM_VERSION_MAJOR >= 11 machine->adjustPassManager(*builder); #endif diff --git a/paddle/cinn/backends/llvm/llvm_util.cc b/paddle/cinn/backends/llvm/llvm_util.cc index 1fe056e94d406..32256ecc5c9ca 100644 --- a/paddle/cinn/backends/llvm/llvm_util.cc +++ b/paddle/cinn/backends/llvm/llvm_util.cc @@ -26,7 +26,9 @@ namespace backends { using cinn::common::bfloat16; using cinn::common::float16; -llvm::Type *CinnTypeToLLVMType(common::Type type, llvm::Module *m, bool is_vec) { +llvm::Type *CinnTypeToLLVMType(common::Type type, + llvm::Module *m, + bool is_vec) { llvm::Type *ir_type = nullptr; if (type.is_cpp_const()) { // TODO(fc500110) support it latter. @@ -36,21 +38,22 @@ llvm::Type *CinnTypeToLLVMType(common::Type type, llvm::Module *m, bool is_vec) llvm::Type *i1 = llvm::Type::getInt1Ty(m->getContext()); - llvm::Type *i8 = llvm::Type::getInt8Ty(m->getContext()); + llvm::Type *i8 = llvm::Type::getInt8Ty(m->getContext()); llvm::Type *i16 = llvm::Type::getInt16Ty(m->getContext()); llvm::Type *i32 = llvm::Type::getInt32Ty(m->getContext()); llvm::Type *i64 = llvm::Type::getInt64Ty(m->getContext()); - llvm::Type *u8 = llvm::Type::getInt8Ty(m->getContext()); + llvm::Type *u8 = llvm::Type::getInt8Ty(m->getContext()); llvm::Type *u16 = llvm::Type::getInt16Ty(m->getContext()); llvm::Type *u32 = llvm::Type::getInt32Ty(m->getContext()); llvm::Type *u64 = llvm::Type::getInt64Ty(m->getContext()); llvm::Type *bf16 = llvm::Type::getBFloatTy(m->getContext()); - llvm::Type *f16 = llvm::Type::getHalfTy(m->getContext()); - llvm::Type *f32 = llvm::Type::getFloatTy(m->getContext()); - llvm::Type *f64 = llvm::Type::getDoubleTy(m->getContext()); - llvm::Type *arr = llvm::Type::getPrimitiveType(m->getContext(), llvm::Type::ArrayTyID); + llvm::Type *f16 = llvm::Type::getHalfTy(m->getContext()); + llvm::Type *f32 = llvm::Type::getFloatTy(m->getContext()); + llvm::Type *f64 = llvm::Type::getDoubleTy(m->getContext()); + llvm::Type *arr = + llvm::Type::getPrimitiveType(m->getContext(), llvm::Type::ArrayTyID); if (type.is_void() && type.is_cpp_handle()) { return llvm::PointerType::getUnqual(i8); } diff --git a/paddle/cinn/backends/llvm/llvm_util.h b/paddle/cinn/backends/llvm/llvm_util.h index 41e3523f3a17c..dd1a79768ab02 100644 --- a/paddle/cinn/backends/llvm/llvm_util.h +++ b/paddle/cinn/backends/llvm/llvm_util.h @@ -44,9 +44,13 @@ std::string DumpToString(const T &entity) { // return "\033[33m" + buffer + "\033[0m"; // Green } -inline llvm::StringRef AsStringRef(absl::string_view str) { return llvm::StringRef(str.data(), str.size()); } +inline llvm::StringRef AsStringRef(absl::string_view str) { + return llvm::StringRef(str.data(), str.size()); +} -llvm::Type *CinnTypeToLLVMType(common::Type t, llvm::Module *m, bool is_vec = false); +llvm::Type *CinnTypeToLLVMType(common::Type t, + llvm::Module *m, + bool is_vec = false); template llvm::Type *llvm_type_of(llvm::Module *m); diff --git a/paddle/cinn/backends/llvm/runtime_symbol_registry.cc b/paddle/cinn/backends/llvm/runtime_symbol_registry.cc index 08d2c5e8b50bc..0265b72d50a6a 100644 --- a/paddle/cinn/backends/llvm/runtime_symbol_registry.cc +++ b/paddle/cinn/backends/llvm/runtime_symbol_registry.cc @@ -19,8 +19,8 @@ #include -#include "paddle/cinn/runtime/flags.h" #include "gflags/gflags_declare.h" +#include "paddle/cinn/runtime/flags.h" DECLARE_bool(verbose_function_register); @@ -51,7 +51,8 @@ void RuntimeSymbols::Register(const std::string &name, void *address) { std::lock_guard lock(mu_); auto it = symbols_.find(name); if (it != symbols_.end()) { - CHECK_EQ(it->second, address) << "Duplicate register symbol [" << name << "]"; + CHECK_EQ(it->second, address) + << "Duplicate register symbol [" << name << "]"; return; } diff --git a/paddle/cinn/backends/llvm/runtime_symbol_registry.h b/paddle/cinn/backends/llvm/runtime_symbol_registry.h index dd416bdbe76d4..03eaad4b3e99f 100644 --- a/paddle/cinn/backends/llvm/runtime_symbol_registry.h +++ b/paddle/cinn/backends/llvm/runtime_symbol_registry.h @@ -36,7 +36,7 @@ class RuntimeSymbols { RuntimeSymbols(const RuntimeSymbols &) = delete; RuntimeSymbols(RuntimeSymbols &&rhs) { - symbols_ = std::move(rhs.symbols_); + symbols_ = std::move(rhs.symbols_); scalar_holder_ = std::move(rhs.scalar_holder_); } @@ -45,7 +45,9 @@ class RuntimeSymbols { * @param name Name of the symbol. * @param address Address of the function. */ - void RegisterFn(const std::string &name, void *address) { Register(name, address); } + void RegisterFn(const std::string &name, void *address) { + Register(name, address); + } /** * Register scalar. @@ -85,7 +87,8 @@ class RuntimeSymbols { private: /** - * Register external symbol to the registry, the symbols in the registry will finally registered to JIT . + * Register external symbol to the registry, the symbols in the registry will + * finally registered to JIT . * @param name Name of the symbol in the JIT. * @param address The address of the variable in external space. */ diff --git a/paddle/cinn/backends/llvm/simple_jit.cc b/paddle/cinn/backends/llvm/simple_jit.cc index 8806c2c7f3dbb..1966542fd6fbf 100755 --- a/paddle/cinn/backends/llvm/simple_jit.cc +++ b/paddle/cinn/backends/llvm/simple_jit.cc @@ -49,7 +49,8 @@ void SimpleJIT::AddModule(std::unique_ptr module, bool optimize) { LOG(INFO) << "fn:\n" << DumpToString(fn); } */ - CHECK(!llvm::verifyModule(*module, &llvm::errs())) << "Transformation resulted in an invalid module\n\nmodule:\n"; + CHECK(!llvm::verifyModule(*module, &llvm::errs())) + << "Transformation resulted in an invalid module\n\nmodule:\n"; bool debug = false; if (optimize) { @@ -63,16 +64,20 @@ void SimpleJIT::AddModule(std::unique_ptr module, bool optimize) { pass_builder.registerCGSCCAnalyses(cgscc_analysis_manager); pass_builder.registerFunctionAnalyses(function_analysis_manager); pass_builder.registerLoopAnalyses(loop_analysis_manager); - pass_builder.crossRegisterProxies( - loop_analysis_manager, function_analysis_manager, cgscc_analysis_manager, module_analysis_manager); + pass_builder.crossRegisterProxies(loop_analysis_manager, + function_analysis_manager, + cgscc_analysis_manager, + module_analysis_manager); llvm::ModulePassManager module_pass_manager = - pass_builder.buildPerModuleDefaultPipeline(llvm::PassBuilder::OptimizationLevel::O3); + pass_builder.buildPerModuleDefaultPipeline( + llvm::PassBuilder::OptimizationLevel::O3); module_pass_manager.run(*module, module_analysis_manager); } VLOG(3) << "jit target: " << jit_->getDataLayout().getStringRepresentation(); - VLOG(3) << "module target: " << module->getDataLayout().getStringRepresentation(); + VLOG(3) << "module target: " + << module->getDataLayout().getStringRepresentation(); llvm::orc::ThreadSafeModule tsm(std::move(module), context_); llvm::cantFail(jit_->addIRModule(std::move(tsm))); @@ -97,15 +102,19 @@ SimpleJIT::SimpleJIT() : context_(std::make_unique()) { CHECK(jit_) << "JIT create failed"; auto proc_symbols_generator = llvm::cantFail( - llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(jit_->getDataLayout().getGlobalPrefix())); + llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( + jit_->getDataLayout().getGlobalPrefix())); jit_->getMainJITDylib().addGenerator(std::move(proc_symbols_generator)); - llvm::orc::MangleAndInterner mangle(jit_->getExecutionSession(), jit_->getDataLayout()); + llvm::orc::MangleAndInterner mangle(jit_->getExecutionSession(), + jit_->getDataLayout()); for (auto &item : GlobalSymbolRegistry::Global().All()) { VLOG(2) << "Insert [" << item.first << "] to SimpleJIT"; llvm::cantFail(jit_->define(llvm::orc::absoluteSymbols( - {{mangle(item.first), {llvm::pointerToJITTargetAddress(item.second), llvm::JITSymbolFlags::None}}}))); + {{mangle(item.first), + {llvm::pointerToJITTargetAddress(item.second), + llvm::JITSymbolFlags::None}}}))); } } @@ -126,7 +135,8 @@ void SimpleJIT::Link(ir::Module module, bool optimize) { } template void SimpleJIT::Link(ir::Module module, bool optimize); -template void SimpleJIT::Link(ir::Module module, bool optimize); +template void SimpleJIT::Link(ir::Module module, + bool optimize); } // namespace backends diff --git a/paddle/cinn/backends/llvm/simple_jit.h b/paddle/cinn/backends/llvm/simple_jit.h index 039bfe17417d8..5d70f98e55697 100755 --- a/paddle/cinn/backends/llvm/simple_jit.h +++ b/paddle/cinn/backends/llvm/simple_jit.h @@ -50,7 +50,9 @@ namespace backends { class SimpleJIT { public: - static std::unique_ptr Create() { return std::unique_ptr(new SimpleJIT); } + static std::unique_ptr Create() { + return std::unique_ptr(new SimpleJIT); + } /** * Runtime link to a module. @@ -61,7 +63,9 @@ class SimpleJIT { template void Link(ir::Module module, bool optimize = true); - void Link(llvm::orc::ThreadSafeModule m, bool optimize = true) { llvm::cantFail(jit_->addIRModule(std::move(m))); } + void Link(llvm::orc::ThreadSafeModule m, bool optimize = true) { + llvm::cantFail(jit_->addIRModule(std::move(m))); + } llvm::JITTargetAddress Lookup(absl::string_view name) { return llvm::cantFail(jit_->lookup(AsStringRef(name))).getAddress(); diff --git a/paddle/cinn/backends/modular.cc b/paddle/cinn/backends/modular.cc index 428da27776f56..41a74643d6616 100644 --- a/paddle/cinn/backends/modular.cc +++ b/paddle/cinn/backends/modular.cc @@ -21,9 +21,12 @@ namespace backends { class ModularEvaluator : public ir::IRVisitorBase { public: - explicit ModularEvaluator(const std::map& mod_map) : mod_map_(mod_map) {} + explicit ModularEvaluator(const std::map& mod_map) + : mod_map_(mod_map) {} - ModularEntry Eval(const Expr& e) { return ir::IRVisitorBase::Visit(&e); } + ModularEntry Eval(const Expr& e) { + return ir::IRVisitorBase::Visit(&e); + } ModularEntry Visit(const ir::IntImm* op) { if (op->value < std::numeric_limits::max()) { @@ -51,7 +54,7 @@ class ModularEvaluator : public ir::IRVisitorBase { auto b = Eval(op->b()); ModularEntry ret; ret.coeff = gcd(a.coeff, b.coeff); - ret.base = BaseSimplify(a.base + b.base, ret.coeff); + ret.base = BaseSimplify(a.base + b.base, ret.coeff); return ret; } @@ -61,7 +64,7 @@ class ModularEvaluator : public ir::IRVisitorBase { ModularEntry ret; ret.coeff = gcd(a.coeff, b.coeff); - ret.base = BaseSimplify(a.base - b.base, ret.coeff); + ret.base = BaseSimplify(a.base - b.base, ret.coeff); return ret; } @@ -75,7 +78,7 @@ class ModularEvaluator : public ir::IRVisitorBase { ModularEntry ret; ret.coeff = gcd(pq, gcd(pm, qn)); - ret.base = BaseSimplify(a.base * b.base, ret.coeff); + ret.base = BaseSimplify(a.base * b.base, ret.coeff); return ret; } @@ -86,7 +89,7 @@ class ModularEvaluator : public ir::IRVisitorBase { if (b.coeff % b.base == 0) { ModularEntry ret; ret.coeff = a.coeff / b.base; - ret.base = 0; + ret.base = 0; return ret; } @@ -120,7 +123,7 @@ class ModularEvaluator : public ir::IRVisitorBase { ModularEntry ModularEntry::Add(const ModularEntry& a, const ModularEntry& b) { ModularEntry ret; ret.coeff = ModularEvaluator::gcd(a.coeff, b.coeff); - ret.base = ModularEvaluator::BaseSimplify(a.base + b.base, ret.coeff); + ret.base = ModularEvaluator::BaseSimplify(a.base + b.base, ret.coeff); return ret; } diff --git a/paddle/cinn/backends/modular.h b/paddle/cinn/backends/modular.h index c3a657b6360ba..666bd4ba50036 100644 --- a/paddle/cinn/backends/modular.h +++ b/paddle/cinn/backends/modular.h @@ -34,7 +34,8 @@ struct ModularEntry { static ModularEntry Add(const ModularEntry& a, const ModularEntry& b); }; -ModularEntry EvalModular(const Expr& e, const std::map& mod_map); +ModularEntry EvalModular(const Expr& e, + const std::map& mod_map); } // namespace backends } // namespace cinn diff --git a/paddle/cinn/backends/nvrtc/header_generator.cc b/paddle/cinn/backends/nvrtc/header_generator.cc index 9eba1cf2f0a09..328b0ce1f53be 100644 --- a/paddle/cinn/backends/nvrtc/header_generator.cc +++ b/paddle/cinn/backends/nvrtc/header_generator.cc @@ -27,7 +27,8 @@ HeaderGeneratorBase& JitSafeHeaderGenerator::GetInstance() { } const size_t JitSafeHeaderGenerator::size() const { - CHECK_EQ(include_names_.size(), headers_.size()) << "Internal error in size of header files."; + CHECK_EQ(include_names_.size(), headers_.size()) + << "Internal error in size of header files."; return include_names_.size(); } diff --git a/paddle/cinn/backends/nvrtc/header_generator.h b/paddle/cinn/backends/nvrtc/header_generator.h index 1e6e57665857e..171e0c994af07 100644 --- a/paddle/cinn/backends/nvrtc/header_generator.h +++ b/paddle/cinn/backends/nvrtc/header_generator.h @@ -22,8 +22,8 @@ namespace cinn { namespace backends { class HeaderGeneratorBase { public: - virtual const size_t size() const = 0; - virtual const std::vector& headers() const = 0; + virtual const size_t size() const = 0; + virtual const std::vector& headers() const = 0; virtual const std::vector& include_names() const = 0; }; @@ -34,7 +34,9 @@ class JitSafeHeaderGenerator : public HeaderGeneratorBase { static HeaderGeneratorBase& GetInstance(); const size_t size() const; const std::vector& headers() const override { return headers_; } - const std::vector& include_names() const override { return include_names_; } + const std::vector& include_names() const override { + return include_names_; + } private: JitSafeHeaderGenerator(); diff --git a/paddle/cinn/backends/nvrtc/nvrtc_util.cc b/paddle/cinn/backends/nvrtc/nvrtc_util.cc index 3ba93ece28477..90674e65f9c8d 100644 --- a/paddle/cinn/backends/nvrtc/nvrtc_util.cc +++ b/paddle/cinn/backends/nvrtc/nvrtc_util.cc @@ -37,7 +37,8 @@ namespace cinn { namespace backends { namespace nvrtc { -std::string Compiler::operator()(const std::string& code, bool include_headers) { +std::string Compiler::operator()(const std::string& code, + bool include_headers) { if (runtime::CanUseNvccCompiler()) { return CompileWithNvcc(code); } @@ -74,22 +75,28 @@ std::vector Compiler::FindCUDAIncludePaths() { } #endif LOG(FATAL) << "Cannot find cuda include path." - << "CUDA_PATH is not set or CUDA is not installed in the default installation path." + << "CUDA_PATH is not set or CUDA is not installed in the default " + "installation path." << "In other than linux, it is necessary to set CUDA_PATH."; return {cuda_include_path}; } -std::vector Compiler::FindCINNRuntimeIncludePaths() { return {Context::Global().runtime_include_dir()}; } +std::vector Compiler::FindCINNRuntimeIncludePaths() { + return {Context::Global().runtime_include_dir()}; +} -std::string Compiler::CompileCudaSource(const std::string& code, bool include_headers) { +std::string Compiler::CompileCudaSource(const std::string& code, + bool include_headers) { const auto& header_gen = JitSafeHeaderGenerator::GetInstance(); std::vector compile_options; std::vector param_cstrings{}; nvrtcProgram prog; std::string cc = "30"; int major, minor; - cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0); - cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0); + cudaError_t e1 = + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0); + cudaError_t e2 = + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0); if (e1 == cudaSuccess && e2 == cudaSuccess) { cc = std::to_string(major) + std::to_string(minor); @@ -115,16 +122,22 @@ std::string Compiler::CompileCudaSource(const std::string& code, bool include_he for (auto& header : cinn_headers) { include_paths.push_back("--include-path=" + header); } - compile_options.insert(std::end(compile_options), include_paths.begin(), include_paths.end()); + compile_options.insert( + std::end(compile_options), include_paths.begin(), include_paths.end()); } for (const auto& option : compile_options) { param_cstrings.push_back(option.c_str()); } VLOG(3) << "compile options: " << utils::Join(compile_options, " "); - NVRTC_CALL(nvrtcCreateProgram( - &prog, code.c_str(), nullptr, header_gen.size(), header_gen.headers().data(), header_gen.include_names().data())); - nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); + NVRTC_CALL(nvrtcCreateProgram(&prog, + code.c_str(), + nullptr, + header_gen.size(), + header_gen.headers().data(), + header_gen.include_names().data())); + nvrtcResult compile_res = + nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); { // get log size_t log_size; @@ -173,10 +186,11 @@ std::string Compiler::CompileWithNvcc(const std::string& cuda_c) { return prefix_name_ + ".cubin"; } -// std::string Compiler::GetPtx() { return ReadFile(prefix_name_ + ".ptx", std::ios::in); } +// std::string Compiler::GetPtx() { return ReadFile(prefix_name_ + ".ptx", +// std::ios::in); } void Compiler::CompileToPtx() { - auto include_dir = common::Context::Global().runtime_include_dir(); + auto include_dir = common::Context::Global().runtime_include_dir(); std::string include_dir_str = ""; for (auto dir : include_dir) { if (include_dir_str.empty()) { @@ -187,7 +201,8 @@ void Compiler::CompileToPtx() { } std::string options = std::string("export PATH=") + FLAGS_cinn_nvcc_cmd_path + - std::string(":$PATH && nvcc -std=c++14 --ptx -O3 -I ") + include_dir_str; + std::string(":$PATH && nvcc -std=c++14 --ptx -O3 -I ") + + include_dir_str; options += " -arch=" + GetDeviceArch(); options += " -o " + prefix_name_ + ".ptx"; options += " " + prefix_name_ + ".cu"; @@ -197,8 +212,8 @@ void Compiler::CompileToPtx() { } void Compiler::CompileToCubin() { - std::string options = - std::string("export PATH=") + FLAGS_cinn_nvcc_cmd_path + std::string(":$PATH && nvcc --cubin -O3"); + std::string options = std::string("export PATH=") + FLAGS_cinn_nvcc_cmd_path + + std::string(":$PATH && nvcc --cubin -O3"); options += " -arch=" + GetDeviceArch(); options += " -o " + prefix_name_ + ".cubin"; options += " " + prefix_name_ + ".ptx"; @@ -209,8 +224,10 @@ void Compiler::CompileToCubin() { std::string Compiler::GetDeviceArch() { int major = 0, minor = 0; - if (cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0) == cudaSuccess && - cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0) == cudaSuccess) { + if (cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0) == + cudaSuccess && + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0) == + cudaSuccess) { return "sm_" + std::to_string(major) + std::to_string(minor); } else { LOG(WARNING) << "cannot detect compute capability from your device, " @@ -219,7 +236,8 @@ std::string Compiler::GetDeviceArch() { } } -std::string Compiler::ReadFile(const std::string& file_name, std::ios_base::openmode mode) { +std::string Compiler::ReadFile(const std::string& file_name, + std::ios_base::openmode mode) { // open cubin file std::ifstream ifs(file_name, mode); CHECK(ifs.is_open()) << "Fail to open file " << file_name; diff --git a/paddle/cinn/backends/nvrtc/nvrtc_util.h b/paddle/cinn/backends/nvrtc/nvrtc_util.h index b13c24c550a63..9b406d775a1f5 100644 --- a/paddle/cinn/backends/nvrtc/nvrtc_util.h +++ b/paddle/cinn/backends/nvrtc/nvrtc_util.h @@ -36,7 +36,8 @@ class Compiler { /** * Compile the \p code and get PTX string. * @param code The CUDA source code. - * @param include_headers Whether to include the headers of CUDA and CINN runtime modules. + * @param include_headers Whether to include the headers of CUDA and CINN + * runtime modules. * @return Compiled PTX code string. */ std::string operator()(const std::string& code, bool include_headers = true); @@ -67,7 +68,8 @@ class Compiler { std::string CompileCudaSource(const std::string& code, bool include_headers); /** - * whether to compile the source code into cubin, only works with cuda version > 11.1 + * whether to compile the source code into cubin, only works with cuda version + * > 11.1 */ bool compile_to_cubin_{false}; diff --git a/paddle/cinn/backends/outputs.cc b/paddle/cinn/backends/outputs.cc index 1be075cadbb57..5f37f440bd1bc 100644 --- a/paddle/cinn/backends/outputs.cc +++ b/paddle/cinn/backends/outputs.cc @@ -18,31 +18,32 @@ namespace cinn { namespace lang {} // namespace lang backends::Outputs backends::Outputs::object(const std::string &name) const { - Outputs updated = *this; + Outputs updated = *this; updated.object_name = name; return updated; } backends::Outputs backends::Outputs::bitcode(const std::string &name) const { - Outputs updated = *this; + Outputs updated = *this; updated.bitcode_name = name; return updated; } backends::Outputs backends::Outputs::c_header(const std::string &name) const { - Outputs updated = *this; + Outputs updated = *this; updated.c_header_name = name; return updated; } backends::Outputs backends::Outputs::c_source(const std::string &name) const { - Outputs updated = *this; + Outputs updated = *this; updated.c_source_name = name; return updated; } -backends::Outputs backends::Outputs::cuda_source(const std::string &name) const { - Outputs updated = *this; +backends::Outputs backends::Outputs::cuda_source( + const std::string &name) const { + Outputs updated = *this; updated.cuda_source_name = name; return updated; } diff --git a/paddle/cinn/backends/raw_cuda_code_test.cu b/paddle/cinn/backends/raw_cuda_code_test.cu index 829b9f16f117c..f79ef22cf6389 100644 --- a/paddle/cinn/backends/raw_cuda_code_test.cu +++ b/paddle/cinn/backends/raw_cuda_code_test.cu @@ -26,7 +26,8 @@ __global__ void elementwise_add_kernel(const float* __restrict__ A, if ((threadIdx.x < 1024)) { { C[((1024 * blockIdx.x) + threadIdx.x)] = - (A[((1024 * blockIdx.x) + threadIdx.x)] + B[((1024 * blockIdx.x) + threadIdx.x)]); + (A[((1024 * blockIdx.x) + threadIdx.x)] + + B[((1024 * blockIdx.x) + threadIdx.x)]); } } } diff --git a/paddle/cinn/cinn.h b/paddle/cinn/cinn.h index 6fc210fe1c58c..e535478df6bcb 100644 --- a/paddle/cinn/cinn.h +++ b/paddle/cinn/cinn.h @@ -13,7 +13,8 @@ // limitations under the License. /** - * This file exposes some internal APIs to global cinn namespace to make usage more friendly. + * This file exposes some internal APIs to global cinn namespace to make usage + * more friendly. */ #pragma once #include "paddle/cinn/backends/codegen_c.h" diff --git a/paddle/cinn/common/arithmatic.cc b/paddle/cinn/common/arithmatic.cc index 596d8a4466530..44ed4846e782e 100644 --- a/paddle/cinn/common/arithmatic.cc +++ b/paddle/cinn/common/arithmatic.cc @@ -40,14 +40,14 @@ using namespace ir; // NOLINT #endif std::string ExprToGinacConverter::Repr(const ir::Expr& expr) { - auto* load_n = expr.As(); - auto* var_n = expr.As<_Var_>(); + auto* load_n = expr.As(); + auto* var_n = expr.As<_Var_>(); auto* broadcast_n = expr.As(); - auto* mod_n = expr.As(); - auto* min_n = expr.As(); - auto* max_n = expr.As(); - auto* div_n = expr.As
(); - auto* frac_n = expr.As(); + auto* mod_n = expr.As(); + auto* min_n = expr.As(); + auto* max_n = expr.As(); + auto* div_n = expr.As
(); + auto* frac_n = expr.As(); if (load_n || broadcast_n || mod_n || min_n || max_n || div_n || frac_n) { std::string repr = GetStreamCnt(expr); Replace(&repr, "[", "lsq_"); @@ -61,7 +61,7 @@ std::string ExprToGinacConverter::Repr(const ir::Expr& expr) { Replace(&repr, "/", "_div_"); // remove the spaces auto fields = utils::Split(repr, " "); - repr = utils::Join(fields, "_"); + repr = utils::Join(fields, "_"); return repr; } else if (var_n) { return utils::GetStreamCnt(expr); @@ -69,29 +69,33 @@ std::string ExprToGinacConverter::Repr(const ir::Expr& expr) { return ""; } -void ExprToGinacConverter::RecordExpr(const ir::Expr& expr) { repr_to_expr_[Repr(expr)] = expr; } +void ExprToGinacConverter::RecordExpr(const ir::Expr& expr) { + repr_to_expr_[Repr(expr)] = expr; +} GiNaC::ex ExprToGinacConverter::BuildHelper(ir::Expr expr) { - auto* load_n = expr.As(); - auto* var_n = expr.As<_Var_>(); - auto* int_n = expr.As(); - auto* float_n = expr.As(); - auto* add_n = expr.As(); - auto* sub_n = expr.As(); - auto* mul_n = expr.As(); - auto* div_n = expr.As
(); - auto* minus_n = expr.As(); + auto* load_n = expr.As(); + auto* var_n = expr.As<_Var_>(); + auto* int_n = expr.As(); + auto* float_n = expr.As(); + auto* add_n = expr.As(); + auto* sub_n = expr.As(); + auto* mul_n = expr.As(); + auto* div_n = expr.As
(); + auto* minus_n = expr.As(); auto* broadcast_n = expr.As(); - auto* mod_n = expr.As(); - auto* frac_n = expr.As(); - auto* min_n = expr.As(); - auto* max_n = expr.As(); + auto* mod_n = expr.As(); + auto* frac_n = expr.As(); + auto* min_n = expr.As(); + auto* max_n = expr.As(); bool is_integer_math = expr.type().is_int(); - bool is_invalid_arith = load_n || var_n || broadcast_n || mod_n || min_n || max_n; + bool is_invalid_arith = + load_n || var_n || broadcast_n || mod_n || min_n || max_n; if (is_integer_math) - is_invalid_arith = is_invalid_arith || div_n || frac_n; // GiNac can't deal with integer division. + is_invalid_arith = is_invalid_arith || div_n || + frac_n; // GiNac can't deal with integer division. if (is_invalid_arith) { RecordExpr(expr); @@ -143,8 +147,9 @@ GiNaC::ex ExprToGinacConverter::operator()(Expr expr) { n->As(); }); - CHECK(complex_nodes.empty()) - << "Ginac converter can only deal with simple math expression, but get some complex nodes" << expr; + CHECK(complex_nodes.empty()) << "Ginac converter can only deal with simple " + "math expression, but get some complex nodes" + << expr; return BuildHelper(expr); } @@ -175,7 +180,8 @@ class GiNaCToExprVisitor : public GiNaC::symbol::visitor, ir::Expr cur; public: - explicit GiNaCToExprVisitor(std::map& repr_to_expr) : repr_to_expr(repr_to_expr) {} + explicit GiNaCToExprVisitor(std::map& repr_to_expr) + : repr_to_expr(repr_to_expr) {} Expr operator()(GiNaC::ex ex) { ex.accept(*this); @@ -184,7 +190,8 @@ class GiNaCToExprVisitor : public GiNaC::symbol::visitor, void visit(const GiNaC::symbol& node) override { auto it = repr_to_expr.find(node.get_name()); - CHECK(it != repr_to_expr.end()) << "node [" << node.get_name() << "] not found"; + CHECK(it != repr_to_expr.end()) + << "node [" << node.get_name() << "] not found"; cur = it->second; } @@ -254,7 +261,9 @@ bool IsPureMath(Expr expr) { IrNodeTy ::Minus, }); - auto complex_nodes = ir::CollectIRNodes(expr, [&](const Expr* n) { return !valid_node_tys.count(n->node_type()); }); + auto complex_nodes = ir::CollectIRNodes(expr, [&](const Expr* n) { + return !valid_node_tys.count(n->node_type()); + }); #ifdef CINN_DEBUG for (auto& node : complex_nodes) { VLOG(3) << "Found " << node->node_type() << " " << Expr(node); @@ -268,7 +277,8 @@ bool MathContainsSymbol(Expr expr, Var symbol) { ExprToGinacConverter expr_converter; auto expr_ex = expr_converter(expr); if (!expr_converter.HasSymbol(symbol->name)) return false; - return !ginac::diff(expr_ex, expr_converter.GetSymbol(symbol->name)).is_zero(); + return !ginac::diff(expr_ex, expr_converter.GetSymbol(symbol->name)) + .is_zero(); } // lhs >= rhs. @@ -291,7 +301,7 @@ std::tuple Solve(Expr lhs, Expr rhs, Var var) { Expr value = converter.GinacToExpr(item.op(1)); // tell the symbol - auto diff = lhs_ex - rhs_ex; + auto diff = lhs_ex - rhs_ex; auto diff_res = ginac::diff(diff, symbol); CHECK(!diff_res.is_zero()); diff --git a/paddle/cinn/common/arithmatic.h b/paddle/cinn/common/arithmatic.h index 5d4e7ea872baa..51bd98a57577d 100644 --- a/paddle/cinn/common/arithmatic.h +++ b/paddle/cinn/common/arithmatic.h @@ -13,7 +13,8 @@ // limitations under the License. /** - * This file includes some arithmatic utilities, such as simplifying/solving a math equation/CINN expression. + * This file includes some arithmatic utilities, such as simplifying/solving a + * math equation/CINN expression. */ #pragma once @@ -38,12 +39,13 @@ namespace common { namespace ginac = GiNaC; -//! Tell whether the expression \p expr contains only simple math calculations, like i*32+j is true, while Load(buf, -//! i)+1 is not due to the Load Node is not math related. +//! Tell whether the expression \p expr contains only simple math calculations, +//! like i*32+j is true, while Load(buf, i)+1 is not due to the Load Node is not +//! math related. bool IsPureMath(Expr expr); -//! Tell whether the expression \p expr contains the expression \symbol, e.g. i*32+32 contains `i`, it also contains -//! `i+1`. +//! Tell whether the expression \p expr contains the expression \symbol, e.g. +//! i*32+32 contains `i`, it also contains `i+1`. bool MathContainsSymbol(Expr expr, Var symbol); //! Solve the equation \p lhs == \p rhs on symbol \p symbol. @@ -64,8 +66,12 @@ struct ExprToGinacConverter { //! Convert GiNaC ex back to CINN expression, should call operator() first. Expr GinacToExpr(const GiNaC::ex& ex); - bool HasSymbol(const std::string& name) const { return repr_to_ginac_.count(name); } - const ginac::symbol& GetSymbol(const std::string& name) const { return repr_to_ginac_.at(name); } + bool HasSymbol(const std::string& name) const { + return repr_to_ginac_.count(name); + } + const ginac::symbol& GetSymbol(const std::string& name) const { + return repr_to_ginac_.at(name); + } private: std::string Repr(const Expr& expr); diff --git a/paddle/cinn/common/arithmatic_test.cc b/paddle/cinn/common/arithmatic_test.cc index 8122196aa6ed6..774d158c29369 100644 --- a/paddle/cinn/common/arithmatic_test.cc +++ b/paddle/cinn/common/arithmatic_test.cc @@ -43,10 +43,10 @@ TEST(GiNaC, simplify) { TEST(GiNaC, diff) { using namespace GiNaC; // NOLINT symbol x("x"), y("y"); - ex e = (x + 1); + ex e = (x + 1); ex e1 = (y + 1); - e = diff(e, x); + e = diff(e, x); e1 = diff(e1, x); LOG(INFO) << "e: " << eval(e); LOG(INFO) << "e1: " << eval(e1); diff --git a/paddle/cinn/common/axis.cc b/paddle/cinn/common/axis.cc index ed18d03934d81..bb1d45e9c239f 100644 --- a/paddle/cinn/common/axis.cc +++ b/paddle/cinn/common/axis.cc @@ -53,7 +53,7 @@ std::string axis_name(int level) { return kAxises[level]; } // upper level - int repeat_num = 1 + (level / kAxises.size()); + int repeat_num = 1 + (level / kAxises.size()); const auto& base_axis = kAxises[level % kAxises.size()]; // if the level greater than kAxis, repeat the axis, like: diff --git a/paddle/cinn/common/bfloat16.h b/paddle/cinn/common/bfloat16.h index 27501008bf5bf..40ed0fed07cd2 100644 --- a/paddle/cinn/common/bfloat16.h +++ b/paddle/cinn/common/bfloat16.h @@ -69,17 +69,17 @@ struct CINN_ALIGN(2) bfloat16 { #ifdef __cplusplus // Constructors - bfloat16() = default; + bfloat16() = default; bfloat16(const bfloat16& o) = default; bfloat16& operator=(const bfloat16& o) = default; - bfloat16(bfloat16&& o) = default; + bfloat16(bfloat16&& o) = default; bfloat16& operator=(bfloat16&& o) = default; - ~bfloat16() = default; + ~bfloat16() = default; __host__ __device__ inline explicit bfloat16(float val) { #if defined(CINN_CUDA_BF16) __nv_bfloat16 tmp = __float2bfloat16(val); - x = *reinterpret_cast(&tmp); + x = *reinterpret_cast(&tmp); #else std::memcpy(&x, reinterpret_cast(&val) + 2, 2); #endif @@ -92,7 +92,8 @@ struct CINN_ALIGN(2) bfloat16 { #endif template - __host__ __device__ inline explicit bfloat16(const T& val) : x(bfloat16(static_cast(val)).x) {} + __host__ __device__ inline explicit bfloat16(const T& val) + : x(bfloat16(static_cast(val)).x) {} // Assignment operators #if defined(CINN_CUDA_BF16) @@ -162,9 +163,10 @@ struct CINN_ALIGN(2) bfloat16 { #ifdef CINN_CUDA_BF16 return __bfloat162float(*reinterpret_cast(&x)); #else - float val = 0.f; + float val = 0.f; uint16_t temp = x; - std::memcpy(reinterpret_cast(&val) + 2, reinterpret_cast(&temp), 2); + std::memcpy( + reinterpret_cast(&val) + 2, reinterpret_cast(&temp), 2); return val; #endif } @@ -175,9 +177,13 @@ struct CINN_ALIGN(2) bfloat16 { } #endif - __host__ __device__ inline explicit operator bool() const { return (x & 0x7fff) != 0; } + __host__ __device__ inline explicit operator bool() const { + return (x & 0x7fff) != 0; + } - __host__ __device__ inline explicit operator int8_t() const { return static_cast(static_cast(*this)); } + __host__ __device__ inline explicit operator int8_t() const { + return static_cast(static_cast(*this)); + } __host__ __device__ inline explicit operator uint8_t() const { return static_cast(static_cast(*this)); @@ -207,11 +213,14 @@ struct CINN_ALIGN(2) bfloat16 { return static_cast(static_cast(*this)); } - __host__ __device__ inline operator double() const { return static_cast(static_cast(*this)); } + __host__ __device__ inline operator double() const { + return static_cast(static_cast(*this)); + } #endif // __cplusplus }; -__host__ __device__ inline bfloat16 operator+(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bfloat16 operator+(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return bfloat16(__hadd(a.to_nv_bfloat16(), b.to_nv_bfloat16())); #else @@ -219,7 +228,8 @@ __host__ __device__ inline bfloat16 operator+(const bfloat16& a, const bfloat16& #endif } -__host__ __device__ inline bfloat16 operator-(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bfloat16 operator-(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return bfloat16(__hsub(a.to_nv_bfloat16(), b.to_nv_bfloat16())); #else @@ -227,7 +237,8 @@ __host__ __device__ inline bfloat16 operator-(const bfloat16& a, const bfloat16& #endif } -__host__ __device__ inline bfloat16 operator*(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bfloat16 operator*(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return bfloat16(__hmul(a.to_nv_bfloat16(), b.to_nv_bfloat16())); #else @@ -235,7 +246,8 @@ __host__ __device__ inline bfloat16 operator*(const bfloat16& a, const bfloat16& #endif } -__host__ __device__ inline bfloat16 operator/(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bfloat16 operator/(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return bfloat16(__hdiv(a.to_nv_bfloat16(), b.to_nv_bfloat16())); #else @@ -280,7 +292,8 @@ __host__ __device__ inline bfloat16 raw_uint16_to_bfloat16(uint16_t a) { } // Comparison operators -__host__ __device__ inline bool operator==(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bool operator==(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __heq(a.to_nv_bfloat16(), b.to_nv_bfloat16()); #else @@ -288,7 +301,8 @@ __host__ __device__ inline bool operator==(const bfloat16& a, const bfloat16& b) #endif } -__host__ __device__ inline bool operator!=(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bool operator!=(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __hne(a.to_nv_bfloat16(), b.to_nv_bfloat16()); #else @@ -296,7 +310,8 @@ __host__ __device__ inline bool operator!=(const bfloat16& a, const bfloat16& b) #endif } -__host__ __device__ inline bool operator<(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bool operator<(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __hlt(a.to_nv_bfloat16(), b.to_nv_bfloat16()); #else @@ -304,7 +319,8 @@ __host__ __device__ inline bool operator<(const bfloat16& a, const bfloat16& b) #endif } -__host__ __device__ inline bool operator<=(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bool operator<=(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __hle(a.to_nv_bfloat16(), b.to_nv_bfloat16()); #else @@ -312,7 +328,8 @@ __host__ __device__ inline bool operator<=(const bfloat16& a, const bfloat16& b) #endif } -__host__ __device__ inline bool operator>(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bool operator>(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __hgt(a.to_nv_bfloat16(), b.to_nv_bfloat16()); #else @@ -320,7 +337,8 @@ __host__ __device__ inline bool operator>(const bfloat16& a, const bfloat16& b) #endif } -__host__ __device__ inline bool operator>=(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bool operator>=(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __hge(a.to_nv_bfloat16(), b.to_nv_bfloat16()); #else @@ -344,7 +362,9 @@ __host__ __device__ inline bool(isinf)(const bfloat16& a) { #endif } -__host__ __device__ inline bool(isfinite)(const bfloat16& a) { return !((isnan)(a)) && !((isinf)(a)); } +__host__ __device__ inline bool(isfinite)(const bfloat16& a) { + return !((isnan)(a)) && !((isinf)(a)); +} __host__ __device__ inline bfloat16(abs)(const bfloat16& a) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -365,36 +385,43 @@ __device__ inline cinn::common::bfloat16 __shfl_sync(unsigned mask, cinn::common::bfloat16 var, int srcLane, int width = warpSize) { - return cinn::common::bfloat16(__shfl_sync(mask, var.to_nv_bfloat16(), srcLane, width)); + return cinn::common::bfloat16( + __shfl_sync(mask, var.to_nv_bfloat16(), srcLane, width)); } -__device__ inline cinn::common::bfloat16 __shfl_up_sync(unsigned mask, - cinn::common::bfloat16 var, - unsigned int delta, - int width = warpSize) { - return cinn::common::bfloat16(__shfl_up_sync(mask, var.to_nv_bfloat16(), delta, width)); +__device__ inline cinn::common::bfloat16 __shfl_up_sync( + unsigned mask, + cinn::common::bfloat16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::bfloat16( + __shfl_up_sync(mask, var.to_nv_bfloat16(), delta, width)); } -__device__ inline cinn::common::bfloat16 __shfl_down_sync(unsigned mask, - cinn::common::bfloat16 var, - unsigned int delta, - int width = warpSize) { - return cinn::common::bfloat16(__shfl_down_sync(mask, var.to_nv_bfloat16(), delta, width)); +__device__ inline cinn::common::bfloat16 __shfl_down_sync( + unsigned mask, + cinn::common::bfloat16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::bfloat16( + __shfl_down_sync(mask, var.to_nv_bfloat16(), delta, width)); } -__device__ inline cinn::common::bfloat16 __shfl_xor_sync(unsigned mask, - cinn::common::bfloat16 var, - int laneMask, - int width = warpSize) { - return cinn::common::bfloat16(__shfl_xor_sync(mask, var.to_nv_bfloat16(), laneMask, width)); +__device__ inline cinn::common::bfloat16 __shfl_xor_sync( + unsigned mask, + cinn::common::bfloat16 var, + int laneMask, + int width = warpSize) { + return cinn::common::bfloat16( + __shfl_xor_sync(mask, var.to_nv_bfloat16(), laneMask, width)); } -__host__ __device__ inline cinn::common::bfloat16 max(const cinn::common::bfloat16& a, - const cinn::common::bfloat16& b) { +__host__ __device__ inline cinn::common::bfloat16 max( + const cinn::common::bfloat16& a, const cinn::common::bfloat16& b) { return a > b ? a : b; } -__host__ __device__ inline cinn::common::bfloat16 min(const cinn::common::bfloat16& a, - const cinn::common::bfloat16& b) { +__host__ __device__ inline cinn::common::bfloat16 min( + const cinn::common::bfloat16& a, const cinn::common::bfloat16& b) { return a < b ? a : b; } #endif // __cplusplus && CINN_CUDA_FP16 diff --git a/paddle/cinn/common/cas.cc b/paddle/cinn/common/cas.cc index 2111c244f33fa..9b8f1cf0f551b 100644 --- a/paddle/cinn/common/cas.cc +++ b/paddle/cinn/common/cas.cc @@ -34,7 +34,9 @@ namespace cinn { namespace common { using namespace ir; // NOLINT -Expr AutoSimplify(Expr u, const absl::flat_hash_map& var_intervals) { +Expr AutoSimplify( + Expr u, + const absl::flat_hash_map& var_intervals) { VLOG(7) << "Begin AutoSimplify: " << u; u = detail::ConvertCinnToCAS(u); absl::flat_hash_map s_var_intervals; @@ -44,7 +46,8 @@ Expr AutoSimplify(Expr u, const absl::flat_hash_map& v Expr e_r = detail::ConvertCinnToCAS(item.second.e_r); s_var_intervals.emplace(item.first, CasInterval(e_l, e_r)); } else { - s_var_intervals.emplace(item.first, CasInterval(item.second.l, item.second.r)); + s_var_intervals.emplace(item.first, + CasInterval(item.second.l, item.second.r)); } } u = CasSimplify(u, s_var_intervals); @@ -70,7 +73,8 @@ int gcd(int a, int b) { return gcd(a, b - a); } -//////// All the following symbolic computation methods are implemented referencing to the book template @@ -141,7 +145,9 @@ bool IsDivisible(const Sum* a, int b); bool IsDivisible(int a, const Product* b) { if (a < 0) return false; for (auto& item : b->operands()) { - if (item.As() && item.As()->value > 0 && IsDivisible(a, item.As()->value)) return true; + if (item.As() && item.As()->value > 0 && + IsDivisible(a, item.As()->value)) + return true; } return false; } @@ -193,13 +199,13 @@ Expr Divide(const Sum* a, int b) { } Expr Divide(const Product* a, int b) { std::vector args; - int i = 0; - int times = -1; + int i = 0; + int times = -1; bool is_divisible = false; for (i = 0; i < a->operands().size(); i++) { auto* a_i = a->operand(i).As(); if (a_i && a_i->value % b == 0) { - times = a_i->value / b; + times = a_i->value / b; is_divisible = true; break; } @@ -259,7 +265,8 @@ Expr CasSimplifyMutator::SimplifyRationalNumber(Expr u) { if (dv > 0) { return FracOp::Make(make_const(Iquot(nv, g)), make_const(Iquot(dv, g))); } else { - return FracOp::Make(make_const(Iquot(-nv, g)), make_const(Iquot(-dv, g))); + return FracOp::Make(make_const(Iquot(-nv, g)), + make_const(Iquot(-dv, g))); } } } @@ -268,7 +275,7 @@ Expr CasSimplifyMutator::SimplifyRationalNumber(Expr u) { Expr SumOrProductGetSingleElementsRec(Expr u) { auto* product = u.As(); - auto* sum = u.As(); + auto* sum = u.As(); if (product && product->operands().size() == 1) { return SumOrProductGetSingleElementsRec(u->operands.front()); } @@ -291,11 +298,12 @@ bool ExprPosCmp::operator()(const Expr& a, const Expr& b) { return a.As<_Var_>()->name < b.As<_Var_>()->name; } - // O-3, if a and b are either both products or both sums, compare by each element similar to lexicographical order. + // O-3, if a and b are either both products or both sums, compare by each + // element similar to lexicographical order. if ((a.As() && b.As()) || (a.As() && b.As())) { auto& aoprs = a->operands; auto& boprs = b->operands; - int m = std::min(aoprs.size(), boprs.size()); + int m = std::min(aoprs.size(), boprs.size()); for (int i = 0; i < m; i++) { // ugly compare representation in string. @@ -357,12 +365,18 @@ bool ExprPosCmp::operator()(const Expr& a, const Expr& b) { return false; } -std::vector CasSimplifyMutator::MergeProduct(const std::vector& p, const std::vector& q) { - return MergeExprs( - p, q, std::bind(&CasSimplifyMutator::SimplifyBinaryProduct, this, std::placeholders::_1, std::placeholders::_2)); +std::vector CasSimplifyMutator::MergeProduct(const std::vector& p, + const std::vector& q) { + return MergeExprs(p, + q, + std::bind(&CasSimplifyMutator::SimplifyBinaryProduct, + this, + std::placeholders::_1, + std::placeholders::_2)); } -std::vector CasSimplifyMutator::SimplifyBinaryProduct(Expr left, Expr right) { +std::vector CasSimplifyMutator::SimplifyBinaryProduct(Expr left, + Expr right) { // SPRDREC-1 if (!left.As() && !right.As()) { auto a = left; @@ -388,23 +402,23 @@ std::vector CasSimplifyMutator::SimplifyBinaryProduct(Expr left, Expr righ Expr cmp_oper; int const_value; if (ai) { - const_oper = a; - cmp_oper = b; + const_oper = a; + cmp_oper = b; const_value = ai->value; } if (af) { - const_oper = a; - cmp_oper = b; + const_oper = a; + cmp_oper = b; const_value = af->value; } if (bi) { - const_oper = b; - cmp_oper = a; + const_oper = b; + cmp_oper = a; const_value = bi->value; } if (bf) { - const_oper = b; - cmp_oper = a; + const_oper = b; + cmp_oper = a; const_value = bf->value; } if (const_value == 0) { @@ -415,44 +429,60 @@ std::vector CasSimplifyMutator::SimplifyBinaryProduct(Expr left, Expr righ auto cmp_max = cmp_oper.As(); if (const_value > 0) { if (cmp_min) { - return {CasSimplify(Min::Make(CasSimplify(Product::Make({cmp_min->a(), const_oper}), var_intervals), - CasSimplify(Product::Make({cmp_min->b(), const_oper}), var_intervals)), - var_intervals)}; + return {CasSimplify( + Min::Make(CasSimplify(Product::Make({cmp_min->a(), const_oper}), + var_intervals), + CasSimplify(Product::Make({cmp_min->b(), const_oper}), + var_intervals)), + var_intervals)}; } if (cmp_max) { - return {CasSimplify(Max::Make(CasSimplify(Product::Make({cmp_max->a(), const_oper}), var_intervals), - CasSimplify(Product::Make({cmp_max->b(), const_oper}), var_intervals)), - var_intervals)}; + return {CasSimplify( + Max::Make(CasSimplify(Product::Make({cmp_max->a(), const_oper}), + var_intervals), + CasSimplify(Product::Make({cmp_max->b(), const_oper}), + var_intervals)), + var_intervals)}; } } else { if (cmp_min) { - return {CasSimplify(Max::Make(CasSimplify(Product::Make({cmp_min->b(), const_oper}), var_intervals), - CasSimplify(Product::Make({cmp_min->a(), const_oper}), var_intervals)), - var_intervals)}; + return {CasSimplify( + Max::Make(CasSimplify(Product::Make({cmp_min->b(), const_oper}), + var_intervals), + CasSimplify(Product::Make({cmp_min->a(), const_oper}), + var_intervals)), + var_intervals)}; } if (cmp_max) { - return {CasSimplify(Min::Make(CasSimplify(Product::Make({cmp_max->b(), const_oper}), var_intervals), - CasSimplify(Product::Make({cmp_max->a(), const_oper}), var_intervals)), - var_intervals)}; + return {CasSimplify( + Min::Make(CasSimplify(Product::Make({cmp_max->b(), const_oper}), + var_intervals), + CasSimplify(Product::Make({cmp_max->a(), const_oper}), + var_intervals)), + var_intervals)}; } } } } { // FracOp related constants. - // NOTE the integer division is weried in C language, 1/2 = 0, that is huge different from a real CAS. + // NOTE the integer division is weried in C language, 1/2 = 0, that is + // huge different from a real CAS. auto* af = a.As(); auto* bf = b.As(); // 1/2 * 2/3 if (af && bf && a->type().is_float()) { - return {CasSimplify(FracOp::Make(Product::Make({af->a(), bf->a()}), Product::Make({af->b(), bf->b()})), + return {CasSimplify(FracOp::Make(Product::Make({af->a(), bf->a()}), + Product::Make({af->b(), bf->b()})), var_intervals)}; } if (af && !bf && a->type().is_float()) { - return {CasSimplify(FracOp::Make(Product::Make({af->a(), b}), af->b()), var_intervals)}; + return {CasSimplify(FracOp::Make(Product::Make({af->a(), b}), af->b()), + var_intervals)}; } if (!af && bf && a->type().is_float()) { - return {CasSimplify(FracOp::Make(Product::Make({bf->a(), a}), bf->b()), var_intervals)}; + return {CasSimplify(FracOp::Make(Product::Make({bf->a(), a}), bf->b()), + var_intervals)}; } } @@ -521,10 +551,12 @@ std::vector CasSimplifyMutator::SimplifyBinaryProduct(Expr left, Expr righ return {left, right}; } -std::vector CasSimplifyMutator::SimplifyProductRec(const std::vector& operands) { - if (operands.size() < 2) return {CasSimplify(operands.front(), var_intervals)}; - auto mid_it = operands.begin() + operands.size() / 2; - auto&& left = SimplifyProductRec(std::vector(operands.begin(), mid_it)); +std::vector CasSimplifyMutator::SimplifyProductRec( + const std::vector& operands) { + if (operands.size() < 2) + return {CasSimplify(operands.front(), var_intervals)}; + auto mid_it = operands.begin() + operands.size() / 2; + auto&& left = SimplifyProductRec(std::vector(operands.begin(), mid_it)); auto&& right = SimplifyProductRec(std::vector(mid_it, operands.end())); return MergeProduct(left, right); } @@ -589,15 +621,16 @@ Expr CasSimplifyMutator::SimplifySum(Expr u) { return Sum::Make(args); } -std::vector CasSimplifyMutator::MergeExprs(const std::vector& p, - const std::vector& q, - const std::function(Expr, Expr)>& binary_merge) { +std::vector CasSimplifyMutator::MergeExprs( + const std::vector& p, + const std::vector& q, + const std::function(Expr, Expr)>& binary_merge) { std::vector res; int li = 0, lj = 0; while (li < p.size() && lj < q.size()) { auto&& p1 = p[li]; auto&& q1 = q[lj]; - auto&& h = binary_merge(p1, q1); + auto&& h = binary_merge(p1, q1); if (h.size() == 2 && h[0] == p1 && h[1] == q1) { ++li; res.emplace_back(std::move(h.front())); @@ -617,7 +650,8 @@ std::vector CasSimplifyMutator::MergeExprs(const std::vector& p, } // This implementation is similar to MergeProduct -std::vector CasSimplifyMutator::MergeSum(const std::vector& p, const std::vector& q) { +std::vector CasSimplifyMutator::MergeSum(const std::vector& p, + const std::vector& q) { #ifdef CINN_DEBUG { std::stringstream ss; @@ -666,24 +700,28 @@ std::vector CasSimplifyMutator::SimplifyBinarySum(Expr left, Expr right) { auto* b_min = b.As(); auto* b_max = b.As(); if (a_min) { - return {CasSimplify(Min::Make(CasSimplify(Sum::Make({a_min->a(), b}), var_intervals), - CasSimplify(Sum::Make({a_min->b(), b}), var_intervals)), - var_intervals)}; + return {CasSimplify( + Min::Make(CasSimplify(Sum::Make({a_min->a(), b}), var_intervals), + CasSimplify(Sum::Make({a_min->b(), b}), var_intervals)), + var_intervals)}; } if (a_max) { - return {CasSimplify(Max::Make(CasSimplify(Sum::Make({a_max->a(), b}), var_intervals), - CasSimplify(Sum::Make({a_max->b(), b}), var_intervals)), - var_intervals)}; + return {CasSimplify( + Max::Make(CasSimplify(Sum::Make({a_max->a(), b}), var_intervals), + CasSimplify(Sum::Make({a_max->b(), b}), var_intervals)), + var_intervals)}; } if (b_min) { - return {CasSimplify(Min::Make(CasSimplify(Sum::Make({b_min->a(), a}), var_intervals), - CasSimplify(Sum::Make({b_min->b(), a}), var_intervals)), - var_intervals)}; + return {CasSimplify( + Min::Make(CasSimplify(Sum::Make({b_min->a(), a}), var_intervals), + CasSimplify(Sum::Make({b_min->b(), a}), var_intervals)), + var_intervals)}; } if (b_max) { - return {CasSimplify(Max::Make(CasSimplify(Sum::Make({b_max->a(), a}), var_intervals), - CasSimplify(Sum::Make({b_max->b(), a}), var_intervals)), - var_intervals)}; + return {CasSimplify( + Max::Make(CasSimplify(Sum::Make({b_max->a(), a}), var_intervals), + CasSimplify(Sum::Make({b_max->b(), a}), var_intervals)), + var_intervals)}; } // case 2 @@ -699,20 +737,25 @@ std::vector CasSimplifyMutator::SimplifyBinarySum(Expr left, Expr right) { auto* am = a.As(); auto* bm = b.As(); if (am && bm) { - if (am->b() == bm->b() && ProductGetNonConstantPart(am->a()) == ProductGetNonConstantPart(bm->a())) { - return {CasSimplify(Mod::Make(Sum::Make({am->a(), bm->a()}), am->b()), var_intervals)}; + if (am->b() == bm->b() && ProductGetNonConstantPart(am->a()) == + ProductGetNonConstantPart(bm->a())) { + return {CasSimplify(Mod::Make(Sum::Make({am->a(), bm->a()}), am->b()), + var_intervals)}; } } } // case 3 - // Here is different from SimplifySumRec, to deal with cases like 3x + (-2x) = 2x + // Here is different from SimplifySumRec, to deal with cases like 3x + (-2x) + // = 2x auto a_non_constant = ProductGetNonConstantPart(a); auto b_non_constant = ProductGetNonConstantPart(b); - if (a_non_constant.defined() && b_non_constant.defined() && a_non_constant == b_non_constant) { + if (a_non_constant.defined() && b_non_constant.defined() && + a_non_constant == b_non_constant) { VLOG(7) << "a " << a; VLOG(7) << "b " << b; - Expr s = SimplifySum(Sum::Make({ProductGetConstantPart(a), ProductGetConstantPart(b)})); + Expr s = SimplifySum( + Sum::Make({ProductGetConstantPart(a), ProductGetConstantPart(b)})); Expr p = Product::Make({s, ProductGetNonConstantPart(a)}); return {CasSimplify(p, var_intervals)}; } @@ -755,7 +798,8 @@ std::vector CasSimplifyMutator::SimplifyBinarySum(Expr left, Expr right) { } // The implementation is similar to SimplifyProductRec -std::vector CasSimplifyMutator::SimplifySumRec(const std::vector& operands) { +std::vector CasSimplifyMutator::SimplifySumRec( + const std::vector& operands) { #ifdef CINN_DEBUG { std::stringstream ss; @@ -766,9 +810,10 @@ std::vector CasSimplifyMutator::SimplifySumRec(const std::vector& op } #endif CHECK(!operands.empty()); - if (operands.size() < 2) return {CasSimplify(operands.front(), var_intervals)}; - auto mid_it = operands.begin() + operands.size() / 2; - auto&& left = SimplifySumRec(std::vector(operands.begin(), mid_it)); + if (operands.size() < 2) + return {CasSimplify(operands.front(), var_intervals)}; + auto mid_it = operands.begin() + operands.size() / 2; + auto&& left = SimplifySumRec(std::vector(operands.begin(), mid_it)); auto&& right = SimplifySumRec(std::vector(mid_it, operands.end())); return MergeSum(left, right); } @@ -782,7 +827,10 @@ void CasSimplifyMutator::AddBaseAndSimplify(Expr* base, Expr bound) { *base = CasSimplify(*base, var_intervals); } -void CasSimplifyMutator::UnfoldBound(Expr* lower_bound, Expr* upper_bound, Expr var, bool unfold_const_bound) { +void CasSimplifyMutator::UnfoldBound(Expr* lower_bound, + Expr* upper_bound, + Expr var, + bool unfold_const_bound) { CHECK(lower_bound); CHECK(upper_bound); auto v_var = var.As<_Var_>(); @@ -810,12 +858,15 @@ void CasSimplifyMutator::UnfoldBound(Expr* lower_bound, Expr* upper_bound, Expr } } -bool CasSimplifyMutator::GetVarBound(Expr* lower_bound, Expr* upper_bound, Expr var, bool unfold_const_bound) { +bool CasSimplifyMutator::GetVarBound(Expr* lower_bound, + Expr* upper_bound, + Expr var, + bool unfold_const_bound) { CHECK(lower_bound); CHECK(upper_bound); - auto v_var = var.As<_Var_>(); + auto v_var = var.As<_Var_>(); auto v_product = var.As(); - auto v_frac = var.As(); + auto v_frac = var.As(); if (v_var && (var_intervals.count(v_var->name) || !unfold_const_bound)) { UnfoldBound(lower_bound, upper_bound, var, unfold_const_bound); return true; @@ -823,9 +874,9 @@ bool CasSimplifyMutator::GetVarBound(Expr* lower_bound, Expr* upper_bound, Expr // only deal with 2*x Expr p_lower_bound; Expr p_upper_bound; - Expr const_oper = ProductGetConstantPart(var); + Expr const_oper = ProductGetConstantPart(var); Expr non_const_oper = ProductGetNonConstantPart(var); - auto v_var = non_const_oper.As<_Var_>(); + auto v_var = non_const_oper.As<_Var_>(); if (v_var && var_intervals.count(v_var->name)) { Expr v_lower, v_upper; UnfoldBound(&v_lower, &v_upper, non_const_oper, unfold_const_bound); @@ -847,8 +898,8 @@ bool CasSimplifyMutator::GetVarBound(Expr* lower_bound, Expr* upper_bound, Expr Expr p_lower_bound; Expr p_upper_bound; Expr non_const_oper = v_frac->a(); - Expr const_oper = v_frac->b(); - auto v_var = non_const_oper.As<_Var_>(); + Expr const_oper = v_frac->b(); + auto v_var = non_const_oper.As<_Var_>(); if (v_var && var_intervals.count(v_var->name)) { Expr v_lower, v_upper; UnfoldBound(&v_lower, &v_upper, non_const_oper, unfold_const_bound); @@ -869,7 +920,10 @@ bool CasSimplifyMutator::GetVarBound(Expr* lower_bound, Expr* upper_bound, Expr return false; } -bool CasSimplifyMutator::GetOperandBound(Expr* lower_bound, Expr* upper_bound, Expr v, bool unfold_const_bound) { +bool CasSimplifyMutator::GetOperandBound(Expr* lower_bound, + Expr* upper_bound, + Expr v, + bool unfold_const_bound) { // only support simple operand of int, var and var's product with int CHECK(lower_bound); CHECK(upper_bound); @@ -884,7 +938,10 @@ bool CasSimplifyMutator::GetOperandBound(Expr* lower_bound, Expr* upper_bound, E return false; } -bool CasSimplifyMutator::GetSumBound(Expr* lower_bound, Expr* upper_bound, Expr sum, bool unfold_const_bound) { +bool CasSimplifyMutator::GetSumBound(Expr* lower_bound, + Expr* upper_bound, + Expr sum, + bool unfold_const_bound) { // only support sum of int, var and var's product with int CHECK(lower_bound); CHECK(upper_bound); @@ -894,7 +951,8 @@ bool CasSimplifyMutator::GetSumBound(Expr* lower_bound, Expr* upper_bound, Expr Expr sum_lower_bound, sum_upper_bound; if (bound_sum) { for (Expr& v : bound_sum->operands()) { - if (!GetOperandBound(&sum_lower_bound, &sum_upper_bound, v, unfold_const_bound)) { + if (!GetOperandBound( + &sum_lower_bound, &sum_upper_bound, v, unfold_const_bound)) { get_bound = false; break; } @@ -908,8 +966,12 @@ bool CasSimplifyMutator::GetSumBound(Expr* lower_bound, Expr* upper_bound, Expr return false; } -bool CasSimplifyMutator::GetExprBound(Expr* lower_bound, Expr* upper_bound, Expr expr, bool unfold_const_bound) { - // only support min's operands as sum, int or var or var's product with int or min/max +bool CasSimplifyMutator::GetExprBound(Expr* lower_bound, + Expr* upper_bound, + Expr expr, + bool unfold_const_bound) { + // only support min's operands as sum, int or var or var's product with int or + // min/max auto bound_sum = expr.As(); auto bound_min = expr.As(); auto bound_max = expr.As(); @@ -920,37 +982,57 @@ bool CasSimplifyMutator::GetExprBound(Expr* lower_bound, Expr* upper_bound, Expr get_bound = GetMinBound(lower_bound, upper_bound, expr, unfold_const_bound); } else if (bound_max) { get_bound = GetMaxBound(lower_bound, upper_bound, expr, unfold_const_bound); - } else if (!GetOperandBound(lower_bound, upper_bound, expr, unfold_const_bound)) { + } else if (!GetOperandBound( + lower_bound, upper_bound, expr, unfold_const_bound)) { return false; } return get_bound; } -bool CasSimplifyMutator::GetMinBound(Expr* lower_bound, Expr* upper_bound, Expr min, bool unfold_const_bound) { - // only support min's operands as sum, int or var or var's product with int or min/max +bool CasSimplifyMutator::GetMinBound(Expr* lower_bound, + Expr* upper_bound, + Expr min, + bool unfold_const_bound) { + // only support min's operands as sum, int or var or var's product with int or + // min/max auto bound_min = min.As(); CHECK(bound_min); bool get_bound = true; Expr a_lower_bound, a_upper_bound, b_lower_bound, b_upper_bound; - get_bound = get_bound && GetExprBound(&a_lower_bound, &a_upper_bound, bound_min->a(), unfold_const_bound) && - GetExprBound(&b_lower_bound, &b_upper_bound, bound_min->b(), unfold_const_bound); + get_bound = + get_bound && + GetExprBound( + &a_lower_bound, &a_upper_bound, bound_min->a(), unfold_const_bound) && + GetExprBound( + &b_lower_bound, &b_upper_bound, bound_min->b(), unfold_const_bound); if (get_bound) { - *lower_bound = CasSimplify(Min::Make(a_lower_bound, b_lower_bound), var_intervals); - *upper_bound = CasSimplify(Min::Make(a_upper_bound, b_upper_bound), var_intervals); + *lower_bound = + CasSimplify(Min::Make(a_lower_bound, b_lower_bound), var_intervals); + *upper_bound = + CasSimplify(Min::Make(a_upper_bound, b_upper_bound), var_intervals); } return get_bound; } -bool CasSimplifyMutator::GetMaxBound(Expr* lower_bound, Expr* upper_bound, Expr max, bool unfold_const_bound) { +bool CasSimplifyMutator::GetMaxBound(Expr* lower_bound, + Expr* upper_bound, + Expr max, + bool unfold_const_bound) { auto bound_max = max.As(); CHECK(bound_max); bool get_bound = true; Expr a_lower_bound, a_upper_bound, b_lower_bound, b_upper_bound; - get_bound = get_bound && GetExprBound(&a_lower_bound, &a_upper_bound, bound_max->a(), unfold_const_bound) && - GetExprBound(&b_lower_bound, &b_upper_bound, bound_max->b(), unfold_const_bound); + get_bound = + get_bound && + GetExprBound( + &a_lower_bound, &a_upper_bound, bound_max->a(), unfold_const_bound) && + GetExprBound( + &b_lower_bound, &b_upper_bound, bound_max->b(), unfold_const_bound); if (get_bound) { - *lower_bound = CasSimplify(Max::Make(a_lower_bound, b_lower_bound), var_intervals); - *upper_bound = CasSimplify(Max::Make(a_upper_bound, b_upper_bound), var_intervals); + *lower_bound = + CasSimplify(Max::Make(a_lower_bound, b_lower_bound), var_intervals); + *upper_bound = + CasSimplify(Max::Make(a_upper_bound, b_upper_bound), var_intervals); } return get_bound; } @@ -959,27 +1041,32 @@ bool CasSimplifyMutator::SimplifySpecificSumMod(Expr* result, Expr a, Expr b) { // case1: (32+(-x))%33 = 32-x%33 (0<=x<=32) // case2: (x-32)%33 = x%33 - 32%33 (0<=x<=32) auto a_sum = a.As(); - auto b_i = b.As(); + auto b_i = b.As(); if (!a_sum || !b_i) { return false; } // if 0 < b < 3, (3a+b) % 6 = (3a % 6) + (b % 6) if (a_sum->operands().size() == 2) { a_sum->operands()[0] = CasSimplify(a_sum->operands()[0], var_intervals); - auto sum_a_prod = a_sum->operands()[0].As(); - auto sum_b_var = a_sum->operands()[1].As<_Var_>(); + auto sum_a_prod = a_sum->operands()[0].As(); + auto sum_b_var = a_sum->operands()[1].As<_Var_>(); if (sum_a_prod && sum_b_var && var_intervals.count(sum_b_var->name)) { auto sum_a_prod_b_int = sum_a_prod->operand(1).As(); - if (sum_a_prod_b_int) std::swap(sum_a_prod->operand(0), sum_a_prod->operand(1)); + if (sum_a_prod_b_int) + std::swap(sum_a_prod->operand(0), sum_a_prod->operand(1)); auto sum_a_prod_a_int = sum_a_prod->operand(0).As(); - auto& interval = var_intervals.at(sum_b_var->name); - int b_abs = std::abs(b_i->value); - int sum_prod_a_abs = std::abs(sum_a_prod_a_int->value); + auto& interval = var_intervals.at(sum_b_var->name); + int b_abs = std::abs(b_i->value); + int sum_prod_a_abs = std::abs(sum_a_prod_a_int->value); if (sum_a_prod_a_int && (b_abs % sum_prod_a_abs == 0)) { - if (std::abs(interval.l) < sum_prod_a_abs && std::abs(interval.r) < sum_prod_a_abs) { - *result = CasSimplify(Sum::Make({CasSimplify(Mod::Make(a_sum->operands()[0], b), var_intervals), - CasSimplify(Mod::Make(a_sum->operands()[1], b), var_intervals)}), - var_intervals); + if (std::abs(interval.l) < sum_prod_a_abs && + std::abs(interval.r) < sum_prod_a_abs) { + *result = CasSimplify( + Sum::Make({CasSimplify(Mod::Make(a_sum->operands()[0], b), + var_intervals), + CasSimplify(Mod::Make(a_sum->operands()[1], b), + var_intervals)}), + var_intervals); return true; } } @@ -994,7 +1081,7 @@ bool CasSimplifyMutator::SimplifySpecificSumMod(Expr* result, Expr a, Expr b) { Expr upper_bound; Expr rest_oper; bool can_simplify = true; - bool has_int = false; + bool has_int = false; // fold only the expr bound(may contains the var) and try to simplify the var Expr unfolded_lower_bound, unfolded_upper_bound; for (Expr& v : a_sum->operands()) { @@ -1009,9 +1096,12 @@ bool CasSimplifyMutator::SimplifySpecificSumMod(Expr* result, Expr a, Expr b) { break; } } - can_simplify = can_simplify && has_int && std::abs(const_value) % b_i->value == b_i->value - 1 && - lower_bound.defined() && upper_bound.defined() && rest_oper.defined(); - // further infer the vars' bound by the intervals infos, try to get the constant + can_simplify = can_simplify && has_int && + std::abs(const_value) % b_i->value == b_i->value - 1 && + lower_bound.defined() && upper_bound.defined() && + rest_oper.defined(); + // further infer the vars' bound by the intervals infos, try to get the + // constant if (can_simplify) { std::vector bounds = {lower_bound, upper_bound}; for (int i = 0; i < bounds.size(); ++i) { @@ -1031,9 +1121,11 @@ bool CasSimplifyMutator::SimplifySpecificSumMod(Expr* result, Expr a, Expr b) { // case1: (32+(-x))%33 = 32-x%33 (0<=x<=32) // case2: (x-32)%33 = x%33 - 32%33 (0<=x<=32) can_simplify = can_simplify && lower_bound.is_constant(); - bool case1 = can_simplify && const_value >= 0 && lower_bound.get_constant() >= -const_value && + bool case1 = can_simplify && const_value >= 0 && + lower_bound.get_constant() >= -const_value && upper_bound.is_constant() && upper_bound.get_constant() <= 0; - bool case2 = can_simplify && const_value <= 0 && lower_bound.get_constant() >= 0 && upper_bound.is_constant() && + bool case2 = can_simplify && const_value <= 0 && + lower_bound.get_constant() >= 0 && upper_bound.is_constant() && upper_bound.get_constant() <= -const_value; can_simplify = can_simplify && (case1 || case2); if (can_simplify) { @@ -1043,7 +1135,10 @@ bool CasSimplifyMutator::SimplifySpecificSumMod(Expr* result, Expr a, Expr b) { } else { const_expr = make_const(b->type(), const_value % b_i->value); } - *result = CasSimplify(Sum::Make({const_expr, CasSimplify(Mod::Make(rest_oper, b), var_intervals)}), var_intervals); + *result = CasSimplify( + Sum::Make( + {const_expr, CasSimplify(Mod::Make(rest_oper, b), var_intervals)}), + var_intervals); return true; } return false; @@ -1051,14 +1146,17 @@ bool CasSimplifyMutator::SimplifySpecificSumMod(Expr* result, Expr a, Expr b) { } // Return if the var's interval is nonnegative. -inline bool IsVarNonnegative(const absl::flat_hash_map& var_intervals, - const std::string& var_name) { +inline bool IsVarNonnegative( + const absl::flat_hash_map& var_intervals, + const std::string& var_name) { return var_intervals.count(var_name) && var_intervals.at(var_name).l >= 0; } -// Return if the var is binded with thread or block in cuda(which implies it is non-negative). +// Return if the var is binded with thread or block in cuda(which implies it is +// non-negative). inline bool IsVarBinded(const std::string& var_name) { - return utils::Startswith(var_name, "threadIdx") || utils::Startswith(var_name, "blockIdx"); + return utils::Startswith(var_name, "threadIdx") || + utils::Startswith(var_name, "blockIdx"); } /** @@ -1068,14 +1166,17 @@ inline bool IsVarBinded(const std::string& var_name) { * @param var_intervals intervals of each var. * @return if exprs are still all nonnegative vars. */ -inline bool IsVarAllNonnegative(bool all_nonnegative_var, - _Var_* arg_var, - const absl::flat_hash_map& var_intervals) { - // All exprs all nonnegative vars if previous exprs are nonnegative vars(all_nonnegative_var == true) and this expr is - // a var (arg_var != nullptr) and (this var's interval is nonnegative or this var is binded to thread or block in - // cuda). +inline bool IsVarAllNonnegative( + bool all_nonnegative_var, + _Var_* arg_var, + const absl::flat_hash_map& var_intervals) { + // All exprs all nonnegative vars if previous exprs are nonnegative + // vars(all_nonnegative_var == true) and this expr is a var (arg_var != + // nullptr) and (this var's interval is nonnegative or this var is binded to + // thread or block in cuda). return all_nonnegative_var && arg_var && - (IsVarNonnegative(var_intervals, arg_var->name) || IsVarBinded(arg_var->name)); + (IsVarNonnegative(var_intervals, arg_var->name) || + IsVarBinded(arg_var->name)); } Expr CasSimplifyMutator::SimplifyMod(Expr u) { @@ -1086,12 +1187,12 @@ Expr CasSimplifyMutator::SimplifyMod(Expr u) { auto a = CasSimplify(node->a(), var_intervals); auto b = CasSimplify(node->b(), var_intervals); - auto* a_i = a.As(); + auto* a_i = a.As(); auto* a_product = a.As(); - auto* a_sum = a.As(); - auto* a_var = a.As<_Var_>(); - auto* a_mod = a.As(); - auto* a_add = a.As(); + auto* a_sum = a.As(); + auto* a_var = a.As<_Var_>(); + auto* a_mod = a.As(); + auto* a_add = a.As(); auto* b_i = b.As(); @@ -1115,10 +1216,12 @@ Expr CasSimplifyMutator::SimplifyMod(Expr u) { if (a_op_int % b_i->value == 0) return make_const(a_product->type(), 0); // case: (x * y * 2) % 6 = ((x * y) % 3) * 2 if (b_i->value % a_op_int == 0) { - int new_b = b_i->value / a_op_int; + int new_b = b_i->value / a_op_int; std::vector a_operands = a_product->operands(); a_operands.erase(a_operands.begin() + i); - return Product::Make({SimplifyMod(Mod::Make(Product::Make(a_operands), Expr(new_b))), Expr(a_op_int)}); + return Product::Make( + {SimplifyMod(Mod::Make(Product::Make(a_operands), Expr(new_b))), + Expr(a_op_int)}); } } } @@ -1140,11 +1243,12 @@ Expr CasSimplifyMutator::SimplifyMod(Expr u) { if (b_i && a_var && var_intervals.count(a_var->name)) { auto& interval = var_intervals.at(a_var->name); - int b_abs = std::abs(b_i->value); + int b_abs = std::abs(b_i->value); // x\in[1, 3] % 4 = x if (std::abs(interval.l) < b_abs && std::abs(interval.r) < b_abs) return a; // [3,3] % 3 = 0 - if (interval.l == interval.r && interval.l % b_abs == 0) return make_const(b_i->type(), 0); + if (interval.l == interval.r && interval.l % b_abs == 0) + return make_const(b_i->type(), 0); } if (a_product && b_i) { @@ -1173,13 +1277,16 @@ Expr CasSimplifyMutator::SimplifyMod(Expr u) { bool all_nonnegative_var = true; bool all_nonnegative_int = true; for (int i = 0; i < sum_args.size(); i++) { - auto* arg_var = sum_args[i].As<_Var_>(); - all_nonnegative_var = IsVarAllNonnegative(all_nonnegative_var, arg_var, var_intervals); - auto* arg_int = sum_args[i].As(); - all_nonnegative_int = all_nonnegative_int && arg_int && arg_int->value >= 0; + auto* arg_var = sum_args[i].As<_Var_>(); + all_nonnegative_var = + IsVarAllNonnegative(all_nonnegative_var, arg_var, var_intervals); + auto* arg_int = sum_args[i].As(); + all_nonnegative_int = + all_nonnegative_int && arg_int && arg_int->value >= 0; } VLOG(4) << all_nonnegative_var << " " << all_nonnegative_int; - if (all_nonnegative_var) return SimplifyMod(Mod::Make(Sum::Make(sum_args), b)); + if (all_nonnegative_var) + return SimplifyMod(Mod::Make(Sum::Make(sum_args), b)); if (all_nonnegative_int) { int sum_value = 0; for (auto& i : sum_args) sum_value += i.As()->value; @@ -1207,8 +1314,8 @@ Expr CasSimplifyMutator::SimplifyMinAndMax(Expr u) { auto* u_max = u.As(); auto* u_min = u.As(); if (u_max) { - Expr a = CasSimplify(u_max->a(), var_intervals); - Expr b = CasSimplify(u_max->b(), var_intervals); + Expr a = CasSimplify(u_max->a(), var_intervals); + Expr b = CasSimplify(u_max->b(), var_intervals); bool is_a_const = a.is_constant(); bool is_b_const = b.is_constant(); if (is_a_const && is_b_const) { @@ -1217,40 +1324,44 @@ Expr CasSimplifyMutator::SimplifyMinAndMax(Expr u) { Expr lower_bound, upper_bound; Expr const_operand, non_const_operand; if (is_a_const) { - const_operand = a; + const_operand = a; non_const_operand = b; } if (is_b_const) { - const_operand = b; + const_operand = b; non_const_operand = a; } if (const_operand.defined() && non_const_operand.defined()) { auto const_size = const_operand.get_constant(); // unfold var with bounds if (GetExprBound(&lower_bound, &upper_bound, non_const_operand, true)) { - // if non_const_operand's lower_bound is larger than const_operand, then non_const_operand must be larger than - // const_operand - if (lower_bound.is_constant() && const_size <= lower_bound.get_constant()) { + // if non_const_operand's lower_bound is larger than const_operand, then + // non_const_operand must be larger than const_operand + if (lower_bound.is_constant() && + const_size <= lower_bound.get_constant()) { return non_const_operand; } - // if non_const_operand's upper_bound is smaller than a, then const_operand must be larger than - // non_const_operand - if (upper_bound.is_constant() && const_size >= upper_bound.get_constant()) { + // if non_const_operand's upper_bound is smaller than a, then + // const_operand must be larger than non_const_operand + if (upper_bound.is_constant() && + const_size >= upper_bound.get_constant()) { return const_operand; } } // not unfold var for var may be eliminated in the caculation if (GetExprBound(&lower_bound, &upper_bound, non_const_operand, false)) { - // if non_const_operand's lower_bound is larger than const_operand, then non_const_operand must be larger than - // const_operand + // if non_const_operand's lower_bound is larger than const_operand, then + // non_const_operand must be larger than const_operand lower_bound = CasSimplify(lower_bound, var_intervals); upper_bound = CasSimplify(upper_bound, var_intervals); - if (lower_bound.is_constant() && const_size <= lower_bound.get_constant()) { + if (lower_bound.is_constant() && + const_size <= lower_bound.get_constant()) { return non_const_operand; } - // if non_const_operand's upper_bound is smaller than a, then const_operand must be larger than - // non_const_operand - if (upper_bound.is_constant() && const_size >= upper_bound.get_constant()) { + // if non_const_operand's upper_bound is smaller than a, then + // const_operand must be larger than non_const_operand + if (upper_bound.is_constant() && + const_size >= upper_bound.get_constant()) { return const_operand; } } @@ -1259,8 +1370,8 @@ Expr CasSimplifyMutator::SimplifyMinAndMax(Expr u) { } if (u_min) { - Expr a = CasSimplify(u_min->a(), var_intervals); - Expr b = CasSimplify(u_min->b(), var_intervals); + Expr a = CasSimplify(u_min->a(), var_intervals); + Expr b = CasSimplify(u_min->b(), var_intervals); bool is_a_const = a.is_constant(); bool is_b_const = b.is_constant(); if (is_a_const && is_b_const) { @@ -1269,36 +1380,40 @@ Expr CasSimplifyMutator::SimplifyMinAndMax(Expr u) { Expr lower_bound, upper_bound; Expr const_operand, non_const_operand; if (is_a_const) { - const_operand = a; + const_operand = a; non_const_operand = b; } if (is_b_const) { - const_operand = b; + const_operand = b; non_const_operand = a; } if (const_operand.defined() && non_const_operand.defined()) { auto const_size = const_operand.get_constant(); if (GetExprBound(&lower_bound, &upper_bound, non_const_operand, true)) { - // if non_const_operand's lower_bound is larger than const_operand, then non_const_operand must be larger than - // const_operand - if (lower_bound.is_constant() && const_size <= lower_bound.get_constant()) { + // if non_const_operand's lower_bound is larger than const_operand, then + // non_const_operand must be larger than const_operand + if (lower_bound.is_constant() && + const_size <= lower_bound.get_constant()) { return const_operand; } - // if non_const_operand's upper_bound is smaller than a, then const_operand must be larger than - // non_const_operand - if (upper_bound.is_constant() && const_size >= upper_bound.get_constant()) { + // if non_const_operand's upper_bound is smaller than a, then + // const_operand must be larger than non_const_operand + if (upper_bound.is_constant() && + const_size >= upper_bound.get_constant()) { return non_const_operand; } } if (GetExprBound(&lower_bound, &upper_bound, non_const_operand, false)) { - // if non_const_operand's lower_bound is larger than const_operand, then non_const_operand must be larger than - // const_operand - if (lower_bound.is_constant() && const_size <= lower_bound.get_constant()) { + // if non_const_operand's lower_bound is larger than const_operand, then + // non_const_operand must be larger than const_operand + if (lower_bound.is_constant() && + const_size <= lower_bound.get_constant()) { return const_operand; } - // if non_const_operand's upper_bound is smaller than a, then const_operand must be larger than - // non_const_operand - if (upper_bound.is_constant() && const_size >= upper_bound.get_constant()) { + // if non_const_operand's upper_bound is smaller than a, then + // const_operand must be larger than non_const_operand + if (upper_bound.is_constant() && + const_size >= upper_bound.get_constant()) { return non_const_operand; } } @@ -1333,9 +1448,10 @@ Expr CasSimplifyMutator::SimplifyCmp(Expr u) { } /** - * deal with index's div-mod add simplification, tempory solution, not cover all situations. - * case 1: (m / n) * n + m % n = m (m, n's type is int) - * case 2: (m / n1) * n3 + (n2 * m) % n3 = n2 * m if n3 = n1 * n2 (m, n1, n2, n3's type is int) + * deal with index's div-mod add simplification, tempory solution, not cover all + * situations. case 1: (m / n) * n + m % n = m (m, n's type is int) case 2: (m / + * n1) * n3 + (n2 * m) % n3 = n2 * m if n3 = n1 * n2 (m, n1, n2, n3's type is + * int) */ Expr CasSimplifyMutator::SimplifySpecificSum(Expr tmp) { auto sum = tmp.As(); @@ -1343,44 +1459,46 @@ Expr CasSimplifyMutator::SimplifySpecificSum(Expr tmp) { return tmp; } if (sum->operands().size() == 1U) return sum->operand(0); - Expr left = sum->operand(0); - Expr right = sum->operand(1); - auto left_mod = left.As(); + Expr left = sum->operand(0); + Expr right = sum->operand(1); + auto left_mod = left.As(); auto right_mod = right.As(); - auto left_mul = left.As(); + auto left_mul = left.As(); auto right_mul = right.As(); - auto left_div = left.As(); + auto left_div = left.As(); auto right_div = right.As(); // normalize to left mul and right mod if (right_mul && left_mod) { - left_mul = right_mul; + left_mul = right_mul; right_mod = left_mod; } // normalize to left div and right mod if (right_div && left_mod) { - left_div = right_div; + left_div = right_div; right_mod = left_mod; } if (!right_mod || (!left_mul && !left_div)) { return tmp; } CHECK_GE(right_mod->operands().size(), 2U); - Expr mod_left = right_mod->operand(0); + Expr mod_left = right_mod->operand(0); Expr mod_right = right_mod->operand(1); if (!mod_left->type().is_integer() || !mod_right->type().is_integer()) { return tmp; } if (left_mul) { // case 1: (m / n) * n + m % n = m (m, n's type is int) - // case 2: (m / n1) * n3 + (n2 * m) % n3 = n2 * m if n3 = n1 * n2 (m, n1, n2, n3's type is int) + // case 2: (m / n1) * n3 + (n2 * m) % n3 = n2 * m if n3 = n1 * n2 (m, n1, + // n2, n3's type is int) CHECK_GE(left_mul->operands().size(), 2U); - Expr mul_left = left_mul->operand(0); + Expr mul_left = left_mul->operand(0); Expr mul_right = left_mul->operand(1); // handle the case1 : n * (m / n) + m % n = (m / n) * n + m % n = m - // handle the case2 : n3 * (m / n1) + (n2 * m) % n3 = (m / n1) * n3 + (n2 * m) % n3 = n2 * m if n3 = n1 * n2 + // handle the case2 : n3 * (m / n1) + (n2 * m) % n3 = (m / n1) * n3 + (n2 * + // m) % n3 = n2 * m if n3 = n1 * n2 if (MathEqual(mod_right, mul_left)) { - mul_left = left_mul->operand(1); + mul_left = left_mul->operand(1); mul_right = left_mul->operand(0); } else if (!MathEqual(mod_right, mul_right)) { return tmp; @@ -1390,7 +1508,7 @@ Expr CasSimplifyMutator::SimplifySpecificSum(Expr tmp) { return tmp; } CHECK_GE(div->operands().size(), 2U); - Expr div_left = div->operand(0); + Expr div_left = div->operand(0); Expr div_right = div->operand(1); if (!div_left->type().is_integer() || !div_right->type().is_integer()) { return tmp; @@ -1416,7 +1534,7 @@ Expr CasSimplifyMutator::operator()(Expr u) { if (u.is_constant() || u.As<_Var_>()) return u; if (u.As()) { - u = SimplifyFracOp(u); + u = SimplifyFracOp(u); auto tmp = FurtherSimplifyFracWithInterval(u, var_intervals); if (!tmp.same_as(u)) return operator()(tmp); return u; @@ -1428,10 +1546,11 @@ Expr CasSimplifyMutator::operator()(Expr u) { if (u.As()) { auto tmp = detail::SumOrProductGetSingleElementsRec(SimplifySum(u)); - // deal with index's div-mod add simplification, tempory solution, not cover all situations. - // case 1: (m / n) * n + m % n = m (m, n's type is int) - // case 2: (m / n1) * n3 + (n2 * m) % n3 = n2 * m if n3 = n1 * n2 (m, n1, n2, n3's type is int) - // case 3: m / n2 + (n1 * m) % n3 = n1 * m if n3 = n1 * n2 (m, n1, n2, n3's type is int) + // deal with index's div-mod add simplification, tempory solution, not cover + // all situations. case 1: (m / n) * n + m % n = m (m, n's type is int) case + // 2: (m / n1) * n3 + (n2 * m) % n3 = n2 * m if n3 = n1 * n2 (m, n1, n2, + // n3's type is int) case 3: m / n2 + (n1 * m) % n3 = n1 * m if n3 = n1 * n2 + // (m, n1, n2, n3's type is int) return SimplifySpecificSum(tmp); } @@ -1456,8 +1575,8 @@ Expr CasSimplifyMutator::operator()(Expr u) { } bool CASasSymbol(Expr expr) { - auto* load_n = expr.As(); - auto* var_n = expr.As<_Var_>(); + auto* load_n = expr.As(); + auto* var_n = expr.As<_Var_>(); auto* broadcast_n = expr.As(); return load_n || var_n || broadcast_n; @@ -1536,7 +1655,7 @@ Expr ConvertCinnToCAS(Expr expr) { return; } - b = Product::Make({make_const(b->type(), -1), b}); + b = Product::Make({make_const(b->type(), -1), b}); *expr = Sum::Make({a, b}); } @@ -1585,10 +1704,10 @@ Expr ConvertCinnToCAS(Expr expr) { } /** - * @brief Given an expr, visit it. If there is an ir::Min and its operands are 1 constant value and 1 inconstant value, - * return the constant min value. - * For example, if a < min(5, b), then we get a < 5 and a < b. Using a < 5 to simplify the condition ensures - * correctness, though not sufficient. + * @brief Given an expr, visit it. If there is an ir::Min and its operands are 1 + * constant value and 1 inconstant value, return the constant min value. For + * example, if a < min(5, b), then we get a < 5 and a < b. Using a < 5 to + * simplify the condition ensures correctness, though not sufficient. */ Expr ReplaceMinToConstant(Expr expr) { Expr copied = optim::IRCopy(expr); @@ -1620,8 +1739,8 @@ Expr ReplaceMinToConstant(Expr expr) { } /** - * @brief Given an expr, visit it. If there is an ir::Max and its operands are 1 constant value and 1 inconstant value, - * return the constant max value. + * @brief Given an expr, visit it. If there is an ir::Max and its operands are 1 + * constant value and 1 inconstant value, return the constant max value. */ Expr ReplaceMaxToConstant(Expr expr) { Expr copied = optim::IRCopy(expr); @@ -1746,7 +1865,8 @@ Expr ConvertCasToCinn(Expr expr) { bool IsExprCasCompatible(Expr expr) { auto teller = [](const Expr* expr) { - return expr->As() || expr->As() || expr->As() || expr->As
(); + return expr->As() || expr->As() || expr->As() || + expr->As
(); }; return ir::CollectIRNodes(expr, teller).empty(); } @@ -1756,10 +1876,12 @@ Expr DividePartially(Sum* a, int b) { std::vector external_sum_args, sum_args; for (auto& item : a->operands()) { - if (item.As() && (IsDivisible(item.As(), b) || IsDivisible(b, item.As()))) { + if (item.As() && (IsDivisible(item.As(), b) || + IsDivisible(b, item.As()))) { external_sum_args.push_back(Divide(item.As(), b)); } else if (item.As() && IsDivisible(item.As()->value, b)) { - external_sum_args.push_back(make_const(item.type(), item.As()->value / b)); + external_sum_args.push_back( + make_const(item.type(), item.As()->value / b)); } else { sum_args.push_back(item); } @@ -1767,8 +1889,9 @@ Expr DividePartially(Sum* a, int b) { if (!external_sum_args.empty()) { if (sum_args.empty()) return Sum::Make(external_sum_args); - Expr internal_sum = sum_args.size() == 1 ? sum_args[0] : Sum::Make(sum_args); - Expr new_frac = FracOp::Make(internal_sum, make_const(a->type(), b)); + Expr internal_sum = + sum_args.size() == 1 ? sum_args[0] : Sum::Make(sum_args); + Expr new_frac = FracOp::Make(internal_sum, make_const(a->type(), b)); return Sum::Make(Concat(external_sum_args, {new_frac})); } return Expr(a); @@ -1787,9 +1910,11 @@ bool IsMonotonical(Expr u, Var v) { return false; } -// Should be called after SimplifyFracOp. If y is integer and $y\in \[0, 3\]$, then y/4=0 +// Should be called after SimplifyFracOp. If y is integer and $y\in \[0, 3\]$, +// then y/4=0 Expr CasSimplifyMutator::FurtherSimplifyFracWithInterval( - Expr expr, const absl::flat_hash_map& var_intervals) { + Expr expr, + const absl::flat_hash_map& var_intervals) { auto* node = expr.As(); if (!node) return expr; auto a = CasSimplify(node->a(), var_intervals); @@ -1804,7 +1929,8 @@ Expr CasSimplifyMutator::FurtherSimplifyFracWithInterval( if (bi) { if (av) { auto it = var_intervals.find(av->name); - if (it != var_intervals.end() && std::abs(it->second.r) < std::abs(bi->value) && + if (it != var_intervals.end() && + std::abs(it->second.r) < std::abs(bi->value) && std::abs(it->second.l) < std::abs(bi->value)) return make_const(a.type(), 0); } @@ -1812,13 +1938,14 @@ Expr CasSimplifyMutator::FurtherSimplifyFracWithInterval( // case: 1/y, y\in(2, 100) if (ai) { if (bv) { - auto it = var_intervals.find(bv->name); + auto it = var_intervals.find(bv->name); auto ai_abs = std::abs(ai->value); if (it != var_intervals.end()) { VLOG(7) << "found " << bv->name << " " << it->second << " " << " ai " << ai_abs; } - if (it != var_intervals.end() && std::abs(it->second.r) > ai_abs && std::abs(it->second.l) > ai_abs) { + if (it != var_intervals.end() && std::abs(it->second.r) > ai_abs && + std::abs(it->second.l) > ai_abs) { return make_const(a.type(), 0); } } @@ -1855,8 +1982,8 @@ Expr SimplifyConstantFrac(FracOp* node) { Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) { VLOG(7) << "CAS simplify Frac " << expr; auto* node = expr.As(); - auto a = CasSimplify(node->a(), var_intervals); - auto b = CasSimplify(node->b(), var_intervals); + auto a = CasSimplify(node->a(), var_intervals); + auto b = CasSimplify(node->b(), var_intervals); // update frac op node expr = ir::FracOp::Make(a, b); @@ -1885,12 +2012,13 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) { // case 2 // sum/x or product/x is divisible if (bi) { - auto* a_sum = a.As(); + auto* a_sum = a.As(); auto* a_product = a.As(); // divisible if (a_sum && IsDivisible(a_sum, bi->value)) return Divide(a_sum, bi->value); if (a_product) { - if (IsDivisible(a_product, bi->value) || IsDivisible(bi->value, a_product)) { + if (IsDivisible(a_product, bi->value) || + IsDivisible(bi->value, a_product)) { return Divide(a_product, bi->value); } else { return FracOp::Make(a, b); @@ -1900,20 +2028,25 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) { // if 0 < b < 3, (3a+b) / 6 = (3a / 6) + (b / 6) if (a_sum && a_sum->operands().size() == 2) { a_sum->operands()[0] = CasSimplify(a_sum->operands()[0], var_intervals); - auto sum_a_prod = a_sum->operands()[0].As(); - auto sum_b_var = a_sum->operands()[1].As<_Var_>(); + auto sum_a_prod = a_sum->operands()[0].As(); + auto sum_b_var = a_sum->operands()[1].As<_Var_>(); if (sum_a_prod && sum_b_var && var_intervals.count(sum_b_var->name)) { auto sum_a_prod_b_int = sum_a_prod->operand(1).As(); - if (sum_a_prod_b_int) std::swap(sum_a_prod->operand(0), sum_a_prod->operand(1)); + if (sum_a_prod_b_int) + std::swap(sum_a_prod->operand(0), sum_a_prod->operand(1)); auto sum_a_prod_a_int = sum_a_prod->operand(0).As(); - auto& interval = var_intervals.at(sum_b_var->name); - int b_abs = std::abs(bi->value); - int sum_prod_a_abs = std::abs(sum_a_prod_a_int->value); + auto& interval = var_intervals.at(sum_b_var->name); + int b_abs = std::abs(bi->value); + int sum_prod_a_abs = std::abs(sum_a_prod_a_int->value); if (sum_a_prod_a_int && (b_abs % sum_prod_a_abs == 0)) { - if (std::abs(interval.l) < sum_prod_a_abs && std::abs(interval.r) < sum_prod_a_abs) { - return CasSimplify(Sum::Make({CasSimplify(FracOp::Make(a_sum->operands()[0], b), var_intervals), - CasSimplify(FracOp::Make(a_sum->operands()[1], b), var_intervals)}), - var_intervals); + if (std::abs(interval.l) < sum_prod_a_abs && + std::abs(interval.r) < sum_prod_a_abs) { + return CasSimplify( + Sum::Make({CasSimplify(FracOp::Make(a_sum->operands()[0], b), + var_intervals), + CasSimplify(FracOp::Make(a_sum->operands()[1], b), + var_intervals)}), + var_intervals); } } } @@ -1933,22 +2066,25 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) { auto cmp_min = a.As(); auto cmp_max = a.As(); if (cmp_min) { - return {CasSimplify(Min::Make(CasSimplify(FracOp::Make(cmp_min->a(), b), var_intervals), - CasSimplify(FracOp::Make(cmp_min->b(), b), var_intervals)), - var_intervals)}; + return {CasSimplify( + Min::Make(CasSimplify(FracOp::Make(cmp_min->a(), b), var_intervals), + CasSimplify(FracOp::Make(cmp_min->b(), b), var_intervals)), + var_intervals)}; } if (cmp_max) { - return {CasSimplify(Max::Make(CasSimplify(FracOp::Make(cmp_max->a(), b), var_intervals), - CasSimplify(FracOp::Make(cmp_max->b(), b), var_intervals)), - var_intervals)}; + return {CasSimplify( + Max::Make(CasSimplify(FracOp::Make(cmp_max->a(), b), var_intervals), + CasSimplify(FracOp::Make(cmp_max->b(), b), var_intervals)), + var_intervals)}; } } if (av && bi) { if (var_intervals.count(av->name)) { auto& interval = var_intervals.at(av->name); - int b_abs = std::abs(bi->value); - if (std::abs(interval.l) < b_abs && std::abs(interval.r) < b_abs) return make_const(bi->type(), 0); + int b_abs = std::abs(bi->value); + if (std::abs(interval.l) < b_abs && std::abs(interval.r) < b_abs) + return make_const(bi->type(), 0); return FracOp::Make(a, b); } } @@ -1959,9 +2095,13 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) { std::vector internal_sum_args; for (auto& e : as->operands()) { if (IsDivisible(e, bi->value)) { - if (e.As()) external_sum_args.push_back(Divide(e.As(), bi->value)); - if (e.As()) external_sum_args.push_back(make_const(bi->type(), e.As()->value / bi->value)); - if (e.As()) external_sum_args.push_back(Divide(e.As(), bi->value)); + if (e.As()) + external_sum_args.push_back(Divide(e.As(), bi->value)); + if (e.As()) + external_sum_args.push_back( + make_const(bi->type(), e.As()->value / bi->value)); + if (e.As()) + external_sum_args.push_back(Divide(e.As(), bi->value)); } else { internal_sum_args.push_back(e); } @@ -1980,7 +2120,8 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) { } if (external_sum.defined() && internal_sum.defined()) { - return CasSimplify(Sum::Make({external_sum, internal_sum}), var_intervals); + return CasSimplify(Sum::Make({external_sum, internal_sum}), + var_intervals); } if (external_sum.defined()) return CasSimplify(external_sum, var_intervals); return internal_sum; @@ -1988,7 +2129,8 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) { // solve the case: 2abc / b // Both avs and bvs should be sorted first. - auto reduce_product_div_product = [](const std::vector& avs, const std::vector& bvs) { + auto reduce_product_div_product = [](const std::vector& avs, + const std::vector& bvs) { std::vector avs1, bvs1; int i = 0; int j = 0; @@ -2005,14 +2147,14 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) { auto* bf = b.As(); if (ai) { CHECK(bi); - int g = gcd(ai->value, bi->value); + int g = gcd(ai->value, bi->value); int a_d = ai->value / g; int b_d = bi->value / g; avs1.push_back(make_const(a.type(), a_d)); if (b_d != 1) bvs1.push_back(make_const(b.type(), b_d)); } else if (af || bf) { - double value = af->value / bf->value; + double value = af->value / bf->value; const auto& ftype = af ? af->type() : bf->type(); avs1.push_back(make_const(ftype, value)); } else { @@ -2075,7 +2217,7 @@ Expr CasSimplifyMutator::SimplifyCond(Expr u) { // -------------------------- NOT ----------------------------- case ir::IrNodeTy::Not: { auto* node = u.As(); - Expr v = operator()(node->v()); + Expr v = operator()(node->v()); switch (v.node_type()) { // Not 1 = (1 == 0) case ir::IrNodeTy::IntImm: @@ -2139,7 +2281,9 @@ Expr CasSimplifyMutator::SimplifyCond(Expr u) { } // namespace detail -Expr CasSimplify(Expr u, const absl::flat_hash_map& var_intervals) { +Expr CasSimplify( + Expr u, + const absl::flat_hash_map& var_intervals) { return detail::CasSimplifyMutator(var_intervals)(u); } @@ -2168,8 +2312,8 @@ Expr SolveInequality(Expr inequality, Var val) { // if (common::IsPureMath(a) && common::IsPureMath(b)) { if (true) { auto _res_positive_ = common::Solve(a, b, val); // NOLINT - auto& res = std::get<0>(_res_positive_); - auto& positive = std::get<1>(_res_positive_); + auto& res = std::get<0>(_res_positive_); + auto& positive = std::get<1>(_res_positive_); // Simplify it with CAS to avoid random result from GiNac. res = AutoSimplify(res); res = common::cast(res, val->type()); diff --git a/paddle/cinn/common/cas.h b/paddle/cinn/common/cas.h index eca803d3b6aa6..c7c4517d63524 100755 --- a/paddle/cinn/common/cas.h +++ b/paddle/cinn/common/cas.h @@ -41,9 +41,11 @@ struct CasInterval { } /** - * @brief When iterator's upper_bound is an ir::Min of a constant value and a inconstant value, choose the constant - * value. When iterator's lower_bound is an ir::Max of a constant value and a inconstant value, choose the constant - * value. E.g: expr_l = max(x, 1) and expr_r = min(y,5): max(x, 1) <= iterator_i <= min(y,5) + * @brief When iterator's upper_bound is an ir::Min of a constant value and a + * inconstant value, choose the constant value. When iterator's lower_bound is + * an ir::Max of a constant value and a inconstant value, choose the constant + * value. E.g: expr_l = max(x, 1) and expr_r = min(y,5): max(x, 1) <= + * iterator_i <= min(y,5) * * the bounds will be simplified to e_l = 1 and e_r = 5: * 1 <= iterator_i <= 5 @@ -54,7 +56,8 @@ struct CasInterval { expr_l = detail::ReplaceMaxToConstant(expr_l); optim::Simplify(&expr_l); optim::Simplify(&expr_r); - VLOG(2) << "After simplify, CasInterval is : [" << expr_l << ", " << expr_r << "]."; + VLOG(2) << "After simplify, CasInterval is : [" << expr_l << ", " << expr_r + << "]."; if (expr_l.is_constant() && expr_r.is_constant()) { CHECK(expr_l->type().is_integer()); @@ -82,16 +85,21 @@ struct CasInterval { using cas_intervals_t = absl::flat_hash_map; -Expr AutoSimplify(Expr u, const absl::flat_hash_map& var_intervals = {}); +Expr AutoSimplify( + Expr u, + const absl::flat_hash_map& var_intervals = {}); //! Simplify a CAS expression. -Expr CasSimplify(Expr u, const absl::flat_hash_map& var_intervals = {}); +Expr CasSimplify( + Expr u, + const absl::flat_hash_map& var_intervals = {}); /** * \brief Solve an equality. * Currently this is an naive implementation using the GiNaC. * - * @param inequality The inequality expression containing an LE or LT or GT or GE, such as 2x-1<3 + * @param inequality The inequality expression containing an LE or LT or GT or + * GE, such as 2x-1<3 * @param val The target variable. * @return an copied expression looks like x < 100. */ @@ -100,11 +108,14 @@ Expr SolveInequalityInt(Expr inequality, Var val); namespace detail { -//! Whether to treat this expression as a symbol. e.g. Load, Min, Max are treated as symbol to avoid confusing the CAS. +//! Whether to treat this expression as a symbol. e.g. Load, Min, Max are +//! treated as symbol to avoid confusing the CAS. bool CASasSymbol(Expr expr); -//! Convert some nodes to CAS representation, e.g. convert Mul, Add to Product and Sum. +//! Convert some nodes to CAS representation, e.g. convert Mul, Add to Product +//! and Sum. Expr ConvertCinnToCAS(Expr expr); -//! Convert the CAS representation to CINN expression, e.g. convert Product and Sum to Mul and Add. +//! Convert the CAS representation to CINN expression, e.g. convert Product and +//! Sum to Mul and Add. Expr ConvertCasToCinn(Expr expr); //! Tell whether this expression is acceptable by CAS. bool IsExprCasCompatible(Expr expr); @@ -114,7 +125,8 @@ struct ExprPosCmp { }; struct CasSimplifyMutator { - explicit CasSimplifyMutator(const absl::flat_hash_map var_intervals) + explicit CasSimplifyMutator( + const absl::flat_hash_map var_intervals) : var_intervals(var_intervals) {} Expr operator()(Expr u); @@ -130,33 +142,60 @@ struct CasSimplifyMutator { Expr SimplifyMod(Expr u); Expr SimplifyFracOp(Expr expr); Expr SimplifyCond(Expr u); - Expr FurtherSimplifyFracWithInterval(Expr expr, const absl::flat_hash_map& var_intervals); + Expr FurtherSimplifyFracWithInterval( + Expr expr, + const absl::flat_hash_map& var_intervals); Expr SimplifyIntegerPower(Expr u); void AddBaseAndSimplify(Expr* base, Expr bound); - void UnfoldBound(Expr* lower_bound, Expr* upper_bound, Expr var, bool unfold_const_bound = true); - bool GetVarBound(Expr* lower_bound, Expr* upper_bound, Expr var, bool unfold_const_bound = true); - bool GetOperandBound(Expr* lower_bound, Expr* upper_bound, Expr var, bool unfold_const_bound = true); - bool GetSumBound(Expr* lower_bound, Expr* upper_bound, Expr sum, bool unfold_const_bound = true); - bool GetMinBound(Expr* lower_bound, Expr* upper_bound, Expr min, bool unfold_const_bound = true); - bool GetMaxBound(Expr* lower_bound, Expr* upper_bound, Expr max, bool unfold_const_bound = true); - bool GetExprBound(Expr* lower_bound, Expr* upper_bound, Expr min, bool unfold_const_bound = true); + void UnfoldBound(Expr* lower_bound, + Expr* upper_bound, + Expr var, + bool unfold_const_bound = true); + bool GetVarBound(Expr* lower_bound, + Expr* upper_bound, + Expr var, + bool unfold_const_bound = true); + bool GetOperandBound(Expr* lower_bound, + Expr* upper_bound, + Expr var, + bool unfold_const_bound = true); + bool GetSumBound(Expr* lower_bound, + Expr* upper_bound, + Expr sum, + bool unfold_const_bound = true); + bool GetMinBound(Expr* lower_bound, + Expr* upper_bound, + Expr min, + bool unfold_const_bound = true); + bool GetMaxBound(Expr* lower_bound, + Expr* upper_bound, + Expr max, + bool unfold_const_bound = true); + bool GetExprBound(Expr* lower_bound, + Expr* upper_bound, + Expr min, + bool unfold_const_bound = true); bool SimplifySpecificSumMod(Expr* u, Expr a, Expr b); Expr SimplifySpecificSum(Expr u); private: std::vector SimplifyBinaryProduct(Expr left, Expr right); - std::vector MergeProduct(const std::vector& p, const std::vector& q); + std::vector MergeProduct(const std::vector& p, + const std::vector& q); std::vector SimplifyBinarySum(Expr left, Expr right); - std::vector MergeSum(const std::vector& p, const std::vector& q); - std::vector MergeExprs(const std::vector& p, - const std::vector& q, - const std::function(Expr, Expr)>& binary_merge); + std::vector MergeSum(const std::vector& p, + const std::vector& q); + std::vector MergeExprs( + const std::vector& p, + const std::vector& q, + const std::function(Expr, Expr)>& binary_merge); const absl::flat_hash_map var_intervals; - // Computation based on integer if set true(1/2 get 0), false if treat as rational number in mathematics(1/2 is still - // 1/2), currently it only works with true. + // Computation based on integer if set true(1/2 get 0), false if treat as + // rational number in mathematics(1/2 is still 1/2), currently it only works + // with true. bool int_compute_{true}; }; diff --git a/paddle/cinn/common/cas_test.cc b/paddle/cinn/common/cas_test.cc index 224ffd36be38c..b13f8c8a2e97f 100644 --- a/paddle/cinn/common/cas_test.cc +++ b/paddle/cinn/common/cas_test.cc @@ -34,7 +34,8 @@ using namespace ir; // NOLINT TEST(CAS, number_cal) { // 1 * 100 * -1 + 0 + 1001 - auto u1 = Sum::Make({Product::Make({Expr(1), Expr(100), Expr(-1)}), Expr(0), Expr(1001)}); + auto u1 = Sum::Make( + {Product::Make({Expr(1), Expr(100), Expr(-1)}), Expr(0), Expr(1001)}); LOG(INFO) << u1; } @@ -49,11 +50,16 @@ TEST(CAS, cmp) { EXPECT_EQ(cmp(Expr(1), x), true); // x * y * z > x * y - EXPECT_EQ(cmp(ir::Product::Make({x, y, z}), ir::Product::Make({x, y})), false); + EXPECT_EQ(cmp(ir::Product::Make({x, y, z}), ir::Product::Make({x, y})), + false); // x * y * z > 10 * y * z - EXPECT_EQ(cmp(ir::Product::Make({x, y, z}), ir::Product::Make({Expr(10), y, z})), false); + EXPECT_EQ( + cmp(ir::Product::Make({x, y, z}), ir::Product::Make({Expr(10), y, z})), + false); // 1 * y * z < 10 * y * z - EXPECT_EQ(cmp(ir::Product::Make({Expr(1), y, z}), ir::Product::Make({Expr(10), y, z})), true); + EXPECT_EQ(cmp(ir::Product::Make({Expr(1), y, z}), + ir::Product::Make({Expr(10), y, z})), + true); } TEST(CAS, SimplifySum) { @@ -67,7 +73,8 @@ TEST(CAS, SimplifySum) { // z + 1 + y + x + zx auto u3 = CasSimplify(Sum::Make({z, Expr(1), y, x, Product::Make({z, x})})); // z + 1 + y + 3 + x + 0 + zx - auto u4 = CasSimplify(Sum::Make({z, Expr(1), y, Expr(3), x, Expr(0), Product::Make({z, x})})); + auto u4 = CasSimplify( + Sum::Make({z, Expr(1), y, Expr(3), x, Expr(0), Product::Make({z, x})})); // x2 + 3zy + -3*yz + -2x + 1 auto u5 = CasSimplify(Sum::Make({Product::Make({x, Expr(2)}), Product::Make({z, y, Expr(3)}), @@ -103,7 +110,9 @@ TEST(CAS, SimplifyMod) { // (x+y+z) % 2 = x%2 + y%2 + z%2 auto u2 = CasSimplify(Mod::Make(Sum::Make({x, y, z}), Expr(2))); // x%2 + 1%2 + x%2 - auto u3 = CasSimplify(Sum::Make({Mod::Make(x, Expr(2)), Mod::Make(Expr(1), Expr(2)), Mod::Make(x, Expr(2))})); + auto u3 = CasSimplify(Sum::Make({Mod::Make(x, Expr(2)), + Mod::Make(Expr(1), Expr(2)), + Mod::Make(x, Expr(2))})); EXPECT_EQ(GetStreamCnt(u1), "0"); EXPECT_EQ(GetStreamCnt(u2), "((x + y + z) % 2)"); @@ -122,10 +131,12 @@ TEST(CAS, SimplifyModForVectorize) { // = (8*x) - (x//8)*(8*8) // = 8*(x-(x//8)*8) // since mod definition // = 8*(x%8) - auto u1 = CasSimplify(Mod::Make( - Mod::Make(Mod::Make(Sum::Make({Product::Make({x, Expr(8)}), Product::Make({y, Expr(1024)})}), Expr(802816)), - Expr(7168)), - Expr(64))); + auto u1 = CasSimplify( + Mod::Make(Mod::Make(Mod::Make(Sum::Make({Product::Make({x, Expr(8)}), + Product::Make({y, Expr(1024)})}), + Expr(802816)), + Expr(7168)), + Expr(64))); std::cout << GetStreamCnt(u1); EXPECT_EQ(GetStreamCnt(u1), "((x % 8) * 8)"); } @@ -136,7 +147,9 @@ TEST(CAS, ConvertCinnToCAS) { auto C = Compute( {Expr(10), Expr(10)}, - [&](Expr i, Expr j) { return A(i, j) + 0.f + 1.f + 2.f * B(i, j) + 0.f * B(i, j) * A(i, j); }, + [&](Expr i, Expr j) { + return A(i, j) + 0.f + 1.f + 2.f * B(i, j) + 0.f * B(i, j) * A(i, j); + }, "C"); Expr body = C->body(); @@ -144,9 +157,11 @@ TEST(CAS, ConvertCinnToCAS) { body = detail::ConvertCinnToCAS(body); body = CasSimplify(body); - EXPECT_EQ(GetStreamCnt(body), "(1.00000000f + A[i, j] + (2.00000000f * B[i, j]))"); + EXPECT_EQ(GetStreamCnt(body), + "(1.00000000f + A[i, j] + (2.00000000f * B[i, j]))"); body = detail::ConvertCasToCinn(body); - EXPECT_EQ(GetStreamCnt(body), "(1.00000000f + (A[i, j] + (2.00000000f * B[i, j])))"); + EXPECT_EQ(GetStreamCnt(body), + "(1.00000000f + (A[i, j] + (2.00000000f * B[i, j])))"); } TEST(CAS, FracOp) { @@ -212,10 +227,15 @@ TEST(CAS, Mod) { // u = AutoSimplify((x + 20 * y + 5) % 5, var_intervals0); // OUTPUT_EQUAL("x") - u = AutoSimplify((x % 32) + ((32768 * (x / 32)) + ((32768 * y) + ((32 * z) + (128 * k))))); - OUTPUT_EQUAL("((32768 * (x / 32)) + ((x % 32) + ((128 * k) + ((32768 * y) + (32 * z)))))"); + u = AutoSimplify( + (x % 32) + ((32768 * (x / 32)) + ((32768 * y) + ((32 * z) + (128 * k))))); + OUTPUT_EQUAL( + "((32768 * (x / 32)) + ((x % 32) + ((128 * k) + ((32768 * y) + (32 * " + "z)))))"); - u = AutoSimplify((x % 32) + ((32768 * (x / 32)) + ((32768 * y) + ((32 * z) + (128 * k)))), var_intervals0); + u = AutoSimplify( + (x % 32) + ((32768 * (x / 32)) + ((32768 * y) + ((32 * z) + (128 * k)))), + var_intervals0); OUTPUT_EQUAL("((128 * k) + (x + ((32768 * y) + (32 * z))))") // (2x+y+z) % 2 = (y+z) % 2 @@ -289,18 +309,21 @@ TEST(SolveInequality, basic) { Var x("x", Int(32)); Var y("y", Int(32)); -#define TEST_SOLVE(expr__, str__) EXPECT_EQ(GetStreamCnt(SolveInequality(expr__, x)), str__); +#define TEST_SOLVE(expr__, str__) \ + EXPECT_EQ(GetStreamCnt(SolveInequality(expr__, x)), str__); TEST_SOLVE(x * -1 + 20 < 0, "(x > 20)"); TEST_SOLVE(x * 2 + 3 < x * 10 - 20, "(x > 2)"); TEST_SOLVE(x * -1 < -1, "(x > 1)"); TEST_SOLVE(Expr(2) * x * -1 - x < x + 200, "(x > -50)"); - TEST_SOLVE(Expr(2) * x + 30 - x * 3 + y * 23 < 2, "(x > int32((28 + (23 * y))))"); - TEST_SOLVE(x + ir::Min::Make(Expr(2), Expr(3) * y) < 100, "(x < int32(cinn_max((100 + (-3 * y)), 98)))"); + TEST_SOLVE(Expr(2) * x + 30 - x * 3 + y * 23 < 2, + "(x > int32((28 + (23 * y))))"); + TEST_SOLVE(x + ir::Min::Make(Expr(2), Expr(3) * y) < 100, + "(x < int32(cinn_max((100 + (-3 * y)), 98)))"); } TEST(CAS, SimplifyCompoundMod) { { // (-a % 4) * (-1) - Var x = ir::_Var_::Make("x", Int(32)); + Var x = ir::_Var_::Make("x", Int(32)); auto p0 = ir::Product::Make({ir::Mod::Make(-x, Expr(4)), Expr(-1)}); LOG(INFO) << "p0 " << p0; auto p2 = AutoSimplify(p0); @@ -308,8 +331,9 @@ TEST(CAS, SimplifyCompoundMod) { EXPECT_EQ(GetStreamCnt(p2), "(-1 * ((-1 * x) % 4))"); } { // (33 + x % 34) + -33 - Var x = ir::_Var_::Make("x", Int(32)); - auto p0 = ir::Sum::Make({Expr(33), ir::Sum::Make({ir::Mod::Make(x, Expr(4)), Expr(-33)})}); + Var x = ir::_Var_::Make("x", Int(32)); + auto p0 = ir::Sum::Make( + {Expr(33), ir::Sum::Make({ir::Mod::Make(x, Expr(4)), Expr(-33)})}); LOG(INFO) << "p0 " << p0; auto p2 = AutoSimplify(p0); LOG(INFO) << "simplified " << p2; @@ -317,17 +341,20 @@ TEST(CAS, SimplifyCompoundMod) { } { // 33 + (x % 2 + (-16)) Var x = ir::_Var_::Make("x", Int(32)); - auto p0 = - ir::Sum::Make({Expr(33), ir::Sum::Make({ir::Mod::Make(x, Expr(2)), ir::Product::Make({Expr(-1), Expr(16)})})}); + auto p0 = ir::Sum::Make( + {Expr(33), + ir::Sum::Make({ir::Mod::Make(x, Expr(2)), + ir::Product::Make({Expr(-1), Expr(16)})})}); LOG(INFO) << "p0 " << p0; auto p2 = AutoSimplify(p0); LOG(INFO) << "simplified " << p2; EXPECT_EQ(GetStreamCnt(p2), "(17 + (x % 2))"); } { // (32- x1 - 16 * x2) % 33 - Var x1 = ir::_Var_::Make("x1", Int(32)); - Var x2 = ir::_Var_::Make("x2", Int(32)); - auto p0 = ir::Mod::Make(ir::Sum::Make({Expr(32), -x1, Expr(16) * -x2}), Expr(33)); + Var x1 = ir::_Var_::Make("x1", Int(32)); + Var x2 = ir::_Var_::Make("x2", Int(32)); + auto p0 = + ir::Mod::Make(ir::Sum::Make({Expr(32), -x1, Expr(16) * -x2}), Expr(33)); LOG(INFO) << "p0 " << p0; absl::flat_hash_map var_intervals; var_intervals.emplace("x1", CasInterval{0, 15}); @@ -343,7 +370,7 @@ TEST(CAS, SimplifyCompoundMod) { } TEST(CAS, SimplifyNegtive) { { // (-1*x) /2 - Var x = ir::_Var_::Make("x", Int(32)); + Var x = ir::_Var_::Make("x", Int(32)); auto p0 = ir::FracOp::Make(-x, Expr(2)); LOG(INFO) << "p0 " << p0; auto p2 = AutoSimplify(p0); @@ -361,7 +388,7 @@ TEST(CAS, SimplifyNegtive) { TEST(CAS, SimplifyMinMax) { { // 1+cinn_min(15, x) - Var x = ir::_Var_::Make("x", Int(32)); + Var x = ir::_Var_::Make("x", Int(32)); auto p0 = ir::Sum::Make({Expr(1), ir::Min::Make(Expr(15), x)}); LOG(INFO) << "p0 " << p0; auto p2 = CasSimplify(p0); @@ -369,7 +396,7 @@ TEST(CAS, SimplifyMinMax) { EXPECT_EQ(GetStreamCnt(p2), "cinn_min(16, (1 + x))"); } { // 2*cinn_min(15, x) - Var x = ir::_Var_::Make("x", Int(32)); + Var x = ir::_Var_::Make("x", Int(32)); auto p0 = ir::Product::Make({Expr(2), ir::Min::Make(Expr(15), x)}); LOG(INFO) << "p0 " << p0; auto p2 = CasSimplify(p0); @@ -377,7 +404,7 @@ TEST(CAS, SimplifyMinMax) { EXPECT_EQ(GetStreamCnt(p2), "cinn_min(30, (2 * x))"); } { // cinn_min(15, x)/2 - Var x = ir::_Var_::Make("x", Int(32)); + Var x = ir::_Var_::Make("x", Int(32)); auto p0 = ir::FracOp::Make(ir::Min::Make(Expr(15), x), Expr(2)); LOG(INFO) << "p0 " << p0; auto p2 = CasSimplify(p0); @@ -385,17 +412,22 @@ TEST(CAS, SimplifyMinMax) { EXPECT_EQ(GetStreamCnt(p2), "cinn_min(7, (x / 2))"); } { // -(cinn_min(16, 3400-x-1)-1)/2 + x - Var x = ir::_Var_::Make("x", Int(32)); - auto p0 = ir::FracOp::Make(ir::Min::Make(Expr(16), 3400 - x - 1) - 1, Expr(2)); - p0 = -p0 + x; + Var x = ir::_Var_::Make("x", Int(32)); + auto p0 = + ir::FracOp::Make(ir::Min::Make(Expr(16), 3400 - x - 1) - 1, Expr(2)); + p0 = -p0 + x; LOG(INFO) << "p0 " << p0; auto p2 = AutoSimplify(p0); LOG(INFO) << "simplified " << p2; - EXPECT_EQ(GetStreamCnt(p2), "cinn_max((-1699 + ((-1 * ((-1 * x) / 2)) + x)), (-7 + x))"); + EXPECT_EQ(GetStreamCnt(p2), + "cinn_max((-1699 + ((-1 * ((-1 * x) / 2)) + x)), (-7 + x))"); } { // cinn_max((-1 * (3399 + (-16 * i_j_fused_outer))), -15) - Var x = ir::_Var_::Make("x", Int(32)); - auto p0 = ir::Max::Make(ir::Product::Make({Expr(-1), ir::Sum::Make({Expr(3399), Expr(-16) * x})}), Expr(-15)); + Var x = ir::_Var_::Make("x", Int(32)); + auto p0 = ir::Max::Make( + ir::Product::Make( + {Expr(-1), ir::Sum::Make({Expr(3399), Expr(-16) * x})}), + Expr(-15)); LOG(INFO) << "p0 " << p0; auto p2 = AutoSimplify(p0); LOG(INFO) << "simplified " << p2; diff --git a/paddle/cinn/common/cinn_value.cc b/paddle/cinn/common/cinn_value.cc index e705cfdac24fd..3b25f93201333 100644 --- a/paddle/cinn/common/cinn_value.cc +++ b/paddle/cinn/common/cinn_value.cc @@ -38,7 +38,8 @@ namespace common { return code__; \ } __m(std::nullptr_t, -1); -__m(char *, 20); // start from a larger number to avoid duplicate id with cinn_pod_value_t +__m(char *, 20); // start from a larger number to avoid duplicate id with + // cinn_pod_value_t __m(char const *, 21); __m(ir::Expr, 22); __m(ir::Var, 23); @@ -106,18 +107,24 @@ cinn_value_t ToValue(char const *v) { } // @} -bool CINNValue::is_string() const { return type_code_ == TypeCode(); } +bool CINNValue::is_string() const { + return type_code_ == TypeCode(); +} bool CINNValue::is_var() const { return type_code_ == TypeCode(); } bool CINNValue::is_expr() const { - return type_code_ == TypeCode() && !absl::any_cast(shared_).as_tensor(); + return type_code_ == TypeCode() && + !absl::any_cast(shared_).as_tensor(); } -bool CINNValue::is_stagemap() const { return type_code_ == TypeCode(); } +bool CINNValue::is_stagemap() const { + return type_code_ == TypeCode(); +} bool CINNValue::is_tensor() const { - return type_code_ == TypeCode() && absl::any_cast(shared_).as_tensor(); + return type_code_ == TypeCode() && + absl::any_cast(shared_).as_tensor(); } CINNValue::operator std::string() const { @@ -140,24 +147,30 @@ CINNValue::operator poly::StageMap() const { CHECK_EQ(type_code(), TypeCode()); return absl::any_cast(shared_); } -CINNValue::CINNValue(char *value) : cinn_pod_value_t(ToValue(value), TypeCode()) {} +CINNValue::CINNValue(char *value) + : cinn_pod_value_t(ToValue(value), TypeCode()) {} -CINNValue::CINNValue(const std::string &value) : cinn_pod_value_t(cinn_value_t(), TypeCode()) { +CINNValue::CINNValue(const std::string &value) + : cinn_pod_value_t(cinn_value_t(), TypeCode()) { shared_ = value; } -CINNValue::CINNValue(const Var &value) : cinn_pod_value_t(cinn_value_t(), TypeCode()) { +CINNValue::CINNValue(const Var &value) + : cinn_pod_value_t(cinn_value_t(), TypeCode()) { CHECK(value.defined()); shared_ = value; } -CINNValue::CINNValue(const Expr &value) : cinn_pod_value_t(cinn_value_t(), TypeCode()) { +CINNValue::CINNValue(const Expr &value) + : cinn_pod_value_t(cinn_value_t(), TypeCode()) { CHECK(value.defined()); shared_ = value; } -CINNValue::CINNValue(const CINNValuePack &value) : cinn_pod_value_t(cinn_value_t(), TypeCode()) { +CINNValue::CINNValue(const CINNValuePack &value) + : cinn_pod_value_t(cinn_value_t(), TypeCode()) { CHECK(value.defined()); shared_ = value; } -CINNValue::CINNValue(const poly::StageMap &value) : cinn_pod_value_t(cinn_value_t(), TypeCode()) { +CINNValue::CINNValue(const poly::StageMap &value) + : cinn_pod_value_t(cinn_value_t(), TypeCode()) { CHECK(value.defined()); shared_ = value; } diff --git a/paddle/cinn/common/cinn_value.h b/paddle/cinn/common/cinn_value.h index 5db64e41bfc90..587a79ec71c6f 100755 --- a/paddle/cinn/common/cinn_value.h +++ b/paddle/cinn/common/cinn_value.h @@ -85,9 +85,12 @@ struct _CINNValuePack_ : public common::Object { struct CINNValuePack : public Shared<_CINNValuePack_> { explicit CINNValuePack(_CINNValuePack_* ptr) : Shared<_CINNValuePack_>(ptr) {} - explicit CINNValuePack(const std::vector& array) : Shared<_CINNValuePack_>(_CINNValuePack_::Make(array)) {} + explicit CINNValuePack(const std::vector& array) + : Shared<_CINNValuePack_>(_CINNValuePack_::Make(array)) {} CINNValue& operator[](int offset) { return (*operator->())[offset]; } - const CINNValue& operator[](int offset) const { return (*operator->())[offset]; } + const CINNValue& operator[](int offset) const { + return (*operator->())[offset]; + } size_t size() const { return (*operator->()).size(); } @@ -108,22 +111,38 @@ struct CINNValuePack : public Shared<_CINNValuePack_> { }; /** - * Handler for value types in CINN system. It supports two kinds of values: the POD and Shared. + * Handler for value types in CINN system. It supports two kinds of values: the + * POD and Shared. */ class CINNValue : public cinn_pod_value_t { public: static constexpr int kNull = -1; CINNValue() : cinn_pod_value_t(cinn_value_t(), kNull) {} - CINNValue(cinn_value_t value, int type_code) : cinn_pod_value_t(value, type_code) {} - - explicit CINNValue(bool value) : cinn_pod_value_t(value) { type_code_ = ::cinn_type_code(); } - explicit CINNValue(int32_t value) : cinn_pod_value_t(value) { type_code_ = ::cinn_type_code(); } - explicit CINNValue(int64_t value) : cinn_pod_value_t(value) { type_code_ = ::cinn_type_code(); } - explicit CINNValue(float value) : cinn_pod_value_t(value) { type_code_ = ::cinn_type_code(); } - explicit CINNValue(bfloat16 value) : cinn_pod_value_t(value) { type_code_ = ::cinn_type_code(); } - explicit CINNValue(float16 value) : cinn_pod_value_t(value) { type_code_ = ::cinn_type_code(); } - explicit CINNValue(double value) : cinn_pod_value_t(value) { type_code_ = ::cinn_type_code(); } + CINNValue(cinn_value_t value, int type_code) + : cinn_pod_value_t(value, type_code) {} + + explicit CINNValue(bool value) : cinn_pod_value_t(value) { + type_code_ = ::cinn_type_code(); + } + explicit CINNValue(int32_t value) : cinn_pod_value_t(value) { + type_code_ = ::cinn_type_code(); + } + explicit CINNValue(int64_t value) : cinn_pod_value_t(value) { + type_code_ = ::cinn_type_code(); + } + explicit CINNValue(float value) : cinn_pod_value_t(value) { + type_code_ = ::cinn_type_code(); + } + explicit CINNValue(bfloat16 value) : cinn_pod_value_t(value) { + type_code_ = ::cinn_type_code(); + } + explicit CINNValue(float16 value) : cinn_pod_value_t(value) { + type_code_ = ::cinn_type_code(); + } + explicit CINNValue(double value) : cinn_pod_value_t(value) { + type_code_ = ::cinn_type_code(); + } explicit CINNValue(char* value); explicit CINNValue(cinn_buffer_t* value) : cinn_pod_value_t(value) {} explicit CINNValue(void* value) : cinn_pod_value_t(value) {} diff --git a/paddle/cinn/common/common.h b/paddle/cinn/common/common.h index 25e371caf7824..e54d8aad4b31d 100644 --- a/paddle/cinn/common/common.h +++ b/paddle/cinn/common/common.h @@ -63,7 +63,8 @@ static void CheckVarNameValid(const absl::string_view name) { name.find('\n') == std::string::npos && // name.find('\r') == std::string::npos) << "Some invalid character found"; - CHECK(!common::IsAxisNameReserved(std::string(name))) << "The name [" << name << "] is reserved for internal axis"; + CHECK(!common::IsAxisNameReserved(std::string(name))) + << "The name [" << name << "] is reserved for internal axis"; } } // namespace cinn diff --git a/paddle/cinn/common/context.cc b/paddle/cinn/common/context.cc index 2f985f26f05e5..d243fef9e263d 100644 --- a/paddle/cinn/common/context.cc +++ b/paddle/cinn/common/context.cc @@ -50,8 +50,10 @@ const std::vector& Context::runtime_include_dir() { VLOG(4) << "get runtime_include_dir from env: " << env; runtime_include_dir_ = cinn::utils::Split(env, ":"); } else if (defined_runtime_include_dir) { - VLOG(4) << "get runtime_include_dir from RUNTIME_INCLUDE_DIR: " << defined_runtime_include_dir; - runtime_include_dir_ = cinn::utils::Split(defined_runtime_include_dir, ":"); + VLOG(4) << "get runtime_include_dir from RUNTIME_INCLUDE_DIR: " + << defined_runtime_include_dir; + runtime_include_dir_ = + cinn::utils::Split(defined_runtime_include_dir, ":"); } } return runtime_include_dir_; @@ -76,5 +78,7 @@ std::string NameGenerator::New(const std::string& name_hint) { } // namespace common -DEFINE_bool(cinn_runtime_display_debug_info, false, "Whether to display debug information in runtime"); +DEFINE_bool(cinn_runtime_display_debug_info, + false, + "Whether to display debug information in runtime"); } // namespace cinn diff --git a/paddle/cinn/common/context.h b/paddle/cinn/common/context.h index 4a5d774fae3ee..1c7202ff6dc97 100644 --- a/paddle/cinn/common/context.h +++ b/paddle/cinn/common/context.h @@ -60,7 +60,9 @@ class Context { * Generate a new unique name. * @param name_hint The prefix. */ - std::string NewName(const std::string& name_hint) { return name_generator_.New(name_hint); } + std::string NewName(const std::string& name_hint) { + return name_generator_.New(name_hint); + } void ResetNameId() { name_generator_.ResetID(); } @@ -89,7 +91,9 @@ class Context { static thread_local DebugManager debug_mgr_; }; -static std::string UniqName(const std::string& prefix) { return Context::Global().NewName(prefix); } +static std::string UniqName(const std::string& prefix) { + return Context::Global().NewName(prefix); +} } // namespace common } // namespace cinn diff --git a/paddle/cinn/common/cost_model.h b/paddle/cinn/common/cost_model.h index 6c5f4cc79babc..d95cd7b519c9a 100644 --- a/paddle/cinn/common/cost_model.h +++ b/paddle/cinn/common/cost_model.h @@ -25,11 +25,14 @@ namespace auto_schedule { */ class CostModel { public: - virtual void Train(const std::vector>& samples, const std::vector& labels) = 0; + virtual void Train(const std::vector>& samples, + const std::vector& labels) = 0; - virtual std::vector Predict(const std::vector>& samples) const = 0; + virtual std::vector Predict( + const std::vector>& samples) const = 0; - virtual void Update(const std::vector>& samples, const std::vector& labels) = 0; + virtual void Update(const std::vector>& samples, + const std::vector& labels) = 0; virtual void Save(const std::string& path) = 0; diff --git a/paddle/cinn/common/cuda_test_helper.cc b/paddle/cinn/common/cuda_test_helper.cc index 8c9b67985a10a..f43678266daa5 100644 --- a/paddle/cinn/common/cuda_test_helper.cc +++ b/paddle/cinn/common/cuda_test_helper.cc @@ -25,10 +25,12 @@ namespace cinn { namespace common { #ifdef CINN_WITH_CUDA -void CudaModuleTester::Compile(const ir::Module& m, const std::string& rewrite_cuda_code) { - auto _host_module_device_module_ = backends::SplitCudaAndHostModule(m); // NOLINT - auto& host_module = std::get<0>(_host_module_device_module_); - auto& device_module = std::get<1>(_host_module_device_module_); +void CudaModuleTester::Compile(const ir::Module& m, + const std::string& rewrite_cuda_code) { + auto _host_module_device_module_ = + backends::SplitCudaAndHostModule(m); // NOLINT + auto& host_module = std::get<0>(_host_module_device_module_); + auto& device_module = std::get<1>(_host_module_device_module_); CHECK(!host_module.functions().empty()); CHECK(!device_module.functions().empty()); @@ -44,16 +46,19 @@ void CudaModuleTester::Compile(const ir::Module& m, const std::string& rewrite_c else ptx = compiler(rewrite_cuda_code); - cuda_module_ = new runtime::cuda::CUDAModule(ptx, runtime::cuda::CUDAModule::Kind::PTX); + cuda_module_ = + new runtime::cuda::CUDAModule(ptx, runtime::cuda::CUDAModule::Kind::PTX); for (auto& fn : device_module.functions()) { std::string kernel_fn_name = fn->name; - auto fn_kernel = reinterpret_cast(cuda_module_)->GetFunction(0, kernel_fn_name); + auto fn_kernel = reinterpret_cast(cuda_module_) + ->GetFunction(0, kernel_fn_name); CHECK(fn_kernel); kernel_handles_.push_back(fn_kernel); - backends::GlobalSymbolRegistry::Global().RegisterFn(kernel_fn_name + "_ptr_", - reinterpret_cast(&kernel_handles_.back())); + backends::GlobalSymbolRegistry::Global().RegisterFn( + kernel_fn_name + "_ptr_", + reinterpret_cast(&kernel_handles_.back())); } jit_ = backends::SimpleJIT::Create(); @@ -68,20 +73,26 @@ void* CudaModuleTester::CreateDeviceBuffer(const cinn_buffer_t* host_buffer) { CUdeviceptr data; cuMemAlloc(&data, num_bytes); - CUDA_CALL(cudaMemcpy(reinterpret_cast(data), host_buffer->memory, num_bytes, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy(reinterpret_cast(data), + host_buffer->memory, + num_bytes, + cudaMemcpyHostToDevice)); return reinterpret_cast(data); } CudaModuleTester::CudaModuleTester() {} -void CudaModuleTester::operator()(const std::string& fn_name, void* args, int arg_num) { - auto fn = jit_->Lookup(fn_name); +void CudaModuleTester::operator()(const std::string& fn_name, + void* args, + int arg_num) { + auto fn = jit_->Lookup(fn_name); auto fnp = reinterpret_cast(fn); (*fnp)(args, arg_num, stream_); } void* CudaModuleTester::LookupKernel(const std::string& name) { - return reinterpret_cast(cuda_module_)->GetFunction(0, name); + return reinterpret_cast(cuda_module_) + ->GetFunction(0, name); } CudaModuleTester::~CudaModuleTester() { diff --git a/paddle/cinn/common/debug_manager.cc b/paddle/cinn/common/debug_manager.cc index 9a1642e5c82c6..aad6b48517481 100644 --- a/paddle/cinn/common/debug_manager.cc +++ b/paddle/cinn/common/debug_manager.cc @@ -17,34 +17,41 @@ namespace cinn { namespace common { -inline std::vector> &GetVec(absl::any &data) { // NOLINT +inline std::vector> &GetVec( + absl::any &data) { // NOLINT return absl::any_cast> &>(data); } //! AppendTypeSuffix for multiple types. // @{ template <> -inline std::string DebugManager::AppendTypeSuffix(const std::string &key) { +inline std::string DebugManager::AppendTypeSuffix( + const std::string &key) { return key + "_i32"; } template <> -inline std::string DebugManager::AppendTypeSuffix(const std::string &key) { +inline std::string DebugManager::AppendTypeSuffix( + const std::string &key) { return key + "_i64"; } template <> -inline std::string DebugManager::AppendTypeSuffix(const std::string &key) { +inline std::string DebugManager::AppendTypeSuffix( + const std::string &key) { return key + "_f32"; } template <> -inline std::string DebugManager::AppendTypeSuffix(const std::string &key) { +inline std::string DebugManager::AppendTypeSuffix( + const std::string &key) { return key + "_f64"; } template <> -inline std::string DebugManager::AppendTypeSuffix(const std::string &key) { +inline std::string DebugManager::AppendTypeSuffix( + const std::string &key) { return key + "_b"; } template <> -inline std::string DebugManager::AppendTypeSuffix(const std::string &key) { +inline std::string DebugManager::AppendTypeSuffix( + const std::string &key) { return key + "_s"; } // @} @@ -53,13 +60,15 @@ void DebugManager::Append(const std::string &key, absl::any value) { GetVec(data_).push_back(std::make_pair(key, value)); } void DebugManager::Append(const std::string &key, int32_t value) { - GetVec(data_).push_back(std::make_pair(AppendTypeSuffix(key), value)); + GetVec(data_).push_back( + std::make_pair(AppendTypeSuffix(key), value)); } void DebugManager::Append(const std::string &key, bool value) { GetVec(data_).push_back(std::make_pair(AppendTypeSuffix(key), value)); } void DebugManager::Append(const std::string &key, const std::string &value) { - GetVec(data_).push_back(std::make_pair(AppendTypeSuffix(key), value)); + GetVec(data_).push_back( + std::make_pair(AppendTypeSuffix(key), value)); } void DebugManager::Clear() { GetVec(data_).clear(); } diff --git a/paddle/cinn/common/debug_manager.h b/paddle/cinn/common/debug_manager.h index 934965f13c1ef..001fdb1a61267 100644 --- a/paddle/cinn/common/debug_manager.h +++ b/paddle/cinn/common/debug_manager.h @@ -24,7 +24,8 @@ namespace common { /** * Container for debug info. - * DebugManager is integrated into the global Context, and used to log something(but not print to stdout directly). + * DebugManager is integrated into the global Context, and used to log + * something(but not print to stdout directly). */ class DebugManager { public: diff --git a/paddle/cinn/common/float16.h b/paddle/cinn/common/float16.h index 4bf8c64614b17..15bd2cee3fc69 100644 --- a/paddle/cinn/common/float16.h +++ b/paddle/cinn/common/float16.h @@ -19,7 +19,8 @@ #pragma once #endif // __cplusplus -#if defined(_M_X64) || defined(__x86_64__) || defined(_M_IX86) || defined(__i386__) +#if defined(_M_X64) || defined(__x86_64__) || defined(_M_IX86) || \ + defined(__i386__) #define __CINN_x86__ #include #endif @@ -74,12 +75,12 @@ struct CINN_ALIGN(2) float16 { #ifdef __cplusplus // The following defaulted special class member functions // are added to make float16 pass the std::is_trivial test - float16() = default; + float16() = default; float16(const float16& o) = default; float16& operator=(const float16& o) = default; - float16(float16&& o) = default; + float16(float16&& o) = default; float16& operator=(float16&& o) = default; - ~float16() = default; + ~float16() = default; // Constructors #ifdef CINN_CUDA_FP16 @@ -95,7 +96,7 @@ struct CINN_ALIGN(2) float16 { __host__ __device__ inline explicit float16(float val) { #if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300) half tmp = __float2half(val); - x = *reinterpret_cast(&tmp); + x = *reinterpret_cast(&tmp); #elif defined(__F16C__) && defined(__CINN_x86__) x = _cvtss_sh(val, 0); @@ -104,7 +105,7 @@ struct CINN_ALIGN(2) float16 { // Conversion routine adapted from // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion Bits v, s; - v.f = val; + v.f = val; uint32_t sign = v.si & sigN; v.si ^= sign; sign >>= shiftSign; // logical shift @@ -124,7 +125,8 @@ struct CINN_ALIGN(2) float16 { __host__ __device__ inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {} template - __host__ __device__ inline explicit float16(const T& val) : x(float16(static_cast(val)).x) {} + __host__ __device__ inline explicit float16(const T& val) + : x(float16(static_cast(val)).x) {} // Assignment operators #ifdef CINN_CUDA_FP16 @@ -220,7 +222,7 @@ struct CINN_ALIGN(2) float16 { // Conversion routine adapted from // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion Bits v; - v.ui = this->x; + v.ui = this->x; int32_t sign = v.si & sigC; v.si ^= sign; sign <<= shiftSign; @@ -238,9 +240,13 @@ struct CINN_ALIGN(2) float16 { #endif } - __host__ __device__ inline explicit operator bool() const { return (x & 0x7fff) != 0; } + __host__ __device__ inline explicit operator bool() const { + return (x & 0x7fff) != 0; + } - __host__ __device__ inline explicit operator int8_t() const { return static_cast(static_cast(*this)); } + __host__ __device__ inline explicit operator int8_t() const { + return static_cast(static_cast(*this)); + } __host__ __device__ inline explicit operator uint8_t() const { return static_cast(static_cast(*this)); @@ -270,7 +276,9 @@ struct CINN_ALIGN(2) float16 { return static_cast(static_cast(*this)); } - __host__ __device__ inline operator double() const { return static_cast(static_cast(*this)); } + __host__ __device__ inline operator double() const { + return static_cast(static_cast(*this)); + } private: union Bits { @@ -279,7 +287,7 @@ struct CINN_ALIGN(2) float16 { uint32_t ui; }; - static const int shift = 13; + static const int shift = 13; static const int shiftSign = 16; static const int32_t infN = 0x7F800000; @@ -288,7 +296,8 @@ struct CINN_ALIGN(2) float16 { static const int32_t sigN = 0x80000000; // sign bit static constexpr int32_t infC = infN >> shift; - static constexpr int32_t nanN = (infC + 1) << shift; // minimum flt16 nan as float32 + static constexpr int32_t nanN = (infC + 1) + << shift; // minimum flt16 nan as float32 static constexpr int32_t maxC = maxN >> shift; static constexpr int32_t minC = minN >> shift; static constexpr int32_t sigC = sigN >> shiftSign; @@ -353,7 +362,7 @@ __device__ inline half operator*(const half& a, const half& b) { __device__ inline half operator/(const half& a, const half& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 - float num = __half2float(a); + float num = __half2float(a); float denom = __half2float(b); return __float2half(num / denom); #else @@ -442,7 +451,8 @@ __device__ inline bool operator>=(const half& a, const half& b) { #endif // CINN_CUDA_FP16 // Arithmetic operators for float16 on GPU -__host__ __device__ inline float16 operator+(const float16& a, const float16& b) { +__host__ __device__ inline float16 operator+(const float16& a, + const float16& b) { #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hadd(a.to_half(), b.to_half())); #else @@ -450,7 +460,8 @@ __host__ __device__ inline float16 operator+(const float16& a, const float16& b) #endif } -__host__ __device__ inline float16 operator-(const float16& a, const float16& b) { +__host__ __device__ inline float16 operator-(const float16& a, + const float16& b) { #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hsub(a.to_half(), b.to_half())); #else @@ -458,7 +469,8 @@ __host__ __device__ inline float16 operator-(const float16& a, const float16& b) #endif } -__host__ __device__ inline float16 operator*(const float16& a, const float16& b) { +__host__ __device__ inline float16 operator*(const float16& a, + const float16& b) { #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hmul(a.to_half(), b.to_half())); #else @@ -466,10 +478,11 @@ __host__ __device__ inline float16 operator*(const float16& a, const float16& b) #endif } -__host__ __device__ inline float16 operator/(const float16& a, const float16& b) { +__host__ __device__ inline float16 operator/(const float16& a, + const float16& b) { #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 // TODO(kexinzhao): check which cuda version starts to support __hdiv - float num = __half2float(a.to_half()); + float num = __half2float(a.to_half()); float denom = __half2float(b.to_half()); return float16(num / denom); #else @@ -487,22 +500,26 @@ __host__ __device__ inline float16 operator-(const float16& a) { #endif } -__host__ __device__ inline float16& operator+=(float16& a, const float16& b) { // NOLINT +__host__ __device__ inline float16& operator+=(float16& a, + const float16& b) { // NOLINT a = a + b; return a; } -__host__ __device__ inline float16& operator-=(float16& a, const float16& b) { // NOLINT +__host__ __device__ inline float16& operator-=(float16& a, + const float16& b) { // NOLINT a = a - b; return a; } -__host__ __device__ inline float16& operator*=(float16& a, const float16& b) { // NOLINT +__host__ __device__ inline float16& operator*=(float16& a, + const float16& b) { // NOLINT a = a * b; return a; } -__host__ __device__ inline float16& operator/=(float16& a, const float16& b) { // NOLINT +__host__ __device__ inline float16& operator/=(float16& a, + const float16& b) { // NOLINT a = a / b; return a; } @@ -570,9 +587,13 @@ __host__ __device__ inline bool(isnan)(const float16& a) { #endif } -__host__ __device__ inline bool(isinf)(const float16& a) { return (a.x & 0x7fff) == 0x7c00; } +__host__ __device__ inline bool(isinf)(const float16& a) { + return (a.x & 0x7fff) == 0x7c00; +} -__host__ __device__ inline bool(isfinite)(const float16& a) { return !((isnan)(a)) && !((isinf)(a)); } +__host__ __device__ inline bool(isfinite)(const float16& a) { + return !((isnan)(a)) && !((isinf)(a)); +} __host__ __device__ inline float16(abs)(const float16& a) { #if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) @@ -582,7 +603,9 @@ __host__ __device__ inline float16(abs)(const float16& a) { #endif } -__host__ __device__ inline float16(log)(const float16& a) { return float16(std::log(static_cast(a))); } +__host__ __device__ inline float16(log)(const float16& a) { + return float16(std::log(static_cast(a))); +} #ifdef __cplusplus } // namespace common @@ -594,34 +617,43 @@ __device__ inline cinn::common::float16 __shfl_sync(unsigned mask, cinn::common::float16 var, int srcLane, int width = warpSize) { - return cinn::common::float16(__shfl_sync(mask, var.to_half(), srcLane, width)); + return cinn::common::float16( + __shfl_sync(mask, var.to_half(), srcLane, width)); } -__device__ inline cinn::common::float16 __shfl_up_sync(unsigned mask, - cinn::common::float16 var, - unsigned int delta, - int width = warpSize) { - return cinn::common::float16(__shfl_up_sync(mask, var.to_half(), delta, width)); +__device__ inline cinn::common::float16 __shfl_up_sync( + unsigned mask, + cinn::common::float16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::float16( + __shfl_up_sync(mask, var.to_half(), delta, width)); } -__device__ inline cinn::common::float16 __shfl_down_sync(unsigned mask, - cinn::common::float16 var, - unsigned int delta, - int width = warpSize) { - return cinn::common::float16(__shfl_down_sync(mask, var.to_half(), delta, width)); +__device__ inline cinn::common::float16 __shfl_down_sync( + unsigned mask, + cinn::common::float16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::float16( + __shfl_down_sync(mask, var.to_half(), delta, width)); } -__device__ inline cinn::common::float16 __shfl_xor_sync(unsigned mask, - cinn::common::float16 var, - int laneMask, - int width = warpSize) { - return cinn::common::float16(__shfl_xor_sync(mask, var.to_half(), laneMask, width)); +__device__ inline cinn::common::float16 __shfl_xor_sync( + unsigned mask, + cinn::common::float16 var, + int laneMask, + int width = warpSize) { + return cinn::common::float16( + __shfl_xor_sync(mask, var.to_half(), laneMask, width)); } -__host__ __device__ inline cinn::common::float16 max(const cinn::common::float16& a, const cinn::common::float16& b) { +__host__ __device__ inline cinn::common::float16 max( + const cinn::common::float16& a, const cinn::common::float16& b) { return a > b ? a : b; } -__host__ __device__ inline cinn::common::float16 min(const cinn::common::float16& a, const cinn::common::float16& b) { +__host__ __device__ inline cinn::common::float16 min( + const cinn::common::float16& a, const cinn::common::float16& b) { return a < b ? a : b; } #endif // __cplusplus && CINN_CUDA_FP16 diff --git a/paddle/cinn/common/float16_bfloat16_cuda_test.cu b/paddle/cinn/common/float16_bfloat16_cuda_test.cu index a7b2e82939850..932208b1a9d69 100644 --- a/paddle/cinn/common/float16_bfloat16_cuda_test.cu +++ b/paddle/cinn/common/float16_bfloat16_cuda_test.cu @@ -62,7 +62,9 @@ class CudaMem { return reinterpret_cast(data()); } - void MemcpyFromHost(const void* src, size_t bytes, cudaStream_t stream = nullptr) { + void MemcpyFromHost(const void* src, + size_t bytes, + cudaStream_t stream = nullptr) { CHECK_LE(bytes, bytes_) << "Too many data need copy"; CUDA_CALL(cudaMemcpyAsync(ptr, src, bytes, cudaMemcpyHostToDevice, stream)); } @@ -84,21 +86,28 @@ class CudaMem { size_t bytes_{0}; }; -__global__ void cast_fp32_to_fp16_cuda_kernel(const float* input, const int num, float16* out) { +__global__ void cast_fp32_to_fp16_cuda_kernel(const float* input, + const int num, + float16* out) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < num) { out[idx] = float16(input[idx]); } } -__global__ void cast_fp16_to_fp32_cuda_kernel(const float16* input, const int num, float* out) { +__global__ void cast_fp16_to_fp32_cuda_kernel(const float16* input, + const int num, + float* out) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < num) { out[idx] = float(input[idx]); } } -__global__ void test_fp16_cuda_kernel(const float16* x, const float16* y, const int num, float16* out) { +__global__ void test_fp16_cuda_kernel(const float16* x, + const float16* y, + const int num, + float16* out) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < num) { float16 x_i = x[idx], y_i = y[idx]; @@ -108,21 +117,28 @@ __global__ void test_fp16_cuda_kernel(const float16* x, const float16* y, const } } -__global__ void cast_fp32_to_bf16_cuda_kernel(const float* input, const int num, bfloat16* out) { +__global__ void cast_fp32_to_bf16_cuda_kernel(const float* input, + const int num, + bfloat16* out) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < num) { out[idx] = bfloat16(input[idx]); } } -__global__ void cast_bf16_to_fp32_cuda_kernel(const bfloat16* input, const int num, float* out) { +__global__ void cast_bf16_to_fp32_cuda_kernel(const bfloat16* input, + const int num, + float* out) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < num) { out[idx] = float(input[idx]); } } -__global__ void test_bf16_cuda_kernel(const bfloat16* x, const bfloat16* y, const int num, bfloat16* out) { +__global__ void test_bf16_cuda_kernel(const bfloat16* x, + const bfloat16* y, + const int num, + bfloat16* out) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < num) { bfloat16 x_i = x[idx], y_i = y[idx]; @@ -132,7 +148,10 @@ __global__ void test_bf16_cuda_kernel(const bfloat16* x, const bfloat16* y, cons } } -__global__ void test_fp32_cuda_kernel(const float* x, const float* y, const int num, float* out) { +__global__ void test_fp32_cuda_kernel(const float* x, + const float* y, + const int num, + float* out) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < num) { float x_i = x[idx], y_i = y[idx]; @@ -153,7 +172,7 @@ TEST(FP16_BF16, basic_cuda) { CUDA_CALL(cudaStreamCreate(&stream)); dim3 block = 1024; - dim3 grid = (num + block.x - 1) / block.x; + dim3 grid = (num + block.x - 1) / block.x; std::vector x_fp32_host(num), y_fp32_host(num); { // step1 : generate input data @@ -169,38 +188,47 @@ TEST(FP16_BF16, basic_cuda) { CudaMem x_fp32_device, y_fp32_device, out_fp32_device; { // step2 : compute fp32 result - auto x_fp32_ptr = x_fp32_device.mutable_data(num); - auto y_fp32_ptr = y_fp32_device.mutable_data(num); + auto x_fp32_ptr = x_fp32_device.mutable_data(num); + auto y_fp32_ptr = y_fp32_device.mutable_data(num); auto out_fp32_ptr = out_fp32_device.mutable_data(num); - x_fp32_device.MemcpyFromHost(x_fp32_host.data(), num * sizeof(float), stream); - y_fp32_device.MemcpyFromHost(y_fp32_host.data(), num * sizeof(float), stream); + x_fp32_device.MemcpyFromHost( + x_fp32_host.data(), num * sizeof(float), stream); + y_fp32_device.MemcpyFromHost( + y_fp32_host.data(), num * sizeof(float), stream); - test_fp32_cuda_kernel<<>>(x_fp32_ptr, y_fp32_ptr, num, out_fp32_ptr); + test_fp32_cuda_kernel<<>>( + x_fp32_ptr, y_fp32_ptr, num, out_fp32_ptr); } CudaMem x_fp16_device, y_fp16_device, out_fp16_device; CudaMem x_bf16_device, y_bf16_device, out_bf16_device; { // step3 : compute fp16/bf16 result // step3.1 : compute fp16 result - auto x_fp16_ptr = x_fp16_device.mutable_data(num); - auto y_fp16_ptr = y_fp16_device.mutable_data(num); + auto x_fp16_ptr = x_fp16_device.mutable_data(num); + auto y_fp16_ptr = y_fp16_device.mutable_data(num); auto out_fp16_ptr = out_fp16_device.mutable_data(num); - cast_fp32_to_fp16_cuda_kernel<<>>(x_fp32_device.data(), num, x_fp16_ptr); - cast_fp32_to_fp16_cuda_kernel<<>>(y_fp32_device.data(), num, y_fp16_ptr); + cast_fp32_to_fp16_cuda_kernel<<>>( + x_fp32_device.data(), num, x_fp16_ptr); + cast_fp32_to_fp16_cuda_kernel<<>>( + y_fp32_device.data(), num, y_fp16_ptr); - test_fp16_cuda_kernel<<>>(x_fp16_ptr, y_fp16_ptr, num, out_fp16_ptr); + test_fp16_cuda_kernel<<>>( + x_fp16_ptr, y_fp16_ptr, num, out_fp16_ptr); // step3.2 : compute bf16 result - auto x_bf16_ptr = x_bf16_device.mutable_data(num); - auto y_bf16_ptr = y_bf16_device.mutable_data(num); + auto x_bf16_ptr = x_bf16_device.mutable_data(num); + auto y_bf16_ptr = y_bf16_device.mutable_data(num); auto out_bf16_ptr = out_bf16_device.mutable_data(num); - cast_fp32_to_bf16_cuda_kernel<<>>(x_fp32_device.data(), num, x_bf16_ptr); - cast_fp32_to_bf16_cuda_kernel<<>>(y_fp32_device.data(), num, y_bf16_ptr); + cast_fp32_to_bf16_cuda_kernel<<>>( + x_fp32_device.data(), num, x_bf16_ptr); + cast_fp32_to_bf16_cuda_kernel<<>>( + y_fp32_device.data(), num, y_bf16_ptr); - test_bf16_cuda_kernel<<>>(x_bf16_ptr, y_bf16_ptr, num, out_bf16_ptr); + test_bf16_cuda_kernel<<>>( + x_bf16_ptr, y_bf16_ptr, num, out_bf16_ptr); } CudaMem fp32res_fp16_device; @@ -208,18 +236,23 @@ TEST(FP16_BF16, basic_cuda) { { // step4 : cast fp16/bf16 result to fp32 result // step4.1 : cast fp16 result to fp32 result auto fp32res_fp16_ptr = fp32res_fp16_device.mutable_data(num); - cast_fp16_to_fp32_cuda_kernel<<>>(out_fp16_device.data(), num, fp32res_fp16_ptr); + cast_fp16_to_fp32_cuda_kernel<<>>( + out_fp16_device.data(), num, fp32res_fp16_ptr); // step4.2 : cast bf16 result to fp32 result auto fp32res_bf16_ptr = fp32res_bf16_device.mutable_data(num); - cast_bf16_to_fp32_cuda_kernel<<>>(out_bf16_device.data(), num, fp32res_bf16_ptr); + cast_bf16_to_fp32_cuda_kernel<<>>( + out_bf16_device.data(), num, fp32res_bf16_ptr); } std::vector out_fp32_host(num), out_fp16_host(num), out_bf16_host(num); { // step5 : copy result from device to host - out_fp32_device.MemcpyToHost(out_fp32_host.data(), num * sizeof(float), stream); - fp32res_fp16_device.MemcpyToHost(out_fp16_host.data(), num * sizeof(float), stream); - fp32res_bf16_device.MemcpyToHost(out_bf16_host.data(), num * sizeof(float), stream); + out_fp32_device.MemcpyToHost( + out_fp32_host.data(), num * sizeof(float), stream); + fp32res_fp16_device.MemcpyToHost( + out_fp16_host.data(), num * sizeof(float), stream); + fp32res_bf16_device.MemcpyToHost( + out_bf16_host.data(), num * sizeof(float), stream); } CUDA_CALL(cudaStreamSynchronize(stream)); diff --git a/paddle/cinn/common/float16_bfloat16_host_test.cc b/paddle/cinn/common/float16_bfloat16_host_test.cc index 5072501c095c9..3a4cc905a8798 100644 --- a/paddle/cinn/common/float16_bfloat16_host_test.cc +++ b/paddle/cinn/common/float16_bfloat16_host_test.cc @@ -24,7 +24,9 @@ namespace cinn { namespace common { -std::vector test_fp16_host_kernel(const float16* x, const float16* y, const int num) { +std::vector test_fp16_host_kernel(const float16* x, + const float16* y, + const int num) { std::vector out(num); for (int idx = 0; idx < num; ++idx) { float16 x_i = x[idx], y_i = y[idx]; @@ -35,7 +37,9 @@ std::vector test_fp16_host_kernel(const float16* x, const float16* y, c return out; } -std::vector test_bf16_host_kernel(const bfloat16* x, const bfloat16* y, const int num) { +std::vector test_bf16_host_kernel(const bfloat16* x, + const bfloat16* y, + const int num) { std::vector out(num); for (int idx = 0; idx < num; ++idx) { bfloat16 x_i = x[idx], y_i = y[idx]; @@ -46,7 +50,9 @@ std::vector test_bf16_host_kernel(const bfloat16* x, const bfloat16* y return out; } -std::vector test_fp32_host_kernel(const float* x, const float* y, const int num) { +std::vector test_fp32_host_kernel(const float* x, + const float* y, + const int num) { std::vector out(num); for (int idx = 0; idx < num; ++idx) { float x_i = x[idx], y_i = y[idx]; diff --git a/paddle/cinn/common/float16_bfloat16_utils.h b/paddle/cinn/common/float16_bfloat16_utils.h index b80c6a518e6ef..7801eb8ec2518 100644 --- a/paddle/cinn/common/float16_bfloat16_utils.h +++ b/paddle/cinn/common/float16_bfloat16_utils.h @@ -33,15 +33,17 @@ namespace std { // for float16 template <> struct is_pod { - static const bool value = - is_trivial::value && is_standard_layout::value; + static const bool value = is_trivial::value && + is_standard_layout::value; }; template <> struct is_floating_point : std::integral_constant< bool, - std::is_same::type>::value> {}; + std::is_same< + cinn::common::float16, + typename std::remove_cv::type>::value> {}; template <> struct is_signed { static const bool value = true; @@ -52,65 +54,92 @@ struct is_unsigned { static const bool value = false; }; -__host__ __device__ inline cinn::common::float16 abs(const cinn::common::float16& a) { return cinn::common::abs(a); } +__host__ __device__ inline cinn::common::float16 abs( + const cinn::common::float16& a) { + return cinn::common::abs(a); +} -inline bool isnan(const cinn::common::float16& a) { return cinn::common::isnan(a); } +inline bool isnan(const cinn::common::float16& a) { + return cinn::common::isnan(a); +} -inline bool isinf(const cinn::common::float16& a) { return cinn::common::isinf(a); } +inline bool isinf(const cinn::common::float16& a) { + return cinn::common::isinf(a); +} -inline bool isfinite(const cinn::common::float16& a) { return cinn::common::isfinite(a); } +inline bool isfinite(const cinn::common::float16& a) { + return cinn::common::isfinite(a); +} template <> struct numeric_limits { - static const bool is_specialized = true; - static const bool is_signed = true; - static const bool is_integer = false; - static const bool is_exact = false; - static const bool has_infinity = true; - static const bool has_quiet_NaN = true; - static const bool has_signaling_NaN = true; - static const float_denorm_style has_denorm = denorm_present; - static const bool has_denorm_loss = false; + static const bool is_specialized = true; + static const bool is_signed = true; + static const bool is_integer = false; + static const bool is_exact = false; + static const bool has_infinity = true; + static const bool has_quiet_NaN = true; + static const bool has_signaling_NaN = true; + static const float_denorm_style has_denorm = denorm_present; + static const bool has_denorm_loss = false; static const std::float_round_style round_style = std::round_to_nearest; - static const bool is_iec559 = false; - static const bool is_bounded = false; - static const bool is_modulo = false; - static const int digits = 11; - static const int digits10 = 3; - static const int max_digits10 = 5; - static const int radix = 2; - static const int min_exponent = -13; - static const int min_exponent10 = -4; - static const int max_exponent = 16; - static const int max_exponent10 = 4; - static const bool traps = true; - static const bool tinyness_before = false; - - __host__ __device__ static cinn::common::float16(min)() { return cinn::common::raw_uint16_to_float16(0x400); } - __host__ __device__ static cinn::common::float16 lowest() { return cinn::common::raw_uint16_to_float16(0xfbff); } - __host__ __device__ static cinn::common::float16(max)() { return cinn::common::raw_uint16_to_float16(0x7bff); } - __host__ __device__ static cinn::common::float16 epsilon() { return cinn::common::raw_uint16_to_float16(0x0800); } - __host__ __device__ static cinn::common::float16 round_error() { return cinn::common::float16(0.5); } - __host__ __device__ static cinn::common::float16 infinity() { return cinn::common::raw_uint16_to_float16(0x7c00); } - __host__ __device__ static cinn::common::float16 quiet_NaN() { return cinn::common::raw_uint16_to_float16(0x7e00); } + static const bool is_iec559 = false; + static const bool is_bounded = false; + static const bool is_modulo = false; + static const int digits = 11; + static const int digits10 = 3; + static const int max_digits10 = 5; + static const int radix = 2; + static const int min_exponent = -13; + static const int min_exponent10 = -4; + static const int max_exponent = 16; + static const int max_exponent10 = 4; + static const bool traps = true; + static const bool tinyness_before = false; + + __host__ __device__ static cinn::common::float16(min)() { + return cinn::common::raw_uint16_to_float16(0x400); + } + __host__ __device__ static cinn::common::float16 lowest() { + return cinn::common::raw_uint16_to_float16(0xfbff); + } + __host__ __device__ static cinn::common::float16(max)() { + return cinn::common::raw_uint16_to_float16(0x7bff); + } + __host__ __device__ static cinn::common::float16 epsilon() { + return cinn::common::raw_uint16_to_float16(0x0800); + } + __host__ __device__ static cinn::common::float16 round_error() { + return cinn::common::float16(0.5); + } + __host__ __device__ static cinn::common::float16 infinity() { + return cinn::common::raw_uint16_to_float16(0x7c00); + } + __host__ __device__ static cinn::common::float16 quiet_NaN() { + return cinn::common::raw_uint16_to_float16(0x7e00); + } __host__ __device__ static cinn::common::float16 signaling_NaN() { return cinn::common::raw_uint16_to_float16(0x7e00); } - __host__ __device__ static cinn::common::float16 denorm_min() { return cinn::common::raw_uint16_to_float16(0x1); } + __host__ __device__ static cinn::common::float16 denorm_min() { + return cinn::common::raw_uint16_to_float16(0x1); + } }; // for bfloat16 template <> struct is_pod { - static const bool value = - is_trivial::value && is_standard_layout::value; + static const bool value = is_trivial::value && + is_standard_layout::value; }; template <> struct is_floating_point : std::integral_constant< bool, - std::is_same::type>::value> {}; + std::is_same< + cinn::common::bfloat16, + typename std::remove_cv::type>::value> {}; template <> struct is_signed { static const bool value = true; @@ -121,43 +150,61 @@ struct is_unsigned { static const bool value = false; }; -inline bool isnan(const cinn::common::bfloat16& a) { return cinn::common::isnan(a); } +inline bool isnan(const cinn::common::bfloat16& a) { + return cinn::common::isnan(a); +} -inline bool isinf(const cinn::common::bfloat16& a) { return cinn::common::isinf(a); } +inline bool isinf(const cinn::common::bfloat16& a) { + return cinn::common::isinf(a); +} template <> struct numeric_limits { - static const bool is_specialized = true; - static const bool is_signed = true; - static const bool is_integer = false; - static const bool is_exact = false; - static const bool has_infinity = true; - static const bool has_quiet_NaN = true; - static const bool has_signaling_NaN = true; - static const float_denorm_style has_denorm = denorm_present; - static const bool has_denorm_loss = false; + static const bool is_specialized = true; + static const bool is_signed = true; + static const bool is_integer = false; + static const bool is_exact = false; + static const bool has_infinity = true; + static const bool has_quiet_NaN = true; + static const bool has_signaling_NaN = true; + static const float_denorm_style has_denorm = denorm_present; + static const bool has_denorm_loss = false; static const std::float_round_style round_style = std::round_to_nearest; - static const bool is_iec559 = false; - static const bool is_bounded = false; - static const bool is_modulo = false; - static const int digits = 8; - static const int digits10 = 2; - static const int max_digits10 = 9; - static const int radix = 2; - static const int min_exponent = -125; - static const int min_exponent10 = -37; - static const int max_exponent = 128; - static const int max_exponent10 = 38; - static const bool traps = true; - static const bool tinyness_before = false; - - __host__ __device__ static cinn::common::bfloat16(min)() { return cinn::common::raw_uint16_to_bfloat16(0x007f); } - __host__ __device__ static cinn::common::bfloat16 lowest() { return cinn::common::raw_uint16_to_bfloat16(0xff7f); } - __host__ __device__ static cinn::common::bfloat16(max)() { return cinn::common::raw_uint16_to_bfloat16(0x7f7f); } - __host__ __device__ static cinn::common::bfloat16 epsilon() { return cinn::common::raw_uint16_to_bfloat16(0x3400); } - __host__ __device__ static cinn::common::bfloat16 round_error() { return cinn::common::bfloat16(0.5); } - __host__ __device__ static cinn::common::bfloat16 infinity() { return cinn::common::raw_uint16_to_bfloat16(0x7f80); } - __host__ __device__ static cinn::common::bfloat16 quiet_NaN() { return cinn::common::raw_uint16_to_bfloat16(0xffc1); } + static const bool is_iec559 = false; + static const bool is_bounded = false; + static const bool is_modulo = false; + static const int digits = 8; + static const int digits10 = 2; + static const int max_digits10 = 9; + static const int radix = 2; + static const int min_exponent = -125; + static const int min_exponent10 = -37; + static const int max_exponent = 128; + static const int max_exponent10 = 38; + static const bool traps = true; + static const bool tinyness_before = false; + + __host__ __device__ static cinn::common::bfloat16(min)() { + return cinn::common::raw_uint16_to_bfloat16(0x007f); + } + __host__ __device__ static cinn::common::bfloat16 lowest() { + return cinn::common::raw_uint16_to_bfloat16(0xff7f); + } + __host__ __device__ static cinn::common::bfloat16(max)() { + return cinn::common::raw_uint16_to_bfloat16(0x7f7f); + } + __host__ __device__ static cinn::common::bfloat16 epsilon() { + return cinn::common::raw_uint16_to_bfloat16(0x3400); + } + __host__ __device__ static cinn::common::bfloat16 round_error() { + return cinn::common::bfloat16(0.5); + } + __host__ __device__ static cinn::common::bfloat16 infinity() { + return cinn::common::raw_uint16_to_bfloat16(0x7f80); + } + __host__ __device__ static cinn::common::bfloat16 quiet_NaN() { + return cinn::common::raw_uint16_to_bfloat16(0xffc1); + } __host__ __device__ static cinn::common::bfloat16 signaling_NaN() { return cinn::common::raw_uint16_to_bfloat16(0xff81); } diff --git a/paddle/cinn/common/graph_utils.cc b/paddle/cinn/common/graph_utils.cc index 1ec34f5262311..a2b6861b899b4 100755 --- a/paddle/cinn/common/graph_utils.cc +++ b/paddle/cinn/common/graph_utils.cc @@ -38,7 +38,8 @@ std::vector DFSSort(const std::vector &nodes) { } // namespace -std::set Graph::dependencies(const std::vector &targets) { +std::set Graph::dependencies( + const std::vector &targets) { // A naive implementation. std::set _targets(targets.begin(), targets.end()); std::set res; @@ -69,7 +70,8 @@ std::vector Graph::nodes() { return res; } -std::tuple, std::vector> Graph::topological_order() const { +std::tuple, std::vector> +Graph::topological_order() const { std::vector node_order; std::vector edge_order; std::deque queue; @@ -105,12 +107,16 @@ std::tuple, std::vector> Graph::topologica } } - CHECK_EQ(node_order.size(), nodes().size()) << "circle detected in the schedule graph:\n\n" << Visualize(); + CHECK_EQ(node_order.size(), nodes().size()) + << "circle detected in the schedule graph:\n\n" + << Visualize(); return std::make_tuple(node_order, edge_order); } -std::vector Graph::dfs_order() { return std::vector(); } +std::vector Graph::dfs_order() { + return std::vector(); +} std::vector Graph::start_points() const { std::vector res; @@ -143,7 +149,9 @@ GraphNode *Graph::RetrieveNode(size_t key) const { return it == registry_.end() ? nullptr : it->second; } -GraphNode *Graph::RetrieveNode(const std::string &key) const { return RetrieveNode(std::hash()(key)); } +GraphNode *Graph::RetrieveNode(const std::string &key) const { + return RetrieveNode(std::hash()(key)); +} std::string Graph::Visualize() const { utils::DotLang dot; @@ -163,9 +171,10 @@ std::string Graph::Visualize() const { return dot(); } -void Graph::ClearUnlinkedNodes(absl::flat_hash_map> *shape_dict, - absl::flat_hash_map *type_dict, - absl::flat_hash_map *layout_dict) { +void Graph::ClearUnlinkedNodes( + absl::flat_hash_map> *shape_dict, + absl::flat_hash_map *type_dict, + absl::flat_hash_map *layout_dict) { CHECK(shape_dict); CHECK(type_dict); CHECK(layout_dict); @@ -190,7 +199,8 @@ void Graph::ClearUnlinkedNodes(absl::flat_hash_map const char *GraphNode::__type_info__ = "GraphNode"; -bool GraphEdgeCompare::operator()(const Shared &a, const Shared &b) const { +bool GraphEdgeCompare::operator()(const Shared &a, + const Shared &b) const { if (a->source()->id() == b->source()->id()) { if (a->sink()->id() == b->sink()->id()) { return a->index() < b->index(); @@ -200,7 +210,8 @@ bool GraphEdgeCompare::operator()(const Shared &a, const Sharedsource()->id() < b->source()->id(); } -std::set Graph::CollectNodes(std::function &&teller) { +std::set Graph::CollectNodes( + std::function &&teller) { std::set res; for (auto *node : nodes()) { if (teller(node)) res.insert(node); diff --git a/paddle/cinn/common/graph_utils.h b/paddle/cinn/common/graph_utils.h index cd01be7d44bb5..cb144e1c901c7 100644 --- a/paddle/cinn/common/graph_utils.h +++ b/paddle/cinn/common/graph_utils.h @@ -46,7 +46,8 @@ class GraphNode; */ class GraphEdge : public Object { public: - GraphEdge(GraphNode* source, GraphNode* sink, int index = -1) : source_(source), sink_(sink), index_(index) {} + GraphEdge(GraphNode* source, GraphNode* sink, int index = -1) + : source_(source), sink_(sink), index_(index) {} GraphNode* source() const { return source_; } GraphNode* sink() const { return sink_; } @@ -65,7 +66,8 @@ class GraphEdge : public Object { }; struct GraphEdgeCompare { - bool operator()(const common::Shared& a, const common::Shared& b) const; + bool operator()(const common::Shared& a, + const common::Shared& b) const; }; /** @@ -86,7 +88,8 @@ class GraphNode : public Object { CHECK(other); CHECK_NE(other, this) << "Cannot link to itself"; auto outlink_edge = make_shared(this, other, index_outlinks); - auto inlink_edge = make_shared(this, other, other->index_inlinks); + auto inlink_edge = + make_shared(this, other, other->index_inlinks); index_outlinks++; other->index_inlinks++; outlinks_.insert(outlink_edge); @@ -111,7 +114,7 @@ class GraphNode : public Object { void Controls(GraphNode* other) { bool outlink_linked = false; - bool inlink_linked = false; + bool inlink_linked = false; for (auto& item : outlinks_) { if (item->sink()->id() == other->id()) { outlink_linked = true; @@ -135,26 +138,33 @@ class GraphNode : public Object { if (other == this) return; // remove all this node's outlink { - auto it = std::find_if(outlinks_.begin(), outlinks_.end(), [&](const Shared& x) { - return x->source() == this && x->sink() == other; - }); + auto it = std::find_if( + outlinks_.begin(), outlinks_.end(), [&](const Shared& x) { + return x->source() == this && x->sink() == other; + }); while (it != outlinks_.end()) { outlinks_.erase(it); - it = std::find_if(outlinks_.begin(), outlinks_.end(), [&](const Shared& x) { - return x->source() == this && x->sink() == other; - }); + it = std::find_if(outlinks_.begin(), + outlinks_.end(), + [&](const Shared& x) { + return x->source() == this && x->sink() == other; + }); } } // remove all other node's inlink { - auto it = std::find_if(other->inlinks_.begin(), other->inlinks_.end(), [&](const Shared& x) { - return x->source() == this && x->sink() == other; - }); + auto it = std::find_if(other->inlinks_.begin(), + other->inlinks_.end(), + [&](const Shared& x) { + return x->source() == this && x->sink() == other; + }); while (it != other->inlinks_.end()) { other->inlinks_.erase(it); - it = std::find_if(other->inlinks_.begin(), other->inlinks_.end(), [&](const Shared& x) { - return x->source() == this && x->sink() == other; - }); + it = std::find_if(other->inlinks_.begin(), + other->inlinks_.end(), + [&](const Shared& x) { + return x->source() == this && x->sink() == other; + }); } } } @@ -163,16 +173,19 @@ class GraphNode : public Object { if (other == this) return; // remove single outlink { - auto it = std::find_if(outlinks_.begin(), outlinks_.end(), [&](const Shared& x) { - return x->source() == this && x->sink() == other; - }); + auto it = std::find_if( + outlinks_.begin(), outlinks_.end(), [&](const Shared& x) { + return x->source() == this && x->sink() == other; + }); if (it != outlinks_.end()) outlinks_.erase(it); } // remove single inlink { - auto it = std::find_if(other->inlinks_.begin(), other->inlinks_.end(), [&](const Shared& x) { - return x->source() == this && x->sink() == other; - }); + auto it = std::find_if(other->inlinks_.begin(), + other->inlinks_.end(), + [&](const Shared& x) { + return x->source() == this && x->sink() == other; + }); if (it != other->inlinks_.end()) other->inlinks_.erase(it); } } @@ -185,14 +198,21 @@ class GraphNode : public Object { } //! Get the input links of the node. - virtual const std::set, GraphEdgeCompare>& inlinks() const { return inlinks_; } + virtual const std::set, GraphEdgeCompare>& inlinks() const { + return inlinks_; + } //! Get the output links of the node. - virtual const std::set, GraphEdgeCompare>& outlinks() const { return outlinks_; } + virtual const std::set, GraphEdgeCompare>& outlinks() + const { + return outlinks_; + } //! Reset graph traversal meta info. void ResetVisitMeta() { visited_time_ = 0; } void VisitOnce() const { visited_time_++; } - bool visited() const { return inlinks_.empty() || visited_time_ == inlinks_.size(); } + bool visited() const { + return inlinks_.empty() || visited_time_ == inlinks_.size(); + } const char* type_info() const override { return __type_info__; } @@ -202,10 +222,12 @@ class GraphNode : public Object { protected: //! The input links of the node. - //! \note We record the raw pointer rather than the shared pointer to avoid cycle reference. + //! \note We record the raw pointer rather than the shared pointer to avoid + //! cycle reference. std::set, GraphEdgeCompare> inlinks_; //! The output links of the node. - //! \note We record the raw pointer rather than the shared pointer to avoid cycle reference. + //! \note We record the raw pointer rather than the shared pointer to avoid + //! cycle reference. std::set, GraphEdgeCompare> outlinks_; mutable int visited_time_{}; @@ -240,7 +262,8 @@ class Graph { std::vector start_points(); //! Return the graph's nodes and edges(visited) in topological order. - std::tuple, std::vector> topological_order() const; + std::tuple, std::vector> + topological_order() const; //! Return the graph's DFS order. std::vector dfs_order(); @@ -252,10 +275,12 @@ class Graph { std::vector nodes(); //! Collect the nodes match the condition defined by \p teller in the graph. - std::set CollectNodes(std::function&& teller); + std::set CollectNodes( + std::function&& teller); void DropNode(GraphNode* n) { - auto it = std::find_if(nodes_.begin(), nodes_.end(), [&](auto& x) { return x.get() == n; }); + auto it = std::find_if( + nodes_.begin(), nodes_.end(), [&](auto& x) { return x.get() == n; }); if (it != nodes_.end()) { nodes_.erase(it); } @@ -264,14 +289,16 @@ class Graph { //! Get a string representation to visualize a graph. std::string Visualize() const; - void ClearUnlinkedNodes(absl::flat_hash_map>* shape_dict, - absl::flat_hash_map* type_dict, - absl::flat_hash_map* layout_dict); + void ClearUnlinkedNodes( + absl::flat_hash_map>* shape_dict, + absl::flat_hash_map* type_dict, + absl::flat_hash_map* layout_dict); size_t num_nodes() const { return nodes_.size(); } protected: - //! A lookup table that map from hash key to graph node, note that it doesn't own the graph node. + //! A lookup table that map from hash key to graph node, note that it doesn't + //! own the graph node. std::map registry_; //! A list owns the graph nodes. std::vector> nodes_; @@ -283,7 +310,9 @@ class Graph { namespace std { template <> struct hash { - size_t operator()(const cinn::common::GraphNode& x) { return reinterpret_cast(hash()(x.id())); } + size_t operator()(const cinn::common::GraphNode& x) { + return reinterpret_cast(hash()(x.id())); + } }; } // namespace std diff --git a/paddle/cinn/common/graph_utils_test.cc b/paddle/cinn/common/graph_utils_test.cc index 228abaa7f7894..69490214a920b 100644 --- a/paddle/cinn/common/graph_utils_test.cc +++ b/paddle/cinn/common/graph_utils_test.cc @@ -80,7 +80,8 @@ TEST(Graph, simple) { LOG(INFO) << "graph1 " << graph->Visualize(); - std::vector node_order_target({graph->RetrieveNode("B"), graph->RetrieveNode("A")}); + std::vector node_order_target( + {graph->RetrieveNode("B"), graph->RetrieveNode("A")}); ASSERT_EQ(node_order.size(), node_order_target.size()); for (int i = 0; i < node_order.size(); i++) { diff --git a/paddle/cinn/common/ir_util.cc b/paddle/cinn/common/ir_util.cc index a42186c023717..c590b9443905a 100755 --- a/paddle/cinn/common/ir_util.cc +++ b/paddle/cinn/common/ir_util.cc @@ -70,7 +70,7 @@ Expr RampRelatedAdd(ir::Ramp *ramp, ir::Ramp *other) { CHECK(ramp); CHECK(other); if (ramp->lanes == other->lanes) { - Expr base_add = common::AutoSimplify(ramp->base + other->base); + Expr base_add = common::AutoSimplify(ramp->base + other->base); Expr stride_add = common::AutoSimplify(ramp->stride + other->stride); VLOG(2) << base_add; VLOG(2) << stride_add; @@ -81,15 +81,16 @@ Expr RampRelatedAdd(ir::Ramp *ramp, ir::Ramp *other) { } Expr RampRelatedAdd(Expr a, Expr b) { - auto *a_ramp = a.As(); - auto *b_ramp = b.As(); + auto *a_ramp = a.As(); + auto *b_ramp = b.As(); auto *a_broadcast = a.As(); auto *b_broadcast = b.As(); if (a_ramp && !b_ramp && (b->type().lanes() == 1 || b_broadcast)) { return RampRelatedAdd(a_ramp, b); } else if (!a_ramp && b_ramp && (a->type().lanes() == 1 || a_broadcast)) { return RampRelatedAdd(b_ramp, a); - } else if (!a_ramp && !b_ramp && !a->type().is_vector() && !b->type().is_vector()) { + } else if (!a_ramp && !b_ramp && !a->type().is_vector() && + !b->type().is_vector()) { return a + b; } else if (a_ramp && b_ramp) { // a_ramp && b_ramp return RampRelatedAdd(a_ramp, b_ramp); @@ -99,22 +100,24 @@ Expr RampRelatedAdd(Expr a, Expr b) { return RampRelatedAdd(b_broadcast, a); } else if (a_broadcast && b_broadcast) { CHECK_EQ(a_broadcast->lanes, b_broadcast->lanes); - return ir::Broadcast::Make(a_broadcast->value + b_broadcast->value, a_broadcast->lanes); + return ir::Broadcast::Make(a_broadcast->value + b_broadcast->value, + a_broadcast->lanes); } else { CINN_NOT_IMPLEMENTED } } Expr RampRelatedMul(Expr a, Expr b) { - auto *a_ramp = a.As(); - auto *b_ramp = b.As(); + auto *a_ramp = a.As(); + auto *b_ramp = b.As(); auto *a_broadcast = a.As(); auto *b_broadcast = b.As(); if (a_ramp && !b_ramp && (!b->type().is_vector() || b_broadcast)) { return RampRelatedMul(a_ramp, b); } else if (!a_ramp && b_ramp && (a->type().is_vector() || a_broadcast)) { return RampRelatedMul(b_ramp, a); - } else if (!a_ramp && !b_ramp && !a->type().is_vector() && !b->type().is_vector()) { + } else if (!a_ramp && !b_ramp && !a->type().is_vector() && + !b->type().is_vector()) { return a * b; } else if (a_ramp && b_ramp) { // a_ramp && b_ramp return RampRelatedMul(a_ramp, b_ramp); @@ -124,7 +127,8 @@ Expr RampRelatedMul(Expr a, Expr b) { return RampRelatedMul(b_broadcast, a); } else if (a_broadcast && b_broadcast) { CHECK_EQ(a_broadcast->lanes, b_broadcast->lanes); - return ir::Broadcast::Make(a_broadcast->value * b_broadcast->value, a_broadcast->lanes); + return ir::Broadcast::Make(a_broadcast->value * b_broadcast->value, + a_broadcast->lanes); } else { VLOG(3) << "a,b: " << a << " " << b; CINN_NOT_IMPLEMENTED @@ -133,7 +137,8 @@ Expr RampRelatedMul(Expr a, Expr b) { } // namespace -Expr IndiceToAbsOffset(const std::vector &shape, const std::vector &indices) { +Expr IndiceToAbsOffset(const std::vector &shape, + const std::vector &indices) { VLOG(3) << "Begin IndiceToAbsOffset"; VLOG(3) << "shape is : " << utils::Join(shape, ","); VLOG(3) << "indices is : " << utils::Join(indices, ","); @@ -155,13 +160,15 @@ Expr IndiceToAbsOffset(const std::vector &shape, const std::vector & return common::AutoSimplify(res); } -Expr IndiceToAbsOffset(const std::vector &shape, const std::vector &indices) { +Expr IndiceToAbsOffset(const std::vector &shape, + const std::vector &indices) { std::vector shape_; for (int v : shape) shape_.push_back(Expr(v)); return IndiceToAbsOffset(shape, indices); } -Expr PrecedingAxisToAbsOffset(const std::vector &shape, int preceding_n_axis) { +Expr PrecedingAxisToAbsOffset(const std::vector &shape, + int preceding_n_axis) { std::vector indices; for (int i = 0; i < preceding_n_axis; i++) indices.push_back(shape[i]); return IndiceToAbsOffset(shape, indices); @@ -200,8 +207,8 @@ void Substitute(Expr *expr, const std::map &var_map) { } bool is_zero(Expr v) { - v = AutoSimplify(v); - auto *int_n = v.As(); + v = AutoSimplify(v); + auto *int_n = v.As(); auto *float_n = v.As(); if (int_n) return int_n->value == 0; @@ -216,11 +223,13 @@ Expr CastIfNeeded(Expr body, Type type) { bool MathEqual(const Expr &a, const Expr &b) { auto c = a - b; - c = AutoSimplify(c); + c = AutoSimplify(c); return is_zero(c); } -Expr select(Expr cond, Expr true_value, Expr false_value) { return ir::Select::Make(cond, true_value, false_value); } +Expr select(Expr cond, Expr true_value, Expr false_value) { + return ir::Select::Make(cond, true_value, false_value); +} Expr and_all(const std::vector &conds) { CHECK(!conds.empty()); @@ -241,7 +250,8 @@ Expr or_all(const std::vector &conds) { } void CheckTensorUniqueInExpr(Expr expr) { - auto tensor_uniq = ir::CollectIRNodes(expr, [](const Expr *x) { return x->as_tensor(); }); + auto tensor_uniq = + ir::CollectIRNodes(expr, [](const Expr *x) { return x->as_tensor(); }); absl::flat_hash_map tensor_names; for (auto &t : tensor_uniq) { auto *tp = t.as_tensor(); @@ -249,7 +259,8 @@ void CheckTensorUniqueInExpr(Expr expr) { tensor_names[tp->name] = tp; } else { CHECK_EQ(tensor_names[tp->name], tp) - << "Found tensor not unique [" << tp->name << "]\nThe original expression is \n" + << "Found tensor not unique [" << tp->name + << "]\nThe original expression is \n" << expr; } } @@ -259,8 +270,10 @@ void CheckBufferUniqueInExpr(Expr expr) { // the buffers exists in tensor and lowered functions. CheckTensorUniqueInExpr(expr); - auto tensors = ir::CollectIRNodes(expr, [](const Expr *x) { return x->as_tensor(); }); - auto funcs = ir::CollectIRNodes(expr, [](const Expr *x) { return x->as_lowered_func(); }); + auto tensors = + ir::CollectIRNodes(expr, [](const Expr *x) { return x->as_tensor(); }); + auto funcs = ir::CollectIRNodes( + expr, [](const Expr *x) { return x->as_lowered_func(); }); absl::flat_hash_map buffer_name; auto check_buffer_uniq = [&](const ir::_Buffer_ *b) { @@ -323,12 +336,14 @@ Expr cast(Expr e, Type type) { return ir::Cast::Make(type, e); } -std::vector GatherItersToTensorProducer(const std::string &target_tensor_name, Expr *expr) { +std::vector GatherItersToTensorProducer( + const std::string &target_tensor_name, Expr *expr) { struct Visitor : public ir::IRMutator<> { std::vector iters; const std::string &target_tensor_name; - explicit Visitor(const std::string &target_tensor_name) : target_tensor_name(target_tensor_name) {} + explicit Visitor(const std::string &target_tensor_name) + : target_tensor_name(target_tensor_name) {} std::vector operator()(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); @@ -339,7 +354,7 @@ std::vector GatherItersToTensorProducer(const std::string &target_t if (op->tensor.as_tensor()->name == target_tensor_name) { CHECK(iters.empty()); for (auto &e : for_stack) { - auto *for_n = e->As(); + auto *for_n = e->As(); auto *polyfor_n = e->As(); if (for_n) { iters.push_back(for_n->loop_var->name); @@ -367,7 +382,8 @@ std::vector GatherItersToTensorProducer(const std::string &target_t return Visitor(target_tensor_name)(expr); } -std::vector GetForloopStackToStore(Expr *expr, const std::string &tensor_name) { +std::vector GetForloopStackToStore(Expr *expr, + const std::string &tensor_name) { VLOG(4) << "search store " << tensor_name << " in expr:\n"; VLOG(4) << *expr; struct Mutator : public ir::IRMutator<> { @@ -376,7 +392,8 @@ std::vector GetForloopStackToStore(Expr *expr, const std::string &tensor std::string tensor_name; - explicit Mutator(const std::string &tensor_name) : tensor_name(tensor_name) {} + explicit Mutator(const std::string &tensor_name) + : tensor_name(tensor_name) {} std::vector operator()(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); @@ -397,7 +414,9 @@ std::vector GetForloopStackToStore(Expr *expr, const std::string &tensor if (!found) forloop_stack.pop_back(); } - void Visit(const ir::Store *op, Expr *expr) { found = op->tensor.as_tensor()->name == tensor_name; } + void Visit(const ir::Store *op, Expr *expr) { + found = op->tensor.as_tensor()->name == tensor_name; + } }; return Mutator(tensor_name)(expr); diff --git a/paddle/cinn/common/ir_util.h b/paddle/cinn/common/ir_util.h index 8194c27c5e226..179c5dfd0d124 100644 --- a/paddle/cinn/common/ir_util.h +++ b/paddle/cinn/common/ir_util.h @@ -27,10 +27,13 @@ namespace cinn { namespace common { -Expr IndiceToAbsOffset(const std::vector &shape, const std::vector &indices); -Expr IndiceToAbsOffset(const std::vector &shape, const std::vector &indices); +Expr IndiceToAbsOffset(const std::vector &shape, + const std::vector &indices); +Expr IndiceToAbsOffset(const std::vector &shape, + const std::vector &indices); -Expr PrecedingAxisToAbsOffset(const std::vector &shape, int preceding_n_axis); +Expr PrecedingAxisToAbsOffset(const std::vector &shape, + int preceding_n_axis); Expr CastIfNeeded(Expr body, Type type); @@ -39,8 +42,10 @@ Expr CastIfNeeded(Expr body, Type type); //! @param var_map The map from variables to the target expressions. void Substitute(Expr *expr, const std::map &var_map); -//! Get a stack of forloops(For and PolyFor nodes) to a Store node target to \p tensor_name -std::vector GetForloopStackToStore(Expr *expr, const std::string &tensor_name); +//! Get a stack of forloops(For and PolyFor nodes) to a Store node target to \p +//! tensor_name +std::vector GetForloopStackToStore(Expr *expr, + const std::string &tensor_name); // make const // @{ @@ -63,8 +68,12 @@ template inline Expr make_one() { return make_const(static_cast(1)); } -inline Expr make_bool(bool x) { return common::make_shared(Bool(), x); } -inline Expr make_bool(bool x, int lanes) { return common::make_shared(Bool(lanes), x); } +inline Expr make_bool(bool x) { + return common::make_shared(Bool(), x); +} +inline Expr make_bool(bool x, int lanes) { + return common::make_shared(Bool(lanes), x); +} // @} /** @@ -77,7 +86,8 @@ void CheckTensorUniqueInExpr(Expr expr); */ void CheckBufferUniqueInExpr(Expr expr); -std::vector GatherItersToTensorProducer(const std::string &target_tensor_name, Expr *expr); +std::vector GatherItersToTensorProducer( + const std::string &target_tensor_name, Expr *expr); bool is_zero(Expr v); @@ -103,13 +113,21 @@ template Expr make_const(Type t, T v) { if (t.is_vector()) { if (t.is_int()) { - return ir::Broadcast::Make(make_shared(t.ElementOf(), static_cast(v)), t.lanes()); + return ir::Broadcast::Make( + make_shared(t.ElementOf(), static_cast(v)), + t.lanes()); } else if (t.is_uint()) { - return ir::Broadcast::Make(make_shared(t.ElementOf(), static_cast(v)), t.lanes()); + return ir::Broadcast::Make( + make_shared(t.ElementOf(), static_cast(v)), + t.lanes()); } else if (t.is_float()) { - return ir::Broadcast::Make(make_shared(t.ElementOf(), static_cast(v)), t.lanes()); + return ir::Broadcast::Make( + make_shared(t.ElementOf(), static_cast(v)), + t.lanes()); } else if (t.is_bool()) { - return ir::Broadcast::Make(make_shared(t.ElementOf(), static_cast(v)), t.lanes()); + return ir::Broadcast::Make( + make_shared(t.ElementOf(), static_cast(v)), + t.lanes()); } else { CINN_NOT_IMPLEMENTED } diff --git a/paddle/cinn/common/macros.h b/paddle/cinn/common/macros.h index fce0d19292ec3..2b9b75064bc07 100644 --- a/paddle/cinn/common/macros.h +++ b/paddle/cinn/common/macros.h @@ -40,9 +40,10 @@ * CINN_USE_REGISTER(some_key); */ #define CINN_REGISTER_HELPER(symbol__) bool __cinn__##symbol__##__registrar() -#define CINN_USE_REGISTER(symbol__) \ - extern bool __cinn__##symbol__##__registrar(); \ - [[maybe_unused]] static bool __cinn_extern_registrar_##symbol__ = __cinn__##symbol__##__registrar(); +#define CINN_USE_REGISTER(symbol__) \ + extern bool __cinn__##symbol__##__registrar(); \ + [[maybe_unused]] static bool __cinn_extern_registrar_##symbol__ = \ + __cinn__##symbol__##__registrar(); #if __cplusplus >= 201703L #define CINN_NODISCARD [[nodiscard]] diff --git a/paddle/cinn/common/object.h b/paddle/cinn/common/object.h index 625e7d43c6534..a28ac47a12bd0 100644 --- a/paddle/cinn/common/object.h +++ b/paddle/cinn/common/object.h @@ -23,7 +23,8 @@ namespace common { template class Shared; /** - * Object is the basic element in the CINN, with `Shared` wrapper, the object can be shared across the system. + * Object is the basic element in the CINN, with `Shared` wrapper, the object + * can be shared across the system. */ struct Object { //! Get the type representation of this object. @@ -71,7 +72,7 @@ struct Object { mutable RefCount __ref_count__; }; -using object_ptr = Object*; +using object_ptr = Object*; using shared_object = Shared; } // namespace common diff --git a/paddle/cinn/common/python_interpreter_guard.cc b/paddle/cinn/common/python_interpreter_guard.cc index a465adda2c558..07067c22a730f 100644 --- a/paddle/cinn/common/python_interpreter_guard.cc +++ b/paddle/cinn/common/python_interpreter_guard.cc @@ -19,9 +19,13 @@ namespace cinn { namespace common { -PythonInterpreterGuard::PythonInterpreterGuard() { pybind11::initialize_interpreter(); } +PythonInterpreterGuard::PythonInterpreterGuard() { + pybind11::initialize_interpreter(); +} -PythonInterpreterGuard::~PythonInterpreterGuard() { pybind11::finalize_interpreter(); } +PythonInterpreterGuard::~PythonInterpreterGuard() { + pybind11::finalize_interpreter(); +} PythonInterpreterGuard& PythonInterpreterGuard::Guard() { static PythonInterpreterGuard guard; diff --git a/paddle/cinn/common/python_interpreter_guard.h b/paddle/cinn/common/python_interpreter_guard.h index 8c7961af81c36..4628298934db8 100644 --- a/paddle/cinn/common/python_interpreter_guard.h +++ b/paddle/cinn/common/python_interpreter_guard.h @@ -18,13 +18,14 @@ namespace cinn { namespace common { /** - * Singleton to handle Python interpreter life time, since pybind11::initialize_interpreter and - * pybind11::finalize_interpreter cannot be called initialization again after finalization, this - * singleton calls pybind11::finalize_interpreter when it constructs and calls finalization when + * Singleton to handle Python interpreter life time, since + * pybind11::initialize_interpreter and pybind11::finalize_interpreter cannot be + * called initialization again after finalization, this singleton calls + * pybind11::finalize_interpreter when it constructs and calls finalization when * it destructs. * - * In this case, every caller can call this guard to make sure the pybind11 Python interpreter - * is alive. + * In this case, every caller can call this guard to make sure the pybind11 + * Python interpreter is alive. */ class PythonInterpreterGuard { public: diff --git a/paddle/cinn/common/shared.h b/paddle/cinn/common/shared.h index 1b33512c984aa..6c2e042ca6364 100644 --- a/paddle/cinn/common/shared.h +++ b/paddle/cinn/common/shared.h @@ -23,7 +23,7 @@ namespace common { class RefCount { public: using value_type = int32_t; - RefCount() = default; + RefCount() = default; value_type Inc() { return ++count_; } value_type Dec() { return --count_; } @@ -37,7 +37,8 @@ class RefCount { class Object; /** - * The templated methods are used to unify the way to get the RefCount instance in client classes. + * The templated methods are used to unify the way to get the RefCount instance + * in client classes. */ template RefCount& ref_count(const T* t) { @@ -109,8 +110,8 @@ void Shared::DecRef(T* p) { template Shared& Shared::operator=(const Shared& other) { if (other.p_ == p_) return *this; - // Other can be inside of something owned by this, so we should be careful to incref other before we decref - // ourselves. + // Other can be inside of something owned by this, so we should be careful to + // incref other before we decref ourselves. T* tmp = other.p_; IncRef(tmp); DecRef(p_); diff --git a/paddle/cinn/common/target.cc b/paddle/cinn/common/target.cc index c60ff05a06736..42a130163a130 100644 --- a/paddle/cinn/common/target.cc +++ b/paddle/cinn/common/target.cc @@ -53,33 +53,40 @@ int Target::runtime_arch() const { } int Target::max_num_threads() const { - CHECK(arch == Arch::NVGPU) << "The target is not NVGPU! Cannot get max number of threads."; + CHECK(arch == Arch::NVGPU) + << "The target is not NVGPU! Cannot get max number of threads."; return 1024; } int Target::get_multi_processor_count() const { - CHECK(arch == Arch::NVGPU) << "The target is not NVGPU! Cannot get multi processor count"; + CHECK(arch == Arch::NVGPU) + << "The target is not NVGPU! Cannot get multi processor count"; int num_sm = 0; #ifdef CINN_WITH_CUDA - cudaDeviceGetAttribute(&num_sm, cudaDeviceAttr::cudaDevAttrMultiProcessorCount, 0); + cudaDeviceGetAttribute( + &num_sm, cudaDeviceAttr::cudaDevAttrMultiProcessorCount, 0); #endif return num_sm; } int Target::get_max_threads_per_sm() const { - CHECK(arch == Arch::NVGPU) << "The target is not NVGPU! Cannot get max threads per stream processor"; + CHECK(arch == Arch::NVGPU) + << "The target is not NVGPU! Cannot get max threads per stream processor"; int max_thread = 0; #ifdef CINN_WITH_CUDA - cudaDeviceGetAttribute(&max_thread, cudaDeviceAttr::cudaDevAttrMaxThreadsPerMultiProcessor, 0); + cudaDeviceGetAttribute( + &max_thread, cudaDeviceAttr::cudaDevAttrMaxThreadsPerMultiProcessor, 0); #endif return max_thread; } int Target::get_max_blocks_per_sm() const { - CHECK(arch == Arch::NVGPU) << "The target is not NVGPU! Cannot get max blocks per stream processor"; + CHECK(arch == Arch::NVGPU) + << "The target is not NVGPU! Cannot get max blocks per stream processor"; int max_blocks = 1; #ifdef CINN_WITH_CUDA - cudaDeviceGetAttribute(&max_blocks, cudaDeviceAttr::cudaDevAttrMaxBlocksPerMultiprocessor, 0); + cudaDeviceGetAttribute( + &max_blocks, cudaDeviceAttr::cudaDevAttrMaxBlocksPerMultiprocessor, 0); #endif return max_blocks; } @@ -173,16 +180,19 @@ std::ostream &operator<<(std::ostream &os, Target::Arch arch) { } const Target &UnkTarget() { - static Target target(Target::OS::Unk, Target::Arch::Unk, Target::Bit::Unk, {}, {}); + static Target target( + Target::OS::Unk, Target::Arch::Unk, Target::Bit::Unk, {}, {}); return target; } const Target &DefaultHostTarget() { - static Target target(Target::OS::Linux, Target::Arch::X86, Target::Bit::k64, {}, {}); + static Target target( + Target::OS::Linux, Target::Arch::X86, Target::Bit::k64, {}, {}); return target; } const Target &DefaultNVGPUTarget() { - static Target target(Target::OS::Linux, Target::Arch::NVGPU, Target::Bit::k64, {}, {}); + static Target target( + Target::OS::Linux, Target::Arch::NVGPU, Target::Bit::k64, {}, {}); return target; } @@ -191,8 +201,10 @@ int GetMaxThreads() { int max_threads = 1; #ifdef CINN_WITH_CUDA int num_sm = 1; - cudaDeviceGetAttribute(&num_sm, cudaDeviceAttr::cudaDevAttrMultiProcessorCount, 0); - cudaDeviceGetAttribute(&max_threads, cudaDeviceAttr::cudaDevAttrMaxThreadsPerMultiProcessor, 0); + cudaDeviceGetAttribute( + &num_sm, cudaDeviceAttr::cudaDevAttrMultiProcessorCount, 0); + cudaDeviceGetAttribute( + &max_threads, cudaDeviceAttr::cudaDevAttrMaxThreadsPerMultiProcessor, 0); // multiplication num_sm max_threads *= (num_sm * 4); #endif @@ -204,8 +216,10 @@ int GetMaxBlocks() { int max_blocks = 1; #ifdef CINN_WITH_CUDA int num_sm = 1; - cudaDeviceGetAttribute(&num_sm, cudaDeviceAttr::cudaDevAttrMultiProcessorCount, 0); - cudaDeviceGetAttribute(&max_blocks, cudaDeviceAttr::cudaDevAttrMaxBlocksPerMultiprocessor, 0); + cudaDeviceGetAttribute( + &num_sm, cudaDeviceAttr::cudaDevAttrMultiProcessorCount, 0); + cudaDeviceGetAttribute( + &max_blocks, cudaDeviceAttr::cudaDevAttrMaxBlocksPerMultiprocessor, 0); // multiplication num_sm max_blocks *= num_sm; diff --git a/paddle/cinn/common/target.h b/paddle/cinn/common/target.h index ad858b4bb8c6f..6e4a59eab5aa6 100755 --- a/paddle/cinn/common/target.h +++ b/paddle/cinn/common/target.h @@ -23,7 +23,8 @@ namespace common { struct Target { /** - * The operating system used by the target. Determines which system calls to generate. + * The operating system used by the target. Determines which system calls to + * generate. */ enum class OS : int { Unk = -1, @@ -66,16 +67,19 @@ struct Target { std::vector features; std::vector libs; - explicit Target(OS o = OS::Linux, - Arch a = Arch::Unk, - Bit b = Bit::Unk, + explicit Target(OS o = OS::Linux, + Arch a = Arch::Unk, + Bit b = Bit::Unk, const std::vector& features = {}, - const std::vector& libs = {}) + const std::vector& libs = {}) : os(o), arch(a), bits(b), features(features), libs(libs) {} - bool defined() const { return os != OS::Unk && arch != Arch::Unk && bits != Bit::Unk; } + bool defined() const { + return os != OS::Unk && arch != Arch::Unk && bits != Bit::Unk; + } - //! Get the Runtime architecture, it is casted to integer to avoid header file depending. + //! Get the Runtime architecture, it is casted to integer to avoid header file + //! depending. int runtime_arch() const; int max_num_threads() const; diff --git a/paddle/cinn/common/test_helper.cc b/paddle/cinn/common/test_helper.cc index 257b92983eb42..1dd3997135ab1 100644 --- a/paddle/cinn/common/test_helper.cc +++ b/paddle/cinn/common/test_helper.cc @@ -35,7 +35,8 @@ cinn_buffer_t* BufferBuilder::Build() { CINN_NOT_IMPLEMENTED } - auto* buffer = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_type, shape_, align_); + auto* buffer = cinn_buffer_t::new_( + cinn_device_kind_t::cinn_x86_device, cinn_type, shape_, align_); cinn_buffer_malloc(nullptr, buffer); diff --git a/paddle/cinn/common/test_helper.h b/paddle/cinn/common/test_helper.h index 2a3c4625c3e0c..536ecbfe64e45 100644 --- a/paddle/cinn/common/test_helper.h +++ b/paddle/cinn/common/test_helper.h @@ -33,11 +33,12 @@ namespace common { */ struct BufferBuilder { enum class InitType { - kRandom = 0, - kZero = 1, + kRandom = 0, + kZero = 1, kSetValue = 2, }; - explicit BufferBuilder(Type type, const std::vector& shape) : type_(type), shape_(shape) {} + explicit BufferBuilder(Type type, const std::vector& shape) + : type_(type), shape_(shape) {} BufferBuilder& set_random() { init_type_ = InitType::kRandom; @@ -51,7 +52,7 @@ struct BufferBuilder { BufferBuilder& set_val(float x) { init_type_ = InitType::kSetValue; - init_val_ = x; + init_val_ = x; return *this; } @@ -75,7 +76,8 @@ struct BufferBuilder { void RandomInt(void* arr, int len) { auto* data = static_cast(arr); for (int i = 0; i < len; i++) { - data[i] = static_cast(rand() % std::numeric_limits::max()); // NOLINT + data[i] = + static_cast(rand() % std::numeric_limits::max()); // NOLINT } } diff --git a/paddle/cinn/common/type.cc b/paddle/cinn/common/type.cc index 1195f560bf4cf..ad5c8412d950b 100644 --- a/paddle/cinn/common/type.cc +++ b/paddle/cinn/common/type.cc @@ -26,7 +26,8 @@ namespace common { struct Type::Storage { Storage() = default; - Storage(type_t t, int b, int w, specific_type_t st) : type_(t), bits_(b), lanes_(w), specific_type_(st) {} + Storage(type_t t, int b, int w, specific_type_t st) + : type_(t), bits_(b), lanes_(w), specific_type_(st) {} type_t type_{type_t::Unk}; // distinguish FP16/BF16, or E5M2/E4M3 (when FP8 is supported) @@ -122,7 +123,7 @@ Type::Type(const Type &other) { Type Type::ElementOf() const { CheckTypeValid(); - auto type = *this; + auto type = *this; type.storage_->lanes_ = 1; return type; } @@ -130,8 +131,10 @@ Type Type::ElementOf() const { void Type::CheckTypeValid() const { CHECK_NE(GetStorage().type_, type_t::Unk); if (GetStorage().type_ == type_t::Float && GetStorage().bits_ == 16) { - CHECK(GetStorage().specific_type_ == specific_type_t::FP16 || GetStorage().specific_type_ == specific_type_t::BF16) - << "When creating a 16 bits Float, the specific_type_t must be FP16 or BF16."; + CHECK(GetStorage().specific_type_ == specific_type_t::FP16 || + GetStorage().specific_type_ == specific_type_t::BF16) + << "When creating a 16 bits Float, the specific_type_t must be FP16 or " + "BF16."; } } @@ -154,9 +157,11 @@ Type Type::ConstOf() const { } bool Type::is_supported() const { - return this->is_float(32) || this->is_float16() || this->is_bfloat16() || this->is_float(64) || this->is_bool() || - this->is_int(8) || this->is_int(16) || this->is_int(32) || this->is_int(64) || this->is_uint(8) || - this->is_uint(16) || this->is_uint(32) || this->is_uint(64); + return this->is_float(32) || this->is_float16() || this->is_bfloat16() || + this->is_float(64) || this->is_bool() || this->is_int(8) || + this->is_int(16) || this->is_int(32) || this->is_int(64) || + this->is_uint(8) || this->is_uint(16) || this->is_uint(32) || + this->is_uint(64); } Type Type::IgnoreConst() const { @@ -168,20 +173,20 @@ Type Type::IgnoreConst() const { Type Type::with_bits(int x) const { CHECK(is_primitive()); - Type type = *this; + Type type = *this; type.GetStorage().bits_ = x; return type; } Type Type::with_type(Type::type_t x) const { - Type type = *this; + Type type = *this; type.GetStorage().type_ = x; return type; } Type Type::with_lanes(int x) const { CHECK(valid()); - Type type = *this; + Type type = *this; type.GetStorage().lanes_ = x; return type; } @@ -203,7 +208,7 @@ Type &Type::set_cpp_const(bool is_const) { return *this; } Type &Type::set_customized_type(const std::string &t) { - GetStorage().type_ = type_t ::Customized; + GetStorage().type_ = type_t ::Customized; GetStorage().customized_type_ = t; return *this; @@ -225,25 +230,32 @@ bool Type::valid() const { return true; } -Type::Type(Type::type_t t, int b, int w, specific_type_t st) : storage_(new Storage(t, b, w, st)) { +Type::Type(Type::type_t t, int b, int w, specific_type_t st) + : storage_(new Storage(t, b, w, st)) { if (t == Type::type_t::Float && b == 16) { CHECK(st == specific_type_t::FP16 || st == specific_type_t::BF16) - << "When creating a 16 bits Float, the specific_type_t must be FP16 or BF16."; + << "When creating a 16 bits Float, the specific_type_t must be FP16 or " + "BF16."; } } -bool Type::is_primitive() const { return !is_unk() && type() != type_t::Customized; } -bool Type::is_customized() const { return !is_unk() && type() == type_t::Customized; } +bool Type::is_primitive() const { + return !is_unk() && type() != type_t::Customized; +} +bool Type::is_customized() const { + return !is_unk() && type() == type_t::Customized; +} bool Type::is_unk() const { return type() == type_t::Unk; } bool Type::is_bool() const { return type() == type_t::UInt && bits() == 1; } bool Type::is_void() const { return type() == type_t::Void; } bool Type::is_vector() const { return lanes() > 1; } bool Type::is_scalar() const { return lanes() == 1; } -// Note: when calling is_float(16), 'st' can't be specific_type_t::None to distinguish FP16/BF16, or use -// is_float16()/is_bfloat16() for short +// Note: when calling is_float(16), 'st' can't be specific_type_t::None to +// distinguish FP16/BF16, or use is_float16()/is_bfloat16() for short bool Type::is_float(int bits, specific_type_t st) const { if (type() == type_t::Float && bits == 16) { - CHECK(st != specific_type_t::None) << "when calling is_float(16), 'st' can't be specific_type_t::None to " - "distinguish FP16/BF16, or use is_float16()/is_bfloat16() for short"; + CHECK(st != specific_type_t::None) + << "when calling is_float(16), 'st' can't be specific_type_t::None to " + "distinguish FP16/BF16, or use is_float16()/is_bfloat16() for short"; return st == this->specific_type(); } else { return type() == type_t::Float && (bits < 0 || bits == this->bits()); @@ -251,31 +263,48 @@ bool Type::is_float(int bits, specific_type_t st) const { } bool Type::is_float16() const { return is_float(16, specific_type_t::FP16); } bool Type::is_bfloat16() const { return is_float(16, specific_type_t::BF16); } -bool Type::is_uint(int bits) const { return type() == type_t::UInt && (bits < 0 || bits == this->bits()); } -bool Type::is_int(int bits) const { return type() == type_t::Int && (bits < 0 || bits == this->bits()); } +bool Type::is_uint(int bits) const { + return type() == type_t::UInt && (bits < 0 || bits == this->bits()); +} +bool Type::is_int(int bits) const { + return type() == type_t::Int && (bits < 0 || bits == this->bits()); +} bool Type::is_integer(int bits) const { - return (type() == type_t::Int || type() == type_t::UInt) && (bits < 0 || bits == this->bits()); + return (type() == type_t::Int || type() == type_t::UInt) && + (bits < 0 || bits == this->bits()); +} +bool Type::is_index_type() { + return is_int() && lanes() == 1 && (bits() == 32 || bits() == 64); } -bool Type::is_index_type() { return is_int() && lanes() == 1 && (bits() == 32 || bits() == 64); } bool Type::is_cpp_handle() const { - return static_cast(GetStorage().cpp_type_) & static_cast(cpp_type_t::Handle); + return static_cast(GetStorage().cpp_type_) & + static_cast(cpp_type_t::Handle); } bool Type::is_cpp_handle2() const { - return static_cast(GetStorage().cpp_type_) & static_cast(cpp_type_t::HandleHandle); + return static_cast(GetStorage().cpp_type_) & + static_cast(cpp_type_t::HandleHandle); } bool Type::is_cpp_const() const { - return static_cast(cpp_type_t::Const) & static_cast(GetStorage().cpp_type_); + return static_cast(cpp_type_t::Const) & + static_cast(GetStorage().cpp_type_); +} +const std::string &Type::customized_type() const { + return GetStorage().customized_type_; +} +bool Type::is_customized_type() const { + return !GetStorage().customized_type_.empty(); } -const std::string &Type::customized_type() const { return GetStorage().customized_type_; } -bool Type::is_customized_type() const { return !GetStorage().customized_type_.empty(); } Type::type_t Type::type() const { return GetStorage().type_; } -Type::specific_type_t Type::specific_type() const { return GetStorage().specific_type_; } +Type::specific_type_t Type::specific_type() const { + return GetStorage().specific_type_; +} int Type::bits() const { return GetStorage().bits_; } int Type::lanes() const { return GetStorage().lanes_; } Type::cpp_type_t Type::cpp_type() const { return GetStorage().cpp_type_; } bool Type::operator==(const Type &other) const { - return type() == other.type() && specific_type() == other.specific_type() && bits() == other.bits() && - lanes() == other.lanes() && GetStorage().cpp_type_ == other.GetStorage().cpp_type_ && + return type() == other.type() && specific_type() == other.specific_type() && + bits() == other.bits() && lanes() == other.lanes() && + GetStorage().cpp_type_ == other.GetStorage().cpp_type_ && customized_type() == other.customized_type(); } bool Type::is_string() const { return type() == type_t::String; } @@ -286,7 +315,7 @@ Type &Type::operator=(const Type &other) { other.GetStorage().bits_, other.GetStorage().lanes_, other.GetStorage().specific_type_)); - storage_->cpp_type_ = other.GetStorage().cpp_type_; + storage_->cpp_type_ = other.GetStorage().cpp_type_; storage_->customized_type_ = other.GetStorage().customized_type_; } return *this; @@ -380,7 +409,8 @@ struct TypeHash { int Type::bytes() const { // if the type is a pointer auto cpp_type = this->cpp_type(); - if (cpp_type == Type::cpp_type_t::Handle || cpp_type == Type::cpp_type_t::HandleHandle) { + if (cpp_type == Type::cpp_type_t::Handle || + cpp_type == Type::cpp_type_t::HandleHandle) { return sizeof(void *); } @@ -520,7 +550,8 @@ Type Str2Type(const std::string &type) { {"cinn_pod_value_p", type_of()}, }; - CHECK(str2type_map.find(type) != str2type_map.end()) << "Not support type [" << type << "] ! Please Check.\n"; + CHECK(str2type_map.find(type) != str2type_map.end()) + << "Not support type [" << type << "] ! Please Check.\n"; return str2type_map.at(type); } diff --git a/paddle/cinn/common/type.h b/paddle/cinn/common/type.h index 6a92d2f15c044..9e0c353d8d4cb 100644 --- a/paddle/cinn/common/type.h +++ b/paddle/cinn/common/type.h @@ -30,9 +30,10 @@ namespace cinn { namespace common { /** - * Types in the CINN type system. They can be ints, unsigned ints, or floats of various bit-widths. - * They can also be vectors of the same (by setting the `lanes` field to something larger than one). - * NOTE: Front-end code other than vectorize shouldn't use vector types. + * Types in the CINN type system. They can be ints, unsigned ints, or floats of + * various bit-widths. They can also be vectors of the same (by setting the + * `lanes` field to something larger than one). NOTE: Front-end code other than + * vectorize shouldn't use vector types. */ struct Type { enum class type_t { @@ -42,15 +43,17 @@ struct Type { Float, String, Void, - // stupid idea to mix the Customized with other primitive types, large refactor needs here. + // stupid idea to mix the Customized with other primitive types, large + // refactor needs here. Customized, // Customized type }; - // CINN use type_t and bits to distinguish data types, like is_float(64) for double, - // is_float(32) for float, but for Float16 and BFloat16, the bits are both 16, so we need - // some other info to distinguish them. + // CINN use type_t and bits to distinguish data types, like is_float(64) for + // double, is_float(32) for float, but for Float16 and BFloat16, the bits are + // both 16, so we need some other info to distinguish them. enum class specific_type_t { - // None for some cases we only care about the bits, e.g. vectorize for hardwares + // None for some cases we only care about the bits, e.g. vectorize for + // hardwares None = -1, FP16, BF16, @@ -61,9 +64,9 @@ struct Type { //! type decorators in C++, the different code can used together. enum class cpp_type_t : uint8_t { - None = 0, // None information. - Const = 1, // const. - Handle = 1 << 1, // pointer type, such as `cinn_buffer_t*`. + None = 0, // None information. + Const = 1, // const. + Handle = 1 << 1, // pointer type, such as `cinn_buffer_t*`. HandleHandle = 1 << 2, // pointer of pointer, such as `cinn_buffer_t**`. }; @@ -84,7 +87,8 @@ struct Type { CINN_NODISCARD bool is_bool() const; CINN_NODISCARD bool is_vector() const; CINN_NODISCARD bool is_scalar() const; - CINN_NODISCARD bool is_float(int bits = -1, specific_type_t st = specific_type_t::None) const; + CINN_NODISCARD bool is_float( + int bits = -1, specific_type_t st = specific_type_t::None) const; CINN_NODISCARD bool is_float16() const; CINN_NODISCARD bool is_bfloat16() const; CINN_NODISCARD bool is_int(int bits = -1) const; @@ -160,14 +164,26 @@ struct Type { }; // namespace common inline Type Void() { return Type(Type::type_t ::Void, 1, 0); } -inline Type Int(int bits, int lanes = 1) { return Type(Type::type_t ::Int, bits, lanes); } -inline Type UInt(int bits, int lanes = 1) { return Type(Type::type_t ::UInt, bits, lanes); } -inline Type BFloat16(int lanes = 1) { return Type(Type::type_t ::Float, 16, lanes, Type::specific_type_t::BF16); } -inline Type Float16(int lanes = 1) { return Type(Type::type_t ::Float, 16, lanes, Type::specific_type_t::FP16); } -inline Type Float(int bits, int lanes = 1, Type::specific_type_t st = Type::specific_type_t::None) { +inline Type Int(int bits, int lanes = 1) { + return Type(Type::type_t ::Int, bits, lanes); +} +inline Type UInt(int bits, int lanes = 1) { + return Type(Type::type_t ::UInt, bits, lanes); +} +inline Type BFloat16(int lanes = 1) { + return Type(Type::type_t ::Float, 16, lanes, Type::specific_type_t::BF16); +} +inline Type Float16(int lanes = 1) { + return Type(Type::type_t ::Float, 16, lanes, Type::specific_type_t::FP16); +} +inline Type Float(int bits, + int lanes = 1, + Type::specific_type_t st = Type::specific_type_t::None) { if (bits == 16) { - CHECK(st == Type::specific_type_t::FP16 || st == Type::specific_type_t::BF16) - << "When creating a 16 bits Float, the specific_type_t must be FP16 or BF16."; + CHECK(st == Type::specific_type_t::FP16 || + st == Type::specific_type_t::BF16) + << "When creating a 16 bits Float, the specific_type_t must be FP16 or " + "BF16."; } return Type(Type::type_t ::Float, bits, lanes, st); } @@ -273,10 +289,10 @@ std::ostream& operator<<(std::ostream& os, Type::type_t t); namespace customized_type { -static const char* kArgs_type_repr = "Args"; -static const char* kArgValue_type_repr = "ArgValue"; -static const char* kbuffer_t = "cinn_buffer_t"; -static const char* kpod_value_t = "cinn_pod_value_t"; +static const char* kArgs_type_repr = "Args"; +static const char* kArgValue_type_repr = "ArgValue"; +static const char* kbuffer_t = "cinn_buffer_t"; +static const char* kpod_value_t = "cinn_pod_value_t"; static const char* kcuda_builtin_vector_t = "CudaVectorType::"; } // namespace customized_type @@ -287,11 +303,16 @@ inline Type type_of() { } template <> inline Type type_of() { - return Type().set_customized_type(customized_type::kbuffer_t).set_cpp_handle(); + return Type() + .set_customized_type(customized_type::kbuffer_t) + .set_cpp_handle(); } template <> inline Type type_of() { - return Type().set_customized_type(customized_type::kbuffer_t).set_cpp_handle().set_cpp_const(); + return Type() + .set_customized_type(customized_type::kbuffer_t) + .set_cpp_handle() + .set_cpp_const(); } template <> inline Type type_of() { @@ -299,7 +320,9 @@ inline Type type_of() { } template <> inline Type type_of() { - return Type().set_customized_type(customized_type::kpod_value_t).set_cpp_handle(); + return Type() + .set_customized_type(customized_type::kpod_value_t) + .set_cpp_handle(); } Type Str2Type(const std::string& type); diff --git a/paddle/cinn/common/union_find.h b/paddle/cinn/common/union_find.h index b586ee8442488..c42a14683ae3d 100644 --- a/paddle/cinn/common/union_find.h +++ b/paddle/cinn/common/union_find.h @@ -13,7 +13,8 @@ // limitations under the License. /** - * \file This file implements a general UnionFind algorithm to help cluster something. + * \file This file implements a general UnionFind algorithm to help cluster + * something. */ #pragma once #include @@ -33,7 +34,7 @@ struct UnionFindNode : public Object { std::string cluster_info; std::tuple GetRoot() { - auto* p = this; + auto* p = this; int level = 0; while (p->parent) { p = p->parent; @@ -44,11 +45,11 @@ struct UnionFindNode : public Object { void Union(UnionFindNode* other) { auto _p0_l0_ = GetRoot(); - auto& p0 = std::get<0>(_p0_l0_); - auto& l0 = std::get<1>(_p0_l0_); + auto& p0 = std::get<0>(_p0_l0_); + auto& l0 = std::get<1>(_p0_l0_); auto _p1_l1_ = other->GetRoot(); - auto& p1 = std::get<0>(_p1_l1_); - auto& l1 = std::get<1>(_p1_l1_); + auto& p1 = std::get<0>(_p1_l1_); + auto& l1 = std::get<1>(_p1_l1_); if (p0 == p1) return; if (l0 < l1) { @@ -81,8 +82,8 @@ struct UnionFind { for (auto& n : nodes) { auto _root_l_ = n->GetRoot(); // NOLINT - auto& root = std::get<0>(_root_l_); - auto& l = std::get<1>(_root_l_); + auto& root = std::get<0>(_root_l_); + auto& l = std::get<1>(_root_l_); clusters[root].push_back(n.get()); } diff --git a/paddle/cinn/frontend/computation.cc b/paddle/cinn/frontend/computation.cc index a8a8b335582e6..868dc50807e9e 100644 --- a/paddle/cinn/frontend/computation.cc +++ b/paddle/cinn/frontend/computation.cc @@ -40,15 +40,16 @@ struct ComputationContext { std::unordered_map varmap_paddle2program; }; -std::shared_ptr CompileProgram(const Target &target, - Program &program, - const std::vector &outputs, - std::shared_ptr scope, - const CinnComputation::CompileOptions &options, - void *stream) { +std::shared_ptr CompileProgram( + const Target &target, + Program &program, + const std::vector &outputs, + std::shared_ptr scope, + const CinnComputation::CompileOptions &options, + void *stream) { std::shared_ptr ctx(new ComputationContext()); - ctx->stream = stream; - ctx->target = target; + ctx->stream = stream; + ctx->target = target; ctx->compile_options = options; if (ctx->compile_options.use_decomposer) { ProgramPass::Apply(&program, {}, target, {"Decomposer"}); @@ -71,14 +72,16 @@ std::shared_ptr CompileProgram(const Target &target, } ctx->scope = hlir::framework::BuildScope(target, ctx->graph, scope); - ctx->graph_compiler.reset(new hlir::framework::GraphCompiler(target, ctx->scope, ctx->graph)); + ctx->graph_compiler.reset( + new hlir::framework::GraphCompiler(target, ctx->scope, ctx->graph)); std::unordered_set fetch_var_ids; for (auto &out : outputs) { fetch_var_ids.insert(out->id); } - ctx->program = ctx->graph_compiler->Build(options, std::move(fetch_var_ids)).runtime_program; + ctx->program = ctx->graph_compiler->Build(options, std::move(fetch_var_ids)) + .runtime_program; if (ctx->compile_options.do_prerun) { ctx->program->PreRun(); } @@ -116,11 +119,12 @@ std::shared_ptr CinnComputation::CompilePaddleModel( for (int idx = 0; idx < input_names.size(); ++idx) { input_shape_map[input_names[idx]] = input_shapes[idx]; } - auto loadedProgram = LoadPaddleProgram(model_path, scope.get(), input_shape_map, params_combined, target); - auto &program = std::get<0>(loadedProgram); - auto &varmap = std::get<1>(loadedProgram); + auto loadedProgram = LoadPaddleProgram( + model_path, scope.get(), input_shape_map, params_combined, target); + auto &program = std::get<0>(loadedProgram); + auto &varmap = std::get<1>(loadedProgram); auto &varmap_paddle2program = std::get<2>(loadedProgram); - auto &fetch_names = std::get<3>(loadedProgram); + auto &fetch_names = std::get<3>(loadedProgram); // std::vector input_vars; // for (int i = 0; i < input_names.size(); i++) { @@ -137,7 +141,8 @@ std::shared_ptr CinnComputation::CompilePaddleModel( output_vars.push_back(varmap.at(name)); } - std::shared_ptr ctx = CompileProgram(target, *program, output_vars, scope, options, stream); + std::shared_ptr ctx = + CompileProgram(target, *program, output_vars, scope, options, stream); for (auto &v : varmap) { ctx->varmap[v.first] = v.second; } @@ -145,45 +150,52 @@ std::shared_ptr CinnComputation::CompilePaddleModel( ctx->varmap_paddle2program[v.first] = v.second; } - auto computation = std::make_shared(); + auto computation = std::make_shared(); computation->context_ = std::move(ctx); return computation; } -std::shared_ptr CinnComputation::BuildAndCompile(const Target &target, - NetBuilder &builder, - const CompileOptions &options, - const std::vector &outputs, - void *stream) { +std::shared_ptr CinnComputation::BuildAndCompile( + const Target &target, + NetBuilder &builder, + const CompileOptions &options, + const std::vector &outputs, + void *stream) { auto program = builder.Build(); return Compile(target, program, options, outputs, stream); } -std::shared_ptr CinnComputation::Compile(const Target &target, - Program &program, - const CompileOptions &options, - const std::vector &outputs, - void *stream) { +std::shared_ptr CinnComputation::Compile( + const Target &target, + Program &program, + const CompileOptions &options, + const std::vector &outputs, + void *stream) { std::vector output_vars = outputs; if (output_vars.empty()) { output_vars.push_back(program[program.size() - 1].GetOutput(0)); } - std::shared_ptr ctx = CompileProgram(target, program, output_vars, nullptr, options, stream); + std::shared_ptr ctx = + CompileProgram(target, program, output_vars, nullptr, options, stream); - auto computation = std::make_shared(); + auto computation = std::make_shared(); computation->context_ = std::move(ctx); return computation; } -void CinnComputation::SetTensorData(const std::string &tname, void *data, size_t size) { +void CinnComputation::SetTensorData(const std::string &tname, + void *data, + size_t size) { hlir::framework::Tensor t = GetTensor(tname); SetTensorData(t, data, size); } -void CinnComputation::SetTensorData(hlir::framework::Tensor &t, void *data, size_t size) { +void CinnComputation::SetTensorData(hlir::framework::Tensor &t, + void *data, + size_t size) { void *tdata = t->mutable_data(context_->target, t->type()); CHECK_EQ(size, t->shape().numel() * t->type().bytes()); if (context_->target.arch == Target::Arch::NVGPU) { @@ -198,7 +210,9 @@ void CinnComputation::SetTensorData(hlir::framework::Tensor &t, void *data, size CINN_NOT_IMPLEMENTED } } -void CinnComputation::GetTensorData(hlir::framework::Tensor &t, void *data, size_t size) { +void CinnComputation::GetTensorData(hlir::framework::Tensor &t, + void *data, + size_t size) { void *tdata = t->mutable_data(context_->target, t->type()); CHECK_EQ(size, t->shape().numel() * t->type().bytes()); if (context_->target.arch == Target::Arch::NVGPU) { @@ -214,14 +228,20 @@ void CinnComputation::GetTensorData(hlir::framework::Tensor &t, void *data, size } } -void CinnComputation::GetTensorData(const std::string &tname, void *data, size_t size) { +void CinnComputation::GetTensorData(const std::string &tname, + void *data, + size_t size) { hlir::framework::Tensor t = GetTensor(tname); GetTensorData(t, data, size); } -std::vector CinnComputation::GetInputTensors() { return context_->inputs; } +std::vector CinnComputation::GetInputTensors() { + return context_->inputs; +} -std::vector CinnComputation::GetOutputTensors() { return context_->outputs; } +std::vector CinnComputation::GetOutputTensors() { + return context_->outputs; +} hlir::framework::Tensor CinnComputation::GetTensor(const std::string &tname) { if (context_->scope->FindVar(tname)) { @@ -230,12 +250,14 @@ hlir::framework::Tensor CinnComputation::GetTensor(const std::string &tname) { auto it = context_->varmap_paddle2program.find(tname); if (it == context_->varmap_paddle2program.end()) { LOG(FATAL) << "No variable called [" << tname - << "] found in computation\nThe existing vars: " << utils::Join(context_->scope->var_names(), ", "); + << "] found in computation\nThe existing vars: " + << utils::Join(context_->scope->var_names(), ", "); } return context_->scope->GetTensor(it->second); } -void CinnComputation::Execute(const std::map *name2podargs) { +void CinnComputation::Execute( + const std::map *name2podargs) { context_->program->Execute(name2podargs, context_->stream); } diff --git a/paddle/cinn/frontend/computation.h b/paddle/cinn/frontend/computation.h index 9dd1ec8e62270..c464cb40d2a5c 100644 --- a/paddle/cinn/frontend/computation.h +++ b/paddle/cinn/frontend/computation.h @@ -26,9 +26,10 @@ struct ComputationContext; class CinnComputation { public: - struct CompileOptions : public hlir::framework::GraphCompiler::CompileOptions { - bool use_decomposer = false; - bool do_prerun = true; + struct CompileOptions + : public hlir::framework::GraphCompiler::CompileOptions { + bool use_decomposer = false; + bool do_prerun = true; bool use_default_passes = true; std::vector passes; }; @@ -36,43 +37,49 @@ class CinnComputation { inline static CompileOptions DefaultCompileOptions() { CompileOptions options; options.with_instantiate_variables = true; - options.use_decomposer = false; - options.passes = {}; - options.do_prerun = true; - options.use_default_passes = true; + options.use_decomposer = false; + options.passes = {}; + options.do_prerun = true; + options.use_default_passes = true; return options; } /** - * build program from NetBuilder, then compile it. NetBuilder is normally NetBuilder or CINNBuilder. + * build program from NetBuilder, then compile it. NetBuilder is normally + * NetBuilder or CINNBuilder. * @param target the target to run the program * @param builder program builder (NetBuilder or CINNBuilder) * @param options CompileOptions, config the compilation steps - * @param outputs program output variables, if outputs is empty, then the output variable - * of the last instruction of the program is used - * @param stream CUDA stream, the value is meaningful only when target is NVGPU + * @param outputs program output variables, if outputs is empty, then the + * output variable of the last instruction of the program is used + * @param stream CUDA stream, the value is meaningful only when target is + * NVGPU * @return shared_ptr pointing to CinnComputation instance */ - static std::shared_ptr BuildAndCompile(const Target &target, - NetBuilder &builder, - const CompileOptions &options = DefaultCompileOptions(), - const std::vector &outputs = {}, - void *stream = nullptr); + static std::shared_ptr BuildAndCompile( + const Target &target, + NetBuilder &builder, + const CompileOptions &options = DefaultCompileOptions(), + const std::vector &outputs = {}, + void *stream = nullptr); /** * compile the program * @param target the target to run the program - * @param program program (usually generated by a Builder, or converted from Paddle model) + * @param program program (usually generated by a Builder, or converted from + * Paddle model) * @param options CompileOptions, config the compilation steps - * @param outputs program output variables, if outputs is empty, then the output variable - * of the last instruction of the program is used - * @param stream CUDA stream, the value is meaningful only when target is NVGpu + * @param outputs program output variables, if outputs is empty, then the + * output variable of the last instruction of the program is used + * @param stream CUDA stream, the value is meaningful only when target is + * NVGpu * @return shared_ptr pointing to CinnComputation instance */ - static std::shared_ptr Compile(const Target &target, - Program &program, - const CompileOptions &options = DefaultCompileOptions(), - const std::vector &outputs = {}, - void *stream = nullptr); + static std::shared_ptr Compile( + const Target &target, + Program &program, + const CompileOptions &options = DefaultCompileOptions(), + const std::vector &outputs = {}, + void *stream = nullptr); /** * convert a paddle model to program, then compile it. * @param target the target to run the program @@ -81,16 +88,18 @@ class CinnComputation { * @param input_shapes input variable shapes of paddle model * @param params_combined whether params are stored combined * @param options CompileOptions, config the compilation steps - * @param stream CUDA stream, the value is meaningful only when target is NVGpu + * @param stream CUDA stream, the value is meaningful only when target is + * NVGpu * @return shared_ptr pointing to CinnComputation instance */ - static std::shared_ptr CompilePaddleModel(const Target &target, - const std::string &model_path, - const std::vector &input_names, - const std::vector &input_shapes, - bool params_combined, - const CompileOptions &options = DefaultCompileOptions(), - void *stream = nullptr); + static std::shared_ptr CompilePaddleModel( + const Target &target, + const std::string &model_path, + const std::vector &input_names, + const std::vector &input_shapes, + bool params_combined, + const CompileOptions &options = DefaultCompileOptions(), + void *stream = nullptr); /** * get all variable names in the program @@ -123,8 +132,8 @@ class CinnComputation { void SetTensorData(hlir::framework::Tensor &t, void *data, size_t size); /** - * set the data of a tensor (specified by it's name) from user specified buffer. - * if tensor is in NVGPU device memory, cudaMemcpy is used. + * set the data of a tensor (specified by it's name) from user specified + * buffer. if tensor is in NVGPU device memory, cudaMemcpy is used. * @param tname name of the tensor * @param data address of the memory buffer to store tensor's data * @param size size of the memory buffer @@ -140,8 +149,8 @@ class CinnComputation { */ void GetTensorData(hlir::framework::Tensor &t, void *data, size_t size); /** - * copy the data of a tensor (specified by it's name) to user specified buffer. - * if tensor is in NVGPU device memory, cudaMemcpy is used. + * copy the data of a tensor (specified by it's name) to user specified + * buffer. if tensor is in NVGPU device memory, cudaMemcpy is used. * @param tname name of the tensor * @param data address of the memory buffer to store tensor's data * @param size size of the memory buffer @@ -151,7 +160,8 @@ class CinnComputation { /** * run the compiled program */ - void Execute(const std::map *name2podargs = nullptr); + void Execute( + const std::map *name2podargs = nullptr); private: std::shared_ptr context_; diff --git a/paddle/cinn/frontend/computation_test.cc b/paddle/cinn/frontend/computation_test.cc index cdc8db6388fa3..f4b08ecb05397 100644 --- a/paddle/cinn/frontend/computation_test.cc +++ b/paddle/cinn/frontend/computation_test.cc @@ -62,10 +62,10 @@ Program CreateAddProgram() { constexpr int N = 24; NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {M, N}); - auto b = builder.CreateInput(Float(32), {M, N}); - auto c = builder.Relu(a); - auto d = builder.Add(b, c); + auto a = builder.CreateInput(Float(32), {M, N}); + auto b = builder.CreateInput(Float(32), {M, N}); + auto c = builder.Relu(a); + auto d = builder.Add(b, c); auto program = builder.Build(); return program; @@ -82,21 +82,27 @@ TEST(cinn_computation, basic_cpu) { auto d = builder.Add(a, c); auto target = common::DefaultHostTarget(); - auto comp = CinnComputation::BuildAndCompile(target, builder); + auto comp = CinnComputation::BuildAndCompile(target, builder); std::vector hostA(M * N); std::vector hostB(M * N); std::vector hostD(M * N); std::vector hostD_expected(M * N); for (int i = 0; i < M * N; i++) { - hostA[i] = static_cast(rand()) / INT_MAX; - hostB[i] = static_cast(rand()) / INT_MAX; + hostA[i] = static_cast(rand()) / INT_MAX; + hostB[i] = static_cast(rand()) / INT_MAX; hostD_expected[i] = hostA[i] * 2 + hostB[i]; } - comp->SetTensorData("A", reinterpret_cast(hostA.data()), hostA.size() * sizeof(float)); - comp->SetTensorData("B", reinterpret_cast(hostB.data()), hostB.size() * sizeof(float)); + comp->SetTensorData("A", + reinterpret_cast(hostA.data()), + hostA.size() * sizeof(float)); + comp->SetTensorData("B", + reinterpret_cast(hostB.data()), + hostB.size() * sizeof(float)); comp->Execute(); - comp->GetTensorData(d->id, reinterpret_cast(hostD.data()), hostD.size() * sizeof(float)); + comp->GetTensorData(d->id, + reinterpret_cast(hostD.data()), + hostD.size() * sizeof(float)); for (int i = 0; i < hostD.size(); i++) { ASSERT_NEAR(hostD[i], hostD_expected[i], 1e-5); } @@ -114,21 +120,27 @@ TEST(cinn_computation, basic_gpu) { auto d = builder.Add(a, c); auto target = common::DefaultNVGPUTarget(); - auto comp = CinnComputation::BuildAndCompile(target, builder); + auto comp = CinnComputation::BuildAndCompile(target, builder); std::vector hostA(M * N); std::vector hostB(M * N); std::vector hostD(M * N); std::vector hostD_expected(M * N); for (int i = 0; i < M * N; i++) { - hostA[i] = static_cast(rand()) / INT_MAX; - hostB[i] = static_cast(rand()) / INT_MAX; + hostA[i] = static_cast(rand()) / INT_MAX; + hostB[i] = static_cast(rand()) / INT_MAX; hostD_expected[i] = hostA[i] * 2 + hostB[i]; } - comp->SetTensorData("A", reinterpret_cast(hostA.data()), hostA.size() * sizeof(float)); - comp->SetTensorData("B", reinterpret_cast(hostB.data()), hostB.size() * sizeof(float)); + comp->SetTensorData("A", + reinterpret_cast(hostA.data()), + hostA.size() * sizeof(float)); + comp->SetTensorData("B", + reinterpret_cast(hostB.data()), + hostB.size() * sizeof(float)); comp->Execute(); - comp->GetTensorData(d->id, reinterpret_cast(hostD.data()), hostD.size() * sizeof(float)); + comp->GetTensorData(d->id, + reinterpret_cast(hostD.data()), + hostD.size() * sizeof(float)); for (int i = 0; i < hostD.size(); i++) { ASSERT_NEAR(hostD[i], hostD_expected[i], 1e-5); } @@ -137,9 +149,9 @@ TEST(cinn_computation, basic_gpu) { TEST(cinn_computation, net_builder_cpu) { auto program = CreateTestProgram(); - auto target = common::DefaultHostTarget(); + auto target = common::DefaultHostTarget(); auto compute = CinnComputation::Compile(target, program); - auto inputs = compute->GetInputTensors(); + auto inputs = compute->GetInputTensors(); ASSERT_EQ(inputs.size(), 2); auto tensorA = inputs[0]; auto tensorB = inputs[1]; @@ -171,9 +183,9 @@ TEST(cinn_computation, net_builder_cpu) { #ifdef CINN_WITH_CUDA TEST(cinn_computation, net_builder_gpu) { auto program = CreateTestProgram(); - auto target = common::DefaultNVGPUTarget(); + auto target = common::DefaultNVGPUTarget(); auto compute = CinnComputation::Compile(target, program); - auto inputs = compute->GetInputTensors(); + auto inputs = compute->GetInputTensors(); ASSERT_EQ(inputs.size(), 2); auto tensorA = inputs[0]; auto tensorB = inputs[1]; @@ -192,15 +204,20 @@ TEST(cinn_computation, net_builder_gpu) { // ... or async copy to device memory // ... not showed here - // assume tensorB is generated in host memory, needs copy to GPU memory (sync.) + // assume tensorB is generated in host memory, needs copy to GPU memory + // (sync.) std::vector hostB(32 * 24 / 2); - compute->SetTensorData(tensorB, reinterpret_cast(hostB.data()), hostB.size() * sizeof(float)); + compute->SetTensorData(tensorB, + reinterpret_cast(hostB.data()), + hostB.size() * sizeof(float)); // execute engine compute->Execute(); // get outputs std::vector hostOut(tensorOut->shape().numel()); - compute->GetTensorData(tensorOut, reinterpret_cast(hostOut.data()), hostOut.size() * sizeof(float)); + compute->GetTensorData(tensorOut, + reinterpret_cast(hostOut.data()), + hostOut.size() * sizeof(float)); } } #endif @@ -208,8 +225,9 @@ TEST(cinn_computation, net_builder_gpu) { TEST(cinn_computation, fc_execute_cpu) { auto target = common::DefaultHostTarget(); ASSERT_NE(FLAGS_model_dir, ""); - auto compute = CinnComputation::CompilePaddleModel(target, FLAGS_model_dir, {"A"}, {{1, 30}}, false); - auto inputs = compute->GetInputTensors(); + auto compute = CinnComputation::CompilePaddleModel( + target, FLAGS_model_dir, {"A"}, {{1, 30}}, false); + auto inputs = compute->GetInputTensors(); ASSERT_EQ(inputs.size(), 1); auto A = inputs[0]; ASSERT_EQ(A->shape().numel(), 1 * 30); @@ -223,7 +241,8 @@ TEST(cinn_computation, fc_execute_cpu) { TEST(cinn_computation, fc_execute_gpu) { auto target = common::DefaultNVGPUTarget(); ASSERT_NE(FLAGS_model_dir, ""); - auto compute = CinnComputation::CompilePaddleModel(target, FLAGS_model_dir, {"A"}, {{1, 30}}, false); + auto compute = CinnComputation::CompilePaddleModel( + target, FLAGS_model_dir, {"A"}, {{1, 30}}, false); auto inputs = compute->GetInputTensors(); ASSERT_EQ(inputs.size(), 1); @@ -235,62 +254,67 @@ TEST(cinn_computation, fc_execute_gpu) { std::vector hostA(30); for (float &v : hostA) v = static_cast(rand()) / INT_MAX; - compute->SetTensorData(A, reinterpret_cast(hostA.data()), hostA.size() * sizeof(float)); + compute->SetTensorData( + A, reinterpret_cast(hostA.data()), hostA.size() * sizeof(float)); compute->Execute(); std::vector hostOut(30); - compute->GetTensorData(out, reinterpret_cast(hostOut.data()), hostOut.size() * sizeof(float)); + compute->GetTensorData(out, + reinterpret_cast(hostOut.data()), + hostOut.size() * sizeof(float)); } #endif TEST(cinn_computation, decomposer_cpu) { // this test only shows the API usage - ASSERT_NE(cinn::frontend::ProgramPassRegistry::Global()->Find("Decomposer"), nullptr); + ASSERT_NE(cinn::frontend::ProgramPassRegistry::Global()->Find("Decomposer"), + nullptr); // without decomposer { - auto prog = CreateAddProgram(); - auto target = common::DefaultHostTarget(); - auto options = CinnComputation::DefaultCompileOptions(); + auto prog = CreateAddProgram(); + auto target = common::DefaultHostTarget(); + auto options = CinnComputation::DefaultCompileOptions(); options.use_decomposer = false; - auto compute = CinnComputation::Compile(target, prog, options); - auto names = compute->GetAllTensorNames(); + auto compute = CinnComputation::Compile(target, prog, options); + auto names = compute->GetAllTensorNames(); ASSERT_EQ(names.size(), 3); } // with decomposer { - auto prog = CreateAddProgram(); - auto target = common::DefaultHostTarget(); - auto options = CinnComputation::DefaultCompileOptions(); + auto prog = CreateAddProgram(); + auto target = common::DefaultHostTarget(); + auto options = CinnComputation::DefaultCompileOptions(); options.use_decomposer = true; - auto compute = CinnComputation::Compile(target, prog, options); - auto names = compute->GetAllTensorNames(); + auto compute = CinnComputation::Compile(target, prog, options); + auto names = compute->GetAllTensorNames(); } } #ifdef CINN_WITH_CUDA TEST(cinn_computation, gpu_stream) { // this test only shows the API usage - auto target = common::DefaultNVGPUTarget(); - auto prog = CreateAddProgram(); + auto target = common::DefaultNVGPUTarget(); + auto prog = CreateAddProgram(); auto options = CinnComputation::DefaultCompileOptions(); cudaStream_t streams[1]; cudaStreamCreate(&streams[0]); - auto compute = CinnComputation::Compile(target, prog, options, {}, static_cast(streams[0])); + auto compute = CinnComputation::Compile( + target, prog, options, {}, static_cast(streams[0])); compute->Execute(); } #endif TEST(cinn_computation, without_instantiate_variables) { // this test only shows the API usage - auto target = common::DefaultHostTarget(); - auto prog = CreateAddProgram(); - auto options = CinnComputation::DefaultCompileOptions(); + auto target = common::DefaultHostTarget(); + auto prog = CreateAddProgram(); + auto options = CinnComputation::DefaultCompileOptions(); options.with_instantiate_variables = false; auto compute = CinnComputation::Compile(target, prog, options); - auto names = compute->GetAllTensorNames(); + auto names = compute->GetAllTensorNames(); std::map pod2args; // compute->Execute(&pod2args); diff --git a/paddle/cinn/frontend/decomposer/activation.cc b/paddle/cinn/frontend/decomposer/activation.cc index 0244243be8822..040d1af9b1b98 100644 --- a/paddle/cinn/frontend/decomposer/activation.cc +++ b/paddle/cinn/frontend/decomposer/activation.cc @@ -20,45 +20,57 @@ namespace frontend { namespace decomposer { void relu(const Instruction& instr, const DecomposerContext& context) { - CHECK_EQ(instr->inputs.size(), 1UL) << " 1 input tensor for " << instr->op_type; - CHECK_EQ(instr->outputs.size(), 1UL) << "1 output tensor for " << instr->op_type; - auto x = instr->inputs[0]; - auto output = instr->outputs[0]; + CHECK_EQ(instr->inputs.size(), 1UL) + << " 1 input tensor for " << instr->op_type; + CHECK_EQ(instr->outputs.size(), 1UL) + << "1 output tensor for " << instr->op_type; + auto x = instr->inputs[0]; + auto output = instr->outputs[0]; auto* builder = context.builder(); - auto bcast_zero = builder->FillConstant(x->shape, 0.0f, common::UniqName("zero"), common::Type2Str(x->type)); - auto out = builder->Max(x, bcast_zero); + auto bcast_zero = builder->FillConstant( + x->shape, 0.0f, common::UniqName("zero"), common::Type2Str(x->type)); + auto out = builder->Max(x, bcast_zero); // map the the output of decomposed operator to the original. context.MapOutToOrigin(out, output); } void relu_grad(const Instruction& instr, const DecomposerContext& context) { - CHECK_EQ(instr->inputs.size(), 2UL) << " 2 input tensors for " << instr->op_type; - CHECK_EQ(instr->outputs.size(), 1UL) << "1 output tensor for " << instr->op_type; - auto dout = instr->inputs[0]; - auto out = instr->inputs[1]; - auto dx = instr->outputs[0]; + CHECK_EQ(instr->inputs.size(), 2UL) + << " 2 input tensors for " << instr->op_type; + CHECK_EQ(instr->outputs.size(), 1UL) + << "1 output tensor for " << instr->op_type; + auto dout = instr->inputs[0]; + auto out = instr->inputs[1]; + auto dx = instr->outputs[0]; auto* builder = context.builder(); - auto bcast_zero = builder->FillConstant(out->shape, 0.0f, common::UniqName("zero"), common::Type2Str(out->type)); - auto condition = builder->GreaterThan(out, bcast_zero); - auto res = builder->Select(condition, dout, bcast_zero); + auto bcast_zero = builder->FillConstant( + out->shape, 0.0f, common::UniqName("zero"), common::Type2Str(out->type)); + auto condition = builder->GreaterThan(out, bcast_zero); + auto res = builder->Select(condition, dout, bcast_zero); // map the the output of decomposed operator to the original. context.MapOutToOrigin(res, dx); } void gelu(const Instruction& instr, const DecomposerContext& context) { - CHECK_EQ(instr->inputs.size(), 1UL) << " 1 input tensor for " << instr->op_type; - CHECK_EQ(instr->outputs.size(), 1UL) << "1 output tensor for " << instr->op_type; - auto x = instr->inputs[0]; - auto output = instr->outputs[0]; + CHECK_EQ(instr->inputs.size(), 1UL) + << " 1 input tensor for " << instr->op_type; + CHECK_EQ(instr->outputs.size(), 1UL) + << "1 output tensor for " << instr->op_type; + auto x = instr->inputs[0]; + auto output = instr->outputs[0]; auto* builder = context.builder(); // x * (0.5 + 0.5 * erf(sqrtf(0.5) * x)) - auto p_5 = builder->FillConstant(x->shape, 0.5f, common::UniqName("p_5"), common::Type2Str(x->type)); - auto p_7 = builder->FillConstant(x->shape, std::sqrt(0.5), common::UniqName("p_7"), common::Type2Str(x->type)); + auto p_5 = builder->FillConstant( + x->shape, 0.5f, common::UniqName("p_5"), common::Type2Str(x->type)); + auto p_7 = builder->FillConstant(x->shape, + std::sqrt(0.5), + common::UniqName("p_7"), + common::Type2Str(x->type)); auto erf = builder->Erf(builder->Multiply(x, p_7)); auto cdf = builder->Add(p_5, builder->Multiply(p_5, erf)); auto out = builder->Multiply(x, cdf); @@ -68,10 +80,12 @@ void gelu(const Instruction& instr, const DecomposerContext& context) { } void softmax(const Instruction& instr, const DecomposerContext& context) { - CHECK_EQ(instr->inputs.size(), 1UL) << " 1 input tensor for " << instr->op_type; - CHECK_EQ(instr->outputs.size(), 1UL) << "1 output tensor for " << instr->op_type; - auto x = instr->inputs[0]; - auto output = instr->outputs[0]; + CHECK_EQ(instr->inputs.size(), 1UL) + << " 1 input tensor for " << instr->op_type; + CHECK_EQ(instr->outputs.size(), 1UL) + << "1 output tensor for " << instr->op_type; + auto x = instr->inputs[0]; + auto output = instr->outputs[0]; auto* builder = context.builder(); std::vector b_axes; @@ -88,7 +102,8 @@ void softmax(const Instruction& instr, const DecomposerContext& context) { } } - // When the rank of x is 1, broadcast axes will be empty, so we need to insert last dim as broadcast axis. + // When the rank of x is 1, broadcast axes will be empty, so we need to insert + // last dim as broadcast axis. if (b_axes.empty()) { b_axes.emplace_back(-1); } @@ -96,7 +111,8 @@ void softmax(const Instruction& instr, const DecomposerContext& context) { auto mode = instr.GetAttrs("mode"); if (mode == "fast") { // x_sum = sum(exp(x)) - auto x_sum = builder->BroadcastTo(builder->ReduceSum(builder->Exp(x), axes), x->shape, b_axes); + auto x_sum = builder->BroadcastTo( + builder->ReduceSum(builder->Exp(x), axes), x->shape, b_axes); // x_exp / x_sum auto out = builder->Divide(builder->Exp(x), x_sum); @@ -104,13 +120,16 @@ void softmax(const Instruction& instr, const DecomposerContext& context) { context.MapOutToOrigin(out, output); } else { // x = max(x) - auto x_max = builder->BroadcastTo(builder->ReduceMax(x, axes), x->shape, b_axes); + auto x_max = + builder->BroadcastTo(builder->ReduceMax(x, axes), x->shape, b_axes); // x_exp = exp(x - x_max) auto x_exp = builder->Exp(builder->Subtract(x, x_max)); // x_sum = sum(x_exp) - auto x_sum = builder->BroadcastTo(builder->ReduceSum(x_exp, axes), x->shape, b_axes); + auto x_sum = + builder->BroadcastTo(builder->ReduceSum(x_exp, axes), x->shape, b_axes); // x_exp / x_sum - auto out = builder->Divide(builder->Exp(builder->Subtract(x, x_max)), x_sum); + auto out = + builder->Divide(builder->Exp(builder->Subtract(x, x_max)), x_sum); // map the the output of decomposed operator to the original. context.MapOutToOrigin(out, output); diff --git a/paddle/cinn/frontend/decomposer/activation_test.cc b/paddle/cinn/frontend/decomposer/activation_test.cc index 5e2b5b18ea2b1..e0bd9a82a48e0 100644 --- a/paddle/cinn/frontend/decomposer/activation_test.cc +++ b/paddle/cinn/frontend/decomposer/activation_test.cc @@ -20,45 +20,49 @@ namespace cinn::frontend { TEST(Decomposer, relu) { NetBuilder builder("relu"); - auto x = builder.CreateInput(Float(32), {20, 10}, "x"); + auto x = builder.CreateInput(Float(32), {20, 10}, "x"); auto out = builder.Relu(x); - auto relu_cpu = [](const std::vector& lengths, const std::vector& ptrs) { - size_t n = lengths[0]; - float* x = static_cast(ptrs[0]); + auto relu_cpu = [](const std::vector& lengths, + const std::vector& ptrs) { + size_t n = lengths[0]; + float* x = static_cast(ptrs[0]); float* out = static_cast(ptrs[1]); for (size_t i = 0; i < n; ++i) { float tmp_0 = x[i]; - out[i] = tmp_0 > 0 ? tmp_0 : 0; + out[i] = tmp_0 > 0 ? tmp_0 : 0; } }; - std::vector input_names = {x.id().data()}; - std::vector output_names = {out->id}; + std::vector input_names = {x.id().data()}; + std::vector output_names = {out->id}; std::vector> output_shapes = {{20, 10}}; - RunAndCheck(builder, input_names, output_names, output_shapes, relu_cpu, -1, 1); + RunAndCheck( + builder, input_names, output_names, output_shapes, relu_cpu, -1, 1); } TEST(Decomposer, relu_grad) { NetBuilder builder("relu_grad"); auto dout = builder.CreateInput(Float(32), {20, 10}, "dout"); - auto out = builder.CreateInput(Float(32), {20, 10}, "out"); - auto dx = builder.ReluGrad(dout, out); + auto out = builder.CreateInput(Float(32), {20, 10}, "out"); + auto dx = builder.ReluGrad(dout, out); - auto relu_grad_cpu = [](const std::vector& lengths, const std::vector& ptrs) { - size_t n = lengths[0]; + auto relu_grad_cpu = [](const std::vector& lengths, + const std::vector& ptrs) { + size_t n = lengths[0]; float* dout = static_cast(ptrs[0]); - float* out = static_cast(ptrs[1]); - float* dx = static_cast(ptrs[2]); + float* out = static_cast(ptrs[1]); + float* dx = static_cast(ptrs[2]); for (size_t i = 0; i < n; ++i) { dx[i] = out[i] > 0 ? dout[i] : 0; } }; - std::vector input_names = {dout.id().data(), out.id().data()}; - std::vector output_names = {dx->id}; + std::vector input_names = {dout.id().data(), out.id().data()}; + std::vector output_names = {dx->id}; std::vector> output_shapes = {{20, 10}}; - RunAndCheck(builder, input_names, output_names, output_shapes, relu_grad_cpu, -1, 1); + RunAndCheck( + builder, input_names, output_names, output_shapes, relu_grad_cpu, -1, 1); } TEST(Decomposer, softmax_decomposer) { @@ -76,7 +80,8 @@ TEST(Decomposer, softmax_decomposer) { auto target = common::DefaultTarget(); RunDecomposer(&program, target); - auto graph = std::make_shared(program, output_names, target); + auto graph = + std::make_shared(program, output_names, target); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); @@ -90,7 +95,7 @@ TEST(Decomposer, softmax_decomposer) { for (auto& input : inputs) { scope->Var(input.first); auto tensor = scope->GetTensor(input.first); - auto* data = tensor->mutable_data(target); + auto* data = tensor->mutable_data(target); CopyFromVector(input.second, tensor, target); } run_program->Execute(); diff --git a/paddle/cinn/frontend/decomposer/batch_norm.cc b/paddle/cinn/frontend/decomposer/batch_norm.cc index ed421fdebd470..19c53bd506b8f 100644 --- a/paddle/cinn/frontend/decomposer/batch_norm.cc +++ b/paddle/cinn/frontend/decomposer/batch_norm.cc @@ -25,73 +25,88 @@ struct BatchNormHelper { const std::vector& arg_param_shape, std::string data_layout, std::string bn_op_type) { - CHECK_EQ(arg_x_shape.size(), 4UL) << "Only 4-D input tensor is supported, but get " << arg_x_shape.size() - << "-D input tensor."; + CHECK_EQ(arg_x_shape.size(), 4UL) + << "Only 4-D input tensor is supported, but get " << arg_x_shape.size() + << "-D input tensor."; - builder = net_builder; - x_shape = arg_x_shape; + builder = net_builder; + x_shape = arg_x_shape; param_shape = arg_param_shape; if (data_layout == "NCHW") { - channel_dim = 1; - reduce_dim = {0, 2, 3}; + channel_dim = 1; + reduce_dim = {0, 2, 3}; element_count = x_shape[0] * x_shape[2] * x_shape[3]; } else if (data_layout == "NHWC") { - channel_dim = 3; - reduce_dim = {0, 1, 2}; + channel_dim = 3; + reduce_dim = {0, 1, 2}; element_count = x_shape[0] * x_shape[1] * x_shape[2]; } else { LOG(FATAL) << data_layout << " setting is not support!"; } num_instructions = builder->size(); - op_type = bn_op_type; + op_type = bn_op_type; } ~BatchNormHelper() { - VLOG(4) << op_type << " is decomposed to " << builder->size() - num_instructions << " instructions."; + VLOG(4) << op_type << " is decomposed to " + << builder->size() - num_instructions << " instructions."; } std::vector MeanAndVariance(Variable x) { auto mean = Mean(x); - // variance = reduce_sum(x * x) / nhw - mean * mean, shape = [c], simplified by equation: E(x^2) - [E(x)]^2 + // variance = reduce_sum(x * x) / nhw - mean * mean, shape = [c], simplified + // by equation: E(x^2) - [E(x)]^2 auto variance = Variance(x, mean); return {mean, variance}; } - std::vector GradBiasAndScale(Variable x, Variable x_mean, Variable y_grad) { - auto mean_4d = builder->BroadcastTo(x_mean, x->shape, {channel_dim}); + std::vector GradBiasAndScale(Variable x, + Variable x_mean, + Variable y_grad) { + auto mean_4d = builder->BroadcastTo(x_mean, x->shape, {channel_dim}); auto x_mean_diff = builder->Subtract(x, mean_4d); // bias_grad = reduce_sum(y_grad), shape = [c] - auto bias_grad = Reduce(y_grad); - auto sum_of_y_grad_mul_x_mean_diff = Reduce(builder->Multiply(y_grad, x_mean_diff)); + auto bias_grad = Reduce(y_grad); + auto sum_of_y_grad_mul_x_mean_diff = + Reduce(builder->Multiply(y_grad, x_mean_diff)); return {bias_grad, sum_of_y_grad_mul_x_mean_diff}; } // mean = reduce_sum(x) / nhw Variable Mean(Variable x) { - auto sum = Reduce(x); - auto element_count_1d = builder->FillConstant( - sum->shape, element_count, common::UniqName("element_count"), common::Type2Str(sum->type)); + auto sum = Reduce(x); + auto element_count_1d = + builder->FillConstant(sum->shape, + element_count, + common::UniqName("element_count"), + common::Type2Str(sum->type)); auto mean = builder->Divide(sum, element_count_1d); return mean; } // variance = reduce_sum(x * x) / nhw - mean * mean Variable Variance(Variable x, Variable mean) { - auto x_square = builder->Multiply(x, builder->Identity(x)); - auto x_square_sum = Reduce(x_square); - auto element_count_1d = builder->FillConstant( - x_square_sum->shape, element_count, common::UniqName("element_count"), common::Type2Str(x_square_sum->type)); + auto x_square = builder->Multiply(x, builder->Identity(x)); + auto x_square_sum = Reduce(x_square); + auto element_count_1d = + builder->FillConstant(x_square_sum->shape, + element_count, + common::UniqName("element_count"), + common::Type2Str(x_square_sum->type)); auto x_square_mean = builder->Divide(x_square_sum, element_count_1d); - auto variance = builder->Subtract(x_square_mean, builder->Multiply(mean, builder->Identity(mean))); + auto variance = builder->Subtract( + x_square_mean, builder->Multiply(mean, builder->Identity(mean))); return variance; } // std_variance_inv = rsqrt(variance + epsilon) Variable StdVarianceInv1d(Variable variance, float epsilon) { - auto epsilon_1d = - builder->FillConstant(variance->shape, epsilon, common::UniqName("epsilon"), common::Type2Str(variance->type)); + auto epsilon_1d = builder->FillConstant(variance->shape, + epsilon, + common::UniqName("epsilon"), + common::Type2Str(variance->type)); auto std_variance_inv = builder->Rsqrt(builder->Add(variance, epsilon_1d)); return std_variance_inv; } @@ -99,21 +114,32 @@ struct BatchNormHelper { // std_variance_inv = rsqrt(variance + epsilon) Variable StdVarianceInv4d(Variable variance, float epsilon) { auto variance_4d = builder->BroadcastTo(variance, x_shape, {channel_dim}); - auto epsilon_4d = builder->FillConstant( - variance_4d->shape, epsilon, common::UniqName("epsilon"), common::Type2Str(variance_4d->type)); - auto std_variance_inv_4d = builder->Rsqrt(builder->Add(variance_4d, epsilon_4d)); + auto epsilon_4d = + builder->FillConstant(variance_4d->shape, + epsilon, + common::UniqName("epsilon"), + common::Type2Str(variance_4d->type)); + auto std_variance_inv_4d = + builder->Rsqrt(builder->Add(variance_4d, epsilon_4d)); return std_variance_inv_4d; } // moving_value = moving_value * momentum + (1.0 - momentum) * saved_value // value maybe mean and variance. - Variable UpdateMeanVariance(Variable moving_value, Variable saved_value, float momentum) { - auto factor_0 = builder->FillConstant( - moving_value->shape, momentum, common::UniqName("factor_0"), common::Type2Str(moving_value->type)); - auto factor_1 = builder->FillConstant( - saved_value->shape, 1.0f - momentum, common::UniqName("factor_1"), common::Type2Str(saved_value->type)); + Variable UpdateMeanVariance(Variable moving_value, + Variable saved_value, + float momentum) { + auto factor_0 = builder->FillConstant(moving_value->shape, + momentum, + common::UniqName("factor_0"), + common::Type2Str(moving_value->type)); + auto factor_1 = builder->FillConstant(saved_value->shape, + 1.0f - momentum, + common::UniqName("factor_1"), + common::Type2Str(saved_value->type)); auto new_moving_value = - builder->Add(builder->Multiply(moving_value, factor_0), builder->Multiply(saved_value, factor_1)); + builder->Add(builder->Multiply(moving_value, factor_0), + builder->Multiply(saved_value, factor_1)); return new_moving_value; } @@ -129,48 +155,55 @@ struct BatchNormHelper { int num_instructions{0}; }; -void batch_norm_train(const Instruction& instr, const DecomposerContext& context) { - CHECK_EQ(instr->inputs.size(), 5UL) << "The number of the given inputs is not equal to the required for op " - << instr->op_type; - CHECK_EQ(instr->outputs.size(), 5UL) << "The number of the given outputs is not equal to the required for op " - << instr->op_type; - - auto& x = instr->inputs[0]; - auto& scale = instr->inputs[1]; - auto& bias = instr->inputs[2]; - auto& moving_mean = instr->inputs[3]; +void batch_norm_train(const Instruction& instr, + const DecomposerContext& context) { + CHECK_EQ(instr->inputs.size(), 5UL) + << "The number of the given inputs is not equal to the required for op " + << instr->op_type; + CHECK_EQ(instr->outputs.size(), 5UL) + << "The number of the given outputs is not equal to the required for op " + << instr->op_type; + + auto& x = instr->inputs[0]; + auto& scale = instr->inputs[1]; + auto& bias = instr->inputs[2]; + auto& moving_mean = instr->inputs[3]; auto& moving_variance = instr->inputs[4]; CHECK_EQ(scale->type, bias->type); CHECK_EQ(scale->type, moving_mean->type); CHECK_EQ(scale->type, moving_variance->type); - float epsilon = instr.GetAttrs("epsilon"); - float momentum = instr.GetAttrs("momentum"); + float epsilon = instr.GetAttrs("epsilon"); + float momentum = instr.GetAttrs("momentum"); std::string layout = instr.GetAttrs("data_layout"); NetBuilder* builder = context.builder(); - BatchNormHelper helper(builder, x->shape, scale->shape, layout, "batch_norm_train"); + BatchNormHelper helper( + builder, x->shape, scale->shape, layout, "batch_norm_train"); auto mean_variance = helper.MeanAndVariance(x); - auto mean = mean_variance[0]; - auto variance = mean_variance[1]; + auto mean = mean_variance[0]; + auto variance = mean_variance[1]; auto mean_4d = builder->BroadcastTo(mean, x->shape, {helper.channel_dim}); // std_variance_inv = rsqrt(variance + epsilon), shape = [c] auto std_variance_inv_4d = helper.StdVarianceInv4d(variance, epsilon); // y = scale * (x - mean) * std_variance_inv + bias, shape = [n, c, h, w] - auto scale_4d = builder->BroadcastTo(scale, x->shape, {helper.channel_dim}); - auto bias_4d = builder->BroadcastTo(bias, x->shape, {helper.channel_dim}); - auto normalized = builder->Multiply(builder->Subtract(x, mean_4d), std_variance_inv_4d); + auto scale_4d = builder->BroadcastTo(scale, x->shape, {helper.channel_dim}); + auto bias_4d = builder->BroadcastTo(bias, x->shape, {helper.channel_dim}); + auto normalized = + builder->Multiply(builder->Subtract(x, mean_4d), std_variance_inv_4d); auto scaled_normalized = builder->Multiply(normalized, scale_4d); - auto y = builder->Add(scaled_normalized, bias_4d); + auto y = builder->Add(scaled_normalized, bias_4d); // moving_mean = moving_mean * momentum + (1.0 - momentum) * mean, shape = [c] auto new_moving_mean = helper.UpdateMeanVariance(moving_mean, mean, momentum); - // moving_variance = moving_variance * momentum + (1.0 - momentum) * variance, shape = [c] - auto new_moving_variance = helper.UpdateMeanVariance(moving_variance, variance, momentum); + // moving_variance = moving_variance * momentum + (1.0 - momentum) * variance, + // shape = [c] + auto new_moving_variance = + helper.UpdateMeanVariance(moving_variance, variance, momentum); context.MapOutToOrigin(y, instr->outputs[0]); context.MapOutToOrigin(mean, instr->outputs[1]); @@ -179,64 +212,86 @@ void batch_norm_train(const Instruction& instr, const DecomposerContext& context context.MapOutToOrigin(new_moving_variance, instr->outputs[4]); } -void batch_norm_grad(const Instruction& instr, const DecomposerContext& context) { - CHECK_EQ(instr->inputs.size(), 5UL) << " The number of the given inputs is not equal to the required " - << instr->op_type; - CHECK_EQ(instr->outputs.size(), 3UL) << " The number of the given outputs is not equal to the required" - << instr->op_type; - - auto& y_grad = instr->inputs[0]; - auto& x = instr->inputs[1]; - auto& scale = instr->inputs[2]; - auto& save_mean = instr->inputs[3]; +void batch_norm_grad(const Instruction& instr, + const DecomposerContext& context) { + CHECK_EQ(instr->inputs.size(), 5UL) + << " The number of the given inputs is not equal to the required " + << instr->op_type; + CHECK_EQ(instr->outputs.size(), 3UL) + << " The number of the given outputs is not equal to the required" + << instr->op_type; + + auto& y_grad = instr->inputs[0]; + auto& x = instr->inputs[1]; + auto& scale = instr->inputs[2]; + auto& save_mean = instr->inputs[3]; auto& save_variance = instr->inputs[4]; CHECK_EQ(y_grad->type, x->type); CHECK_EQ(scale->type, save_mean->type); CHECK_EQ(scale->type, save_variance->type); auto epsilon = instr.GetAttrs("epsilon"); - auto layout = instr.GetAttrs("data_layout"); + auto layout = instr.GetAttrs("data_layout"); NetBuilder* builder = context.builder(); - BatchNormHelper helper(builder, x->shape, scale->shape, layout, "batch_norm_grad"); + BatchNormHelper helper( + builder, x->shape, scale->shape, layout, "batch_norm_grad"); - auto vars = helper.GradBiasAndScale(x, save_mean, y_grad); - auto bias_grad = vars[0]; + auto vars = helper.GradBiasAndScale(x, save_mean, y_grad); + auto bias_grad = vars[0]; auto sum_of_y_grad_mul_x_mean_diff = vars[1]; - // scale_grad = reduce_sum(y_grad * (x - mean)) * rsqrt(variance + epsilon), shape = [c] - auto scale_grad = builder->Multiply(sum_of_y_grad_mul_x_mean_diff, helper.StdVarianceInv1d(save_variance, epsilon)); + // scale_grad = reduce_sum(y_grad * (x - mean)) * rsqrt(variance + epsilon), + // shape = [c] + auto scale_grad = + builder->Multiply(sum_of_y_grad_mul_x_mean_diff, + helper.StdVarianceInv1d(save_variance, epsilon)); // x_grad = 1/nhw * scale * rsqrt(variance + epsilon) * - // (nhw * y_grad - reduce_sum(y_grad) - (x - mean) * reduce_sum(y_grad * (x - mean)) / (variance + epsilon)) + // (nhw * y_grad - reduce_sum(y_grad) - (x - mean) * reduce_sum(y_grad * (x + // - mean)) / (variance + epsilon)) // => x_grad = tmp0 * (tmp1 - tmp2 - tmp3) - auto scaled_std_variance_inv = builder->Multiply(scale, helper.StdVarianceInv1d(save_variance, epsilon)); - auto element_count_1d = builder->FillConstant(scaled_std_variance_inv->shape, - helper.element_count, - common::UniqName("element_count_1d"), - common::Type2Str(scaled_std_variance_inv->type)); - auto tmp0 = - builder->BroadcastTo(builder->Divide(scaled_std_variance_inv, element_count_1d), x->shape, {helper.channel_dim}); - - auto element_count_4d = builder->FillConstant( - y_grad->shape, helper.element_count, common::UniqName("element_count_4d"), common::Type2Str(y_grad->type)); + auto scaled_std_variance_inv = + builder->Multiply(scale, helper.StdVarianceInv1d(save_variance, epsilon)); + auto element_count_1d = + builder->FillConstant(scaled_std_variance_inv->shape, + helper.element_count, + common::UniqName("element_count_1d"), + common::Type2Str(scaled_std_variance_inv->type)); + auto tmp0 = builder->BroadcastTo( + builder->Divide(scaled_std_variance_inv, element_count_1d), + x->shape, + {helper.channel_dim}); + + auto element_count_4d = + builder->FillConstant(y_grad->shape, + helper.element_count, + common::UniqName("element_count_4d"), + common::Type2Str(y_grad->type)); auto tmp1 = builder->Multiply(y_grad, element_count_4d); auto tmp2 = builder->BroadcastTo(bias_grad, x->shape, {helper.channel_dim}); - auto mean_4d = builder->BroadcastTo(save_mean, x->shape, {helper.channel_dim}); + auto mean_4d = + builder->BroadcastTo(save_mean, x->shape, {helper.channel_dim}); auto x_mean_diff = builder->Subtract(x, mean_4d); - auto sum_of_y_grad_mul_x_mean_diff_4d = - builder->BroadcastTo(sum_of_y_grad_mul_x_mean_diff, x->shape, {helper.channel_dim}); - auto tmp3_0 = builder->Multiply(x_mean_diff, sum_of_y_grad_mul_x_mean_diff_4d); - auto epsilon_1d = builder->FillConstant( - save_variance->shape, epsilon, common::UniqName("epsilon"), common::Type2Str(save_variance->type)); - auto variance_add_eps = builder->Add(save_variance, epsilon_1d); - auto variance_add_eps_4d = builder->BroadcastTo(variance_add_eps, x->shape, {helper.channel_dim}); - auto tmp3 = builder->Divide(tmp3_0, variance_add_eps_4d); - - auto x_grad = builder->Multiply(tmp0, builder->Subtract(builder->Subtract(tmp1, tmp2), tmp3)); + auto sum_of_y_grad_mul_x_mean_diff_4d = builder->BroadcastTo( + sum_of_y_grad_mul_x_mean_diff, x->shape, {helper.channel_dim}); + auto tmp3_0 = + builder->Multiply(x_mean_diff, sum_of_y_grad_mul_x_mean_diff_4d); + auto epsilon_1d = + builder->FillConstant(save_variance->shape, + epsilon, + common::UniqName("epsilon"), + common::Type2Str(save_variance->type)); + auto variance_add_eps = builder->Add(save_variance, epsilon_1d); + auto variance_add_eps_4d = + builder->BroadcastTo(variance_add_eps, x->shape, {helper.channel_dim}); + auto tmp3 = builder->Divide(tmp3_0, variance_add_eps_4d); + + auto x_grad = builder->Multiply( + tmp0, builder->Subtract(builder->Subtract(tmp1, tmp2), tmp3)); context.MapOutToOrigin(x_grad, instr->outputs[0]); context.MapOutToOrigin(scale_grad, instr->outputs[1]); @@ -244,37 +299,41 @@ void batch_norm_grad(const Instruction& instr, const DecomposerContext& context) } void batch_norm(const Instruction& instr, const DecomposerContext& context) { - CHECK_EQ(instr->inputs.size(), 5UL) << "The number of the given inputs is not equal to the required for op " - << instr->op_type; - CHECK_EQ(instr->outputs.size(), 1UL) << "The number of the given outputs is not equal to the required for op " - << instr->op_type; - - auto& x = instr->inputs[0]; - auto& scale = instr->inputs[1]; - auto& bias = instr->inputs[2]; - auto& moving_mean = instr->inputs[3]; + CHECK_EQ(instr->inputs.size(), 5UL) + << "The number of the given inputs is not equal to the required for op " + << instr->op_type; + CHECK_EQ(instr->outputs.size(), 1UL) + << "The number of the given outputs is not equal to the required for op " + << instr->op_type; + + auto& x = instr->inputs[0]; + auto& scale = instr->inputs[1]; + auto& bias = instr->inputs[2]; + auto& moving_mean = instr->inputs[3]; auto& moving_variance = instr->inputs[4]; CHECK_EQ(scale->type, bias->type); CHECK_EQ(scale->type, moving_mean->type); CHECK_EQ(scale->type, moving_variance->type); - float epsilon = instr.GetAttrs("epsilon"); - float momentum = instr.GetAttrs("momentum"); + float epsilon = instr.GetAttrs("epsilon"); + float momentum = instr.GetAttrs("momentum"); std::string layout = instr.GetAttrs("data_layout"); NetBuilder* builder = context.builder(); BatchNormHelper helper(builder, x->shape, scale->shape, layout, "batch_norm"); - auto mean_4d = builder->BroadcastTo(moving_mean, x->shape, {helper.channel_dim}); + auto mean_4d = + builder->BroadcastTo(moving_mean, x->shape, {helper.channel_dim}); // std_variance_inv = rsqrt(variance + epsilon), shape = [c] auto std_variance_inv_4d = helper.StdVarianceInv4d(moving_variance, epsilon); // y = scale * (x - mean) * std_variance_inv + bias, shape = [n, c, h, w] - auto scale_4d = builder->BroadcastTo(scale, x->shape, {helper.channel_dim}); - auto bias_4d = builder->BroadcastTo(bias, x->shape, {helper.channel_dim}); - auto normalized = builder->Multiply(builder->Subtract(x, mean_4d), std_variance_inv_4d); + auto scale_4d = builder->BroadcastTo(scale, x->shape, {helper.channel_dim}); + auto bias_4d = builder->BroadcastTo(bias, x->shape, {helper.channel_dim}); + auto normalized = + builder->Multiply(builder->Subtract(x, mean_4d), std_variance_inv_4d); auto scaled_normalized = builder->Multiply(normalized, scale_4d); - auto y = builder->Add(scaled_normalized, bias_4d); + auto y = builder->Add(scaled_normalized, bias_4d); context.MapOutToOrigin(y, instr->outputs[0]); } @@ -290,13 +349,15 @@ CINN_REGISTER_HELPER(batch_norm_decomposer) { } CINN_REGISTER_HELPER(batch_norm_train_decomposer) { - CINN_DECOMPOSER_REGISTER(batch_norm_train, cinn::frontend::decomposer::batch_norm_train); + CINN_DECOMPOSER_REGISTER(batch_norm_train, + cinn::frontend::decomposer::batch_norm_train); return true; } CINN_REGISTER_HELPER(batch_norm_grad_decomposer) { - CINN_DECOMPOSER_REGISTER(batch_norm_grad, cinn::frontend::decomposer::batch_norm_grad); + CINN_DECOMPOSER_REGISTER(batch_norm_grad, + cinn::frontend::decomposer::batch_norm_grad); return true; } diff --git a/paddle/cinn/frontend/decomposer/batch_norm_test.cc b/paddle/cinn/frontend/decomposer/batch_norm_test.cc index 8c5607dfcf378..c570f44321707 100755 --- a/paddle/cinn/frontend/decomposer/batch_norm_test.cc +++ b/paddle/cinn/frontend/decomposer/batch_norm_test.cc @@ -26,7 +26,8 @@ struct Offset { int h; int w; - Offset(int arg_n, int arg_c, int arg_h, int arg_w) : n(arg_n), c(arg_c), h(arg_h), w(arg_w) {} + Offset(int arg_n, int arg_c, int arg_h, int arg_w) + : n(arg_n), c(arg_c), h(arg_h), w(arg_w) {} int operator()(int idx_n, int idx_c, int idx_h, int idx_w) const { return idx_n * c * h * w + idx_c * h * w + idx_h * w + idx_w; @@ -67,7 +68,9 @@ void ComputeBatchNormTrainRef(const std::vector& x, // sum memset(saved_mean->data(), 0, sizeof(T) * c); - auto func_sum_x = [=](int in, int ic, int ih, int iw) { saved_mean->at(ic) += x[offset(in, ic, ih, iw)]; }; + auto func_sum_x = [=](int in, int ic, int ih, int iw) { + saved_mean->at(ic) += x[offset(in, ic, ih, iw)]; + }; Loop(func_sum_x, n, c, h, w); // saved mean @@ -75,14 +78,16 @@ void ComputeBatchNormTrainRef(const std::vector& x, for (int ic = 0; ic < c; ++ic) { // Checking result of saved_mean: // output[saved_mean], var_name=var_5, shape={32} - // - Total 0 different results, offset=0, 0.00527001 vs 0.00527001, maximum_relative_diff=0(absolute_diff=0) + // - Total 0 different results, offset=0, 0.00527001 vs 0.00527001, + // maximum_relative_diff=0(absolute_diff=0) saved_mean->at(ic) /= element_count; } // square_sum std::vector x_square_mean(c, 0); auto func_sum_square_x = [&](int in, int ic, int ih, int iw) { - x_square_mean.at(ic) += x[offset(in, ic, ih, iw)] * x[offset(in, ic, ih, iw)]; + x_square_mean.at(ic) += + x[offset(in, ic, ih, iw)] * x[offset(in, ic, ih, iw)]; }; Loop(func_sum_square_x, n, c, h, w); @@ -95,11 +100,14 @@ void ComputeBatchNormTrainRef(const std::vector& x, for (int ic = 0; ic < c; ++ic) { // Checking results of saved_variance and std_variance: // output[saved_variance], var_name=var_6, shape={32} - // - Total 0 different results, offset=0, 0.336347 vs 0.336347, maximum_relative_diff=0(absolute_diff=0) - // output[std_variance], var_name=std_variance, shape={32} - // - Total 0 different results, offset=0, 0.579963 vs 0.579963, maximum_relative_diff=0(absolute_diff=0) - saved_variance->at(ic) = x_square_mean[ic] - (saved_mean->at(ic) * saved_mean->at(ic)); - std_variance[ic] = sqrt(saved_variance->at(ic) + epsilon); + // - Total 0 different results, offset=0, 0.336347 vs 0.336347, + // maximum_relative_diff=0(absolute_diff=0) output[std_variance], + // var_name=std_variance, shape={32} + // - Total 0 different results, offset=0, 0.579963 vs 0.579963, + // maximum_relative_diff=0(absolute_diff=0) + saved_variance->at(ic) = + x_square_mean[ic] - (saved_mean->at(ic) * saved_mean->at(ic)); + std_variance[ic] = sqrt(saved_variance->at(ic) + epsilon); } // compute output @@ -110,7 +118,8 @@ void ComputeBatchNormTrainRef(const std::vector& x, // output[y_nobias], var_name=y_nobias, shape={16, 32, 16, 16} // - Total 0 different results, offset=32104, -0.000488288 vs -0.000488288, // maximum_relative_diff=1.19208e-07(absolute_diff=5.82077e-11) - y_nobias[idx] = (x[idx] - saved_mean->at(ic)) * scale[ic] / std_variance[ic]; + y_nobias[idx] = + (x[idx] - saved_mean->at(ic)) * scale[ic] / std_variance[ic]; }; Loop(func_y_nobias, n, c, h, w); @@ -119,10 +128,13 @@ void ComputeBatchNormTrainRef(const std::vector& x, // Checking result of y: // output[y], var_name=var_4, shape={16, 32, 16, 16} // - Total 80 different results, offset=126409, 1.81794e-06 vs 1.80304e-06, - // maximum_relative_diff=0.00826446(absolute_diff=1.49012e-08) For the following case: - // idx=126409, y[idx]=1.80304e-06, y_nobias[idx]=0.2033332, bias[ic]=-0.2033314 + // maximum_relative_diff=0.00826446(absolute_diff=1.49012e-08) For the + // following case: + // idx=126409, y[idx]=1.80304e-06, y_nobias[idx]=0.2033332, + // bias[ic]=-0.2033314 // The computing result of CPU and GPU may have some difference, like - // i=126409, 1.8179417e-06 vs 1.8030405e-06, relative_diff=0.0082644625, absolute_diff=1.4901161e-08 + // i=126409, 1.8179417e-06 vs 1.8030405e-06, relative_diff=0.0082644625, + // absolute_diff=1.4901161e-08 // This case is considered reasonable. y->at(idx) = y_nobias[idx] + bias[ic]; }; @@ -135,32 +147,42 @@ void ComputeBatchNormTrainRef(const std::vector& x, // Checking result of new_moving_mean and new_moving_variance: // output[new_moving_mean], var_name=var_7, shape={32} // - Total 0 different results, offset=9, 0.00123065 vs 0.00123065, - // maximum_relative_diff=9.45967e-08(absolute_diff=1.16415e-10) output[new_moving_variance], var_name=var_8, - // shape={32} + // maximum_relative_diff=9.45967e-08(absolute_diff=1.16415e-10) + // output[new_moving_variance], var_name=var_8, shape={32} // - Total 0 different results, offset=16, -0.00140787 vs -0.00140787, // maximum_relative_diff=5.29211e-06(absolute_diff=7.45058e-09) - new_moving_mean->at(ic) = moving_mean[ic] * factor_0 + saved_mean->at(ic) * factor_1; - new_moving_variance->at(ic) = moving_variance[ic] * factor_0 + saved_variance->at(ic) * factor_1; + new_moving_mean->at(ic) = + moving_mean[ic] * factor_0 + saved_mean->at(ic) * factor_1; + new_moving_variance->at(ic) = + moving_variance[ic] * factor_0 + saved_variance->at(ic) * factor_1; } } TEST(Decomposer, BatchNormTrain) { int n = 16, c = 128, h = 14, w = 14; - float epsilon = 1e-5; - float momentum = 0.9f; + float epsilon = 1e-5; + float momentum = 0.9f; std::string data_layout = "NCHW"; - bool is_test = false; + bool is_test = false; NetBuilder net_builder("batch_norm_train"); std::vector output_names; { - auto x = net_builder.CreateInput(Float(32), {n, c, h, w}, "x"); - auto scale = net_builder.CreateInput(Float(32), {c}, "scale"); - auto bias = net_builder.CreateInput(Float(32), {c}, "bias"); - auto moving_mean = net_builder.CreateInput(Float(32), {c}, "moving_mean"); - auto moving_variance = net_builder.CreateInput(Float(32), {c}, "moving_variance"); - - auto outputs = - net_builder.BatchNorm(x, scale, bias, moving_mean, moving_variance, epsilon, momentum, data_layout, is_test); + auto x = net_builder.CreateInput(Float(32), {n, c, h, w}, "x"); + auto scale = net_builder.CreateInput(Float(32), {c}, "scale"); + auto bias = net_builder.CreateInput(Float(32), {c}, "bias"); + auto moving_mean = net_builder.CreateInput(Float(32), {c}, "moving_mean"); + auto moving_variance = + net_builder.CreateInput(Float(32), {c}, "moving_variance"); + + auto outputs = net_builder.BatchNorm(x, + scale, + bias, + moving_mean, + moving_variance, + epsilon, + momentum, + data_layout, + is_test); for (auto output : outputs) { output_names.push_back(output->id); } @@ -168,7 +190,10 @@ TEST(Decomposer, BatchNormTrain) { auto program = net_builder.Build(); auto target = common::DefaultTarget(); - RunDecomposer(&program, target, cinn::frontend::DefaultTrainingOptimizeOptions().program_passes, output_names); + RunDecomposer(&program, + target, + cinn::frontend::DefaultTrainingOptimizeOptions().program_passes, + output_names); auto graph = std::make_shared(program, target); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); @@ -180,14 +205,16 @@ TEST(Decomposer, BatchNormTrain) { // set input float precision = 1e-3; - std::vector x(n * c * h * w), scale(c), bias(c), moving_mean(c), moving_variance(c); + std::vector x(n * c * h * w), scale(c), bias(c), moving_mean(c), + moving_variance(c); InitRandomVector(&x, n * c * h * w, 0.0f, 1.0f, precision); InitRandomVector(&scale, c, 0.0f, 1.0f, precision); InitRandomVector(&bias, c, 10.0f, 20.0f, precision); InitRandomVector(&moving_mean, c, 0.0f, 1.0f, precision); InitRandomVector(&moving_variance, c, 0.0f, 1.0f, precision); - std::vector y(n * c * h * w), new_moving_mean(c), new_moving_variance(c), saved_mean(c), saved_variance(c); + std::vector y(n * c * h * w), new_moving_mean(c), + new_moving_variance(c), saved_mean(c), saved_variance(c); ComputeBatchNormTrainRef(x, scale, bias, @@ -206,21 +233,26 @@ TEST(Decomposer, BatchNormTrain) { momentum); std::vector>> inputs = { - {"x", x}, {"scale", scale}, {"bias", bias}, {"moving_mean", moving_mean}, {"moving_variance", moving_variance}}; + {"x", x}, + {"scale", scale}, + {"bias", bias}, + {"moving_mean", moving_mean}, + {"moving_variance", moving_variance}}; for (auto& input : inputs) { scope->Var(input.first); auto tensor = scope->GetTensor(input.first); - auto* data = tensor->mutable_data(target); + auto* data = tensor->mutable_data(target); CopyFromVector(input.second, tensor, target); } run_program->Execute(); - std::unordered_map>> outputs_ref = { - {"new_moving_variance", {output_names[4], new_moving_variance}}, - {"new_moving_mean", {output_names[3], new_moving_mean}}, - {"saved_variance", {output_names[2], saved_variance}}, - {"saved_mean", {output_names[1], saved_mean}}, - {"y", {output_names[0], y}}}; + std::unordered_map>> + outputs_ref = { + {"new_moving_variance", {output_names[4], new_moving_variance}}, + {"new_moving_mean", {output_names[3], new_moving_mean}}, + {"saved_variance", {output_names[2], saved_variance}}, + {"saved_mean", {output_names[1], saved_mean}}, + {"y", {output_names[0], y}}}; for (auto& iter : outputs_ref) { auto output = iter.second; @@ -228,7 +260,8 @@ TEST(Decomposer, BatchNormTrain) { std::vector data(tensor->shape().numel()); CopyToVector(tensor, &data); - LOG(INFO) << "output[" << iter.first << "], var_name=" << output.first << ", shape=" << tensor->shape().data(); + LOG(INFO) << "output[" << iter.first << "], var_name=" << output.first + << ", shape=" << tensor->shape().data(); CheckOutput(data, output.second, 1e-8, 1e-4); } } @@ -251,7 +284,9 @@ void ComputeBatchNormGradRef(const std::vector& y_grad, // bias_grad memset(bias_grad->data(), 0, sizeof(T) * c); - auto func_bias_grad = [=](int in, int ic, int ih, int iw) { bias_grad->at(ic) += y_grad[offset(in, ic, ih, iw)]; }; + auto func_bias_grad = [=](int in, int ic, int ih, int iw) { + bias_grad->at(ic) += y_grad[offset(in, ic, ih, iw)]; + }; Loop(func_bias_grad, n, c, h, w); // std_variance @@ -274,7 +309,7 @@ void ComputeBatchNormGradRef(const std::vector& y_grad, // std_norm_grad std::vector std_norm_grad(n * c * h * w); auto func_std_norm_grad = [&](int in, int ic, int ih, int iw) { - int idx = offset(in, ic, ih, iw); + int idx = offset(in, ic, ih, iw); std_norm_grad[idx] = y_grad[idx] * scale[ic]; }; Loop(func_std_norm_grad, n, c, h, w); @@ -282,7 +317,7 @@ void ComputeBatchNormGradRef(const std::vector& y_grad, // x_mean_diff_grad std::vector x_mean_diff_grad(n * c * h * w); auto func_x_mean_diff_grad = [&](int in, int ic, int ih, int iw) { - int idx = offset(in, ic, ih, iw); + int idx = offset(in, ic, ih, iw); x_mean_diff_grad[idx] = std_norm_grad[idx] / std_variance[ic]; }; Loop(func_x_mean_diff_grad, n, c, h, w); @@ -291,7 +326,9 @@ void ComputeBatchNormGradRef(const std::vector& y_grad, std::vector std_variance_grad(c, 0); auto func_std_variance_grad = [&](int in, int ic, int ih, int iw) { int idx = offset(in, ic, ih, iw); - std_variance_grad[ic] += -1.0f * std_norm_grad[idx] * (x[idx] - save_mean[ic]) / (save_variance[ic] + epsilon); + std_variance_grad[ic] += -1.0f * std_norm_grad[idx] * + (x[idx] - save_mean[ic]) / + (save_variance[ic] + epsilon); }; Loop(func_std_variance_grad, n, c, h, w); @@ -305,7 +342,7 @@ void ComputeBatchNormGradRef(const std::vector& y_grad, float element_count = static_cast(n * h * w); std::vector x_grad_0(n * c * h * w); auto func_x_grad_0 = [&](int in, int ic, int ih, int iw) { - int idx = offset(in, ic, ih, iw); + int idx = offset(in, ic, ih, iw); x_grad_0[idx] = x[idx] * (variance_grad_without_mul[ic] / element_count); }; Loop(func_x_grad_0, n, c, h, w); @@ -322,26 +359,29 @@ void ComputeBatchNormGradRef(const std::vector& y_grad, } auto func_x_grad = [=](int in, int ic, int ih, int iw) { - int idx = offset(in, ic, ih, iw); - x_grad->at(idx) = x_mean_diff_grad[idx] + x_grad_0[idx] - minus_mean_grad[ic]; + int idx = offset(in, ic, ih, iw); + x_grad->at(idx) = + x_mean_diff_grad[idx] + x_grad_0[idx] - minus_mean_grad[ic]; }; Loop(func_x_grad, n, c, h, w); } TEST(Decomposer, BatchNormGrad) { int n = 16, c = 128, h = 14, w = 14; - int num = n * c * h * w; + int num = n * c * h * w; float epsilon = 1e-5; NetBuilder net_builder("batch_norm_grad"); std::vector output_names; { - auto y_grad = net_builder.CreateInput(Float(32), {n, c, h, w}, "y_grad"); - auto x = net_builder.CreateInput(Float(32), {n, c, h, w}, "x"); - auto scale = net_builder.CreateInput(Float(32), {c}, "scale"); - auto saved_mean = net_builder.CreateInput(Float(32), {c}, "saved_mean"); - auto saved_variance = net_builder.CreateInput(Float(32), {c}, "saved_variance"); - - auto outputs = net_builder.BatchNormGrad(y_grad, x, scale, saved_mean, saved_variance, epsilon); + auto y_grad = net_builder.CreateInput(Float(32), {n, c, h, w}, "y_grad"); + auto x = net_builder.CreateInput(Float(32), {n, c, h, w}, "x"); + auto scale = net_builder.CreateInput(Float(32), {c}, "scale"); + auto saved_mean = net_builder.CreateInput(Float(32), {c}, "saved_mean"); + auto saved_variance = + net_builder.CreateInput(Float(32), {c}, "saved_variance"); + + auto outputs = net_builder.BatchNormGrad( + y_grad, x, scale, saved_mean, saved_variance, epsilon); for (auto output : outputs) { output_names.push_back(output->id); } @@ -349,7 +389,10 @@ TEST(Decomposer, BatchNormGrad) { auto program = net_builder.Build(); auto target = common::DefaultTarget(); - RunDecomposer(&program, target, cinn::frontend::DefaultTrainingOptimizeOptions().program_passes, output_names); + RunDecomposer(&program, + target, + cinn::frontend::DefaultTrainingOptimizeOptions().program_passes, + output_names); auto graph = std::make_shared(program, target); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); @@ -361,7 +404,8 @@ TEST(Decomposer, BatchNormGrad) { // set input float precision = 1e-3; - std::vector y_grad(num), x(num), scale(c), saved_mean(c, 0), saved_variance(c, 0); + std::vector y_grad(num), x(num), scale(c), saved_mean(c, 0), + saved_variance(c, 0); InitRandomVector(&y_grad, num, 0.0f, 1.0f, precision); InitRandomVector(&x, num, 0.0f, 1.0f, precision); InitRandomVector(&scale, c, 0.0f, 1.0f, precision); @@ -376,11 +420,16 @@ TEST(Decomposer, BatchNormGrad) { float element_count = static_cast(n * h * w); for (int ic = 0; ic < c; ++ic) { saved_mean[ic] /= element_count; - saved_variance[ic] = saved_variance[ic] / element_count - saved_mean[ic] * saved_mean[ic]; + saved_variance[ic] = + saved_variance[ic] / element_count - saved_mean[ic] * saved_mean[ic]; } std::vector>> inputs = { - {"y_grad", y_grad}, {"x", x}, {"scale", scale}, {"saved_mean", saved_mean}, {"saved_variance", saved_variance}}; + {"y_grad", y_grad}, + {"x", x}, + {"scale", scale}, + {"saved_mean", saved_mean}, + {"saved_variance", saved_variance}}; for (auto& input : inputs) { scope->Var(input.first); auto tensor = scope->GetTensor(input.first); @@ -389,13 +438,24 @@ TEST(Decomposer, BatchNormGrad) { run_program->Execute(); std::vector x_grad(num), scale_grad(c), bias_grad(c); - ComputeBatchNormGradRef( - y_grad, x, scale, saved_mean, saved_variance, n, c, h, w, &x_grad, &scale_grad, &bias_grad, epsilon); - - std::unordered_map>> output_refs = { - {"bias_grad", {output_names[2], bias_grad}}, - {"scale_grad", {output_names[1], scale_grad}}, - {"x_grad", {output_names[0], x_grad}}}; + ComputeBatchNormGradRef(y_grad, + x, + scale, + saved_mean, + saved_variance, + n, + c, + h, + w, + &x_grad, + &scale_grad, + &bias_grad, + epsilon); + + std::unordered_map>> + output_refs = {{"bias_grad", {output_names[2], bias_grad}}, + {"scale_grad", {output_names[1], scale_grad}}, + {"x_grad", {output_names[0], x_grad}}}; for (auto& iter : output_refs) { auto output = iter.second; @@ -403,7 +463,8 @@ TEST(Decomposer, BatchNormGrad) { std::vector data(tensor->shape().numel()); CopyToVector(tensor, &data); - LOG(INFO) << "output[" << iter.first << "], var_name=" << output.first << ", shape=" << tensor->shape().data(); + LOG(INFO) << "output[" << iter.first << "], var_name=" << output.first + << ", shape=" << tensor->shape().data(); if (iter.first == "x_grad") { // TODO(Xreki): fix the precision check of x_grad. // CheckOutput(data, output.second, 1e-8, 1e-1); diff --git a/paddle/cinn/frontend/decomposer/broadcast.cc b/paddle/cinn/frontend/decomposer/broadcast.cc index 7f57de1f835c8..67b6b93afc375 100644 --- a/paddle/cinn/frontend/decomposer/broadcast.cc +++ b/paddle/cinn/frontend/decomposer/broadcast.cc @@ -49,11 +49,14 @@ void GetReduceDimsForY(const std::vector& dy_shape, VLOG(3) << "The reduce_dims for Y: " << utils::Join(*reduce_dims, ","); } -void elementwise_add(const Instruction& instr, const DecomposerContext& context) { - CHECK_EQ(instr->inputs.size(), 2UL) << " 2 input tensors for " << instr->op_type; - CHECK_EQ(instr->outputs.size(), 1UL) << "1 output tensor for " << instr->op_type; - auto x = instr->inputs[0]; - auto y = instr->inputs[1]; +void elementwise_add(const Instruction& instr, + const DecomposerContext& context) { + CHECK_EQ(instr->inputs.size(), 2UL) + << " 2 input tensors for " << instr->op_type; + CHECK_EQ(instr->outputs.size(), 1UL) + << "1 output tensor for " << instr->op_type; + auto x = instr->inputs[0]; + auto y = instr->inputs[1]; auto output = instr->outputs[0]; int axis = -1; @@ -62,15 +65,15 @@ void elementwise_add(const Instruction& instr, const DecomposerContext& context) } if (x->shape.size() >= y->shape.size()) { - axis = axis >= 0 ? axis : x->shape.size() - y->shape.size(); + axis = axis >= 0 ? axis : x->shape.size() - y->shape.size(); auto* builder = context.builder(); Variable out; Variable bcast_x = x; Variable bcast_y = y; - // e.g., x.shape = [4, 1, 3], y.shape = [2, 3], aixs = 1 out.shape = [4, 2, 3] - // bcast_axes_x = [0, 1, 2], bcast_axes_y = [1, 2] + // e.g., x.shape = [4, 1, 3], y.shape = [2, 3], aixs = 1 out.shape = [4, 2, + // 3] bcast_axes_x = [0, 1, 2], bcast_axes_y = [1, 2] if (x->shape != output->shape) { std::vector bcast_axes_x(x->shape.size()); std::iota(bcast_axes_x.begin(), bcast_axes_x.end(), 0); @@ -89,7 +92,7 @@ void elementwise_add(const Instruction& instr, const DecomposerContext& context) // map the the output of decomposed operator to the original. context.MapOutToOrigin(out, output); } else { - axis = axis >= 0 ? axis : y->shape.size() - x->shape.size(); + axis = axis >= 0 ? axis : y->shape.size() - x->shape.size(); auto* builder = context.builder(); Variable out; @@ -115,17 +118,21 @@ void elementwise_add(const Instruction& instr, const DecomposerContext& context) } } -void elementwise_add_grad(const Instruction& instr, const DecomposerContext& context) { - CHECK_EQ(instr->inputs.size(), 3UL) << " 3 input tensors for " << instr->op_type; - CHECK_EQ(instr->outputs.size(), 2UL) << "2 output tensors for " << instr->op_type; +void elementwise_add_grad(const Instruction& instr, + const DecomposerContext& context) { + CHECK_EQ(instr->inputs.size(), 3UL) + << " 3 input tensors for " << instr->op_type; + CHECK_EQ(instr->outputs.size(), 2UL) + << "2 output tensors for " << instr->op_type; auto dout = instr->inputs[0]; - auto dx = instr->outputs[0]; - auto dy = instr->outputs[1]; - int axis = instr.GetAttrs("axis"); + auto dx = instr->outputs[0]; + auto dy = instr->outputs[1]; + int axis = instr.GetAttrs("axis"); if (axis < 0 && dx->shape.size() < dy->shape.size()) { - LOG(FATAL) << "Please make sure x'rank greater than or equal to y'rank when axis = -1"; + LOG(FATAL) << "Please make sure x'rank greater than or equal to y'rank " + "when axis = -1"; } - axis = axis >= 0 ? axis : dx->shape.size() - dy->shape.size(); + axis = axis >= 0 ? axis : dx->shape.size() - dy->shape.size(); auto* builder = context.builder(); Variable dx_t; @@ -150,7 +157,7 @@ void elementwise_add_grad(const Instruction& instr, const DecomposerContext& con // may be some extra "1" in the front or back of dy_res's shape. So // the dt_res needs to be reshaped. auto dy_res = builder->ReduceSum(dout, y_reduce_dims, true); - dy_t = builder->Reshape(dy_res, dy->shape); + dy_t = builder->Reshape(dy_res, dy->shape); } // map the the output of decomposed operator to the original. @@ -163,13 +170,15 @@ void elementwise_add_grad(const Instruction& instr, const DecomposerContext& con } // namespace cinn CINN_REGISTER_HELPER(broadcast_decomposers) { - CINN_DECOMPOSER_REGISTER(elementwise_add, cinn::frontend::decomposer::elementwise_add); + CINN_DECOMPOSER_REGISTER(elementwise_add, + cinn::frontend::decomposer::elementwise_add); return true; } CINN_REGISTER_HELPER(broadcast_grad_decomposers) { - CINN_DECOMPOSER_REGISTER(elementwise_add_grad, cinn::frontend::decomposer::elementwise_add_grad); + CINN_DECOMPOSER_REGISTER(elementwise_add_grad, + cinn::frontend::decomposer::elementwise_add_grad); return true; } diff --git a/paddle/cinn/frontend/decomposer/broadcast_test.cc b/paddle/cinn/frontend/decomposer/broadcast_test.cc index 39a564649c99f..93a58649219b5 100644 --- a/paddle/cinn/frontend/decomposer/broadcast_test.cc +++ b/paddle/cinn/frontend/decomposer/broadcast_test.cc @@ -20,102 +20,109 @@ namespace cinn::frontend { TEST(Decomposer, elementwise_add_bcast0) { NetBuilder builder("elementwise_add"); - auto x = builder.CreateInput(Float(32), {4, 1, 20, 10}); - auto y = builder.CreateInput(Float(32), {10, 20}); + auto x = builder.CreateInput(Float(32), {4, 1, 20, 10}); + auto y = builder.CreateInput(Float(32), {10, 20}); auto out = builder.Add(x, y, 1); - std::vector input_names = {x.id().data(), y.id().data()}; - std::vector output_names = {out->id}; + std::vector input_names = {x.id().data(), y.id().data()}; + std::vector output_names = {out->id}; std::vector> output_shapes = {{4, 10, 20, 10}}; RunAndCheckShape(builder, input_names, output_names, output_shapes); } TEST(Decomposer, elementwise_add_bcase1) { NetBuilder builder("elementwise_add"); - auto x = builder.CreateInput(Float(32), {10, 20}); - auto y = builder.CreateInput(Float(32), {4, 1, 20, 10}); + auto x = builder.CreateInput(Float(32), {10, 20}); + auto y = builder.CreateInput(Float(32), {4, 1, 20, 10}); auto out = builder.Add(x, y, 1); - std::vector input_names = {x.id().data(), y.id().data()}; - std::vector output_names = {out->id}; + std::vector input_names = {x.id().data(), y.id().data()}; + std::vector output_names = {out->id}; std::vector> output_shapes = {{4, 10, 20, 10}}; RunAndCheckShape(builder, input_names, output_names, output_shapes); } TEST(Decomposer, elementwise_add_grad_bcast0) { NetBuilder builder("elementwise_add_grad"); - auto dout = builder.CreateInput(Float(32), {4, 10, 20, 10}); - auto x = builder.CreateInput(Float(32), {4, 1, 20, 10}); - auto y = builder.CreateInput(Float(32), {10, 20}); + auto dout = builder.CreateInput(Float(32), {4, 10, 20, 10}); + auto x = builder.CreateInput(Float(32), {4, 1, 20, 10}); + auto y = builder.CreateInput(Float(32), {10, 20}); auto out_grads = builder.ElementwiseAddGrad(dout, x, y, 1); - std::vector input_names = {dout.id().data()}; - std::vector output_names = {out_grads[0]->id, out_grads[1]->id}; + std::vector input_names = {dout.id().data()}; + std::vector output_names = {out_grads[0]->id, out_grads[1]->id}; std::vector> output_shapes = {{4, 1, 20, 10}, {10, 20}}; RunAndCheckShape(builder, input_names, output_names, output_shapes); } TEST(Decomposer, elementwise_add_bcast1) { NetBuilder builder("elementwise_add"); - auto x = builder.CreateInput(Float(32), {32, 64, 32, 32}); - auto y = builder.CreateInput(Float(32), {64}); + auto x = builder.CreateInput(Float(32), {32, 64, 32, 32}); + auto y = builder.CreateInput(Float(32), {64}); auto out = builder.Add(x, y, 1); - auto add_cpu = [](const std::vector& lengths, const std::vector& ptrs) { - float* x = static_cast(ptrs[0]); - float* y = static_cast(ptrs[1]); + auto add_cpu = [](const std::vector& lengths, + const std::vector& ptrs) { + float* x = static_cast(ptrs[0]); + float* y = static_cast(ptrs[1]); float* out = static_cast(ptrs[2]); for (size_t i = 0; i < 32; ++i) { for (size_t j = 0; j < 64; ++j) { for (size_t k = 0; k < 32 * 32; ++k) { - out[(i * 64 + j) * 32 * 32 + k] = x[(i * 64 + j) * 32 * 32 + k] + y[j]; + out[(i * 64 + j) * 32 * 32 + k] = + x[(i * 64 + j) * 32 * 32 + k] + y[j]; } } } }; - std::vector input_names = {x.id().data(), y.id().data()}; - std::vector output_names = {out->id}; + std::vector input_names = {x.id().data(), y.id().data()}; + std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 64, 32, 32}}; - RunAndCheck(builder, input_names, output_names, output_shapes, add_cpu); + RunAndCheck( + builder, input_names, output_names, output_shapes, add_cpu); } TEST(Decomposer, elementwise_add_bcast1_2) { NetBuilder builder("elementwise_add"); - auto x = builder.CreateInput(Float(32), {64}); - auto y = builder.CreateInput(Float(32), {32, 64, 32, 32}); + auto x = builder.CreateInput(Float(32), {64}); + auto y = builder.CreateInput(Float(32), {32, 64, 32, 32}); auto out = builder.Add(x, y, 1); - auto add_cpu = [](const std::vector& lengths, const std::vector& ptrs) { - float* x = static_cast(ptrs[0]); - float* y = static_cast(ptrs[1]); + auto add_cpu = [](const std::vector& lengths, + const std::vector& ptrs) { + float* x = static_cast(ptrs[0]); + float* y = static_cast(ptrs[1]); float* out = static_cast(ptrs[2]); for (size_t i = 0; i < 32; ++i) { for (size_t j = 0; j < 64; ++j) { for (size_t k = 0; k < 32 * 32; ++k) { - out[(i * 64 + j) * 32 * 32 + k] = y[(i * 64 + j) * 32 * 32 + k] + x[j]; + out[(i * 64 + j) * 32 * 32 + k] = + y[(i * 64 + j) * 32 * 32 + k] + x[j]; } } } }; - std::vector input_names = {x.id().data(), y.id().data()}; - std::vector output_names = {out->id}; + std::vector input_names = {x.id().data(), y.id().data()}; + std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 64, 32, 32}}; - RunAndCheck(builder, input_names, output_names, output_shapes, add_cpu); + RunAndCheck( + builder, input_names, output_names, output_shapes, add_cpu); } TEST(Decomposer, elementwise_add_grad_bcast1) { NetBuilder builder("elementwise_add_grad"); - auto dout = builder.CreateInput(Float(32), {32, 64, 32, 32}); - auto x = builder.CreateInput(Float(32), {32, 64, 32, 32}); - auto y = builder.CreateInput(Float(32), {64}); + auto dout = builder.CreateInput(Float(32), {32, 64, 32, 32}); + auto x = builder.CreateInput(Float(32), {32, 64, 32, 32}); + auto y = builder.CreateInput(Float(32), {64}); auto out_grads = builder.ElementwiseAddGrad(dout, x, y, 1); - auto add_grad_cpu = [](const std::vector& lengths, const std::vector& ptrs) { + auto add_grad_cpu = [](const std::vector& lengths, + const std::vector& ptrs) { float* dout = static_cast(ptrs[0]); - float* dx = static_cast(ptrs[1]); - float* dy = static_cast(ptrs[2]); + float* dx = static_cast(ptrs[1]); + float* dy = static_cast(ptrs[2]); for (size_t j = 0; j < 64; ++j) { dy[j] = 0; } @@ -123,159 +130,172 @@ TEST(Decomposer, elementwise_add_grad_bcast1) { for (size_t j = 0; j < 64; ++j) { for (size_t k = 0; k < 32 * 32; ++k) { dx[(i * 64 + j) * 32 * 32 + k] = dout[(i * 64 + j) * 32 * 32 + k]; - dy[j] = dy[j] + dout[(i * 64 + j) * 32 * 32 + k]; + dy[j] = dy[j] + dout[(i * 64 + j) * 32 * 32 + k]; } } } }; - std::vector input_names = {dout.id().data()}; - std::vector output_names = {out_grads[0]->id, out_grads[1]->id}; + std::vector input_names = {dout.id().data()}; + std::vector output_names = {out_grads[0]->id, out_grads[1]->id}; std::vector> output_shapes = {{32, 64, 32, 32}, {64}}; - RunAndCheck(builder, input_names, output_names, output_shapes, add_grad_cpu); + RunAndCheck( + builder, input_names, output_names, output_shapes, add_grad_cpu); } TEST(Decomposer, elementwise_add_bcast2) { NetBuilder builder("elementwise_add"); - auto x = builder.CreateInput(Float(32), {32, 16}); - auto y = builder.CreateInput(Float(32), {1}); + auto x = builder.CreateInput(Float(32), {32, 16}); + auto y = builder.CreateInput(Float(32), {1}); auto out = builder.Add(x, y); - auto add_cpu = [](const std::vector& lengths, const std::vector& ptrs) { - size_t n = lengths[0]; - float* x = static_cast(ptrs[0]); - float* y = static_cast(ptrs[1]); - float* out = static_cast(ptrs[2]); + auto add_cpu = [](const std::vector& lengths, + const std::vector& ptrs) { + size_t n = lengths[0]; + float* x = static_cast(ptrs[0]); + float* y = static_cast(ptrs[1]); + float* out = static_cast(ptrs[2]); float y_data = y[0]; for (size_t i = 0; i < n; ++i) { out[i] = x[i] + y_data; } }; - std::vector input_names = {x.id().data(), y.id().data()}; - std::vector output_names = {out->id}; + std::vector input_names = {x.id().data(), y.id().data()}; + std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 16}}; - RunAndCheck(builder, input_names, output_names, output_shapes, add_cpu); + RunAndCheck( + builder, input_names, output_names, output_shapes, add_cpu); } TEST(Decomposer, elementwise_add_bcast2_2) { NetBuilder builder("elementwise_add"); - auto x = builder.CreateInput(Float(32), {1}); - auto y = builder.CreateInput(Float(32), {32, 16}); + auto x = builder.CreateInput(Float(32), {1}); + auto y = builder.CreateInput(Float(32), {32, 16}); auto out = builder.Add(x, y); - auto add_cpu = [](const std::vector& lengths, const std::vector& ptrs) { - size_t n = 32 * 16; - float* x = static_cast(ptrs[0]); - float* y = static_cast(ptrs[1]); - float* out = static_cast(ptrs[2]); + auto add_cpu = [](const std::vector& lengths, + const std::vector& ptrs) { + size_t n = 32 * 16; + float* x = static_cast(ptrs[0]); + float* y = static_cast(ptrs[1]); + float* out = static_cast(ptrs[2]); float x_data = x[0]; for (size_t i = 0; i < n; ++i) { out[i] = y[i] + x_data; } }; - std::vector input_names = {x.id().data(), y.id().data()}; - std::vector output_names = {out->id}; + std::vector input_names = {x.id().data(), y.id().data()}; + std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 16}}; - RunAndCheck(builder, input_names, output_names, output_shapes, add_cpu); + RunAndCheck( + builder, input_names, output_names, output_shapes, add_cpu); } TEST(Decomposer, elementwise_add_bcast2_3) { constexpr int kLength = 64; - using int_ty = int64_t; + using int_ty = int64_t; NetBuilder builder("elementwise_add"); - auto x = builder.CreateInput(Int(kLength), {32, 16}); - auto y = builder.CreateInput(Int(kLength), {1}); + auto x = builder.CreateInput(Int(kLength), {32, 16}); + auto y = builder.CreateInput(Int(kLength), {1}); auto out = builder.Add(x, y); - auto add_cpu = [](const std::vector& lengths, const std::vector& ptrs) { - size_t n = lengths[0]; - int_ty* x = static_cast(ptrs[0]); - int_ty* y = static_cast(ptrs[1]); - int_ty* out = static_cast(ptrs[2]); + auto add_cpu = [](const std::vector& lengths, + const std::vector& ptrs) { + size_t n = lengths[0]; + int_ty* x = static_cast(ptrs[0]); + int_ty* y = static_cast(ptrs[1]); + int_ty* out = static_cast(ptrs[2]); int_ty y_data = y[0]; for (size_t i = 0; i < n; ++i) { out[i] = x[i] + y_data; } }; - std::vector input_names = {x.id().data(), y.id().data()}; - std::vector output_names = {out->id}; + std::vector input_names = {x.id().data(), y.id().data()}; + std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 16}}; - RunAndCheck(builder, input_names, output_names, output_shapes, add_cpu); + RunAndCheck( + builder, input_names, output_names, output_shapes, add_cpu); } TEST(Decomposer, elementwise_add_grad_bcast2) { NetBuilder builder("elementwise_add_grad"); - auto dout = builder.CreateInput(Float(32), {32, 16}); - auto x = builder.CreateInput(Float(32), {32, 16}); - auto y = builder.CreateInput(Float(32), {1}); + auto dout = builder.CreateInput(Float(32), {32, 16}); + auto x = builder.CreateInput(Float(32), {32, 16}); + auto y = builder.CreateInput(Float(32), {1}); auto out_grads = builder.ElementwiseAddGrad(dout, x, y); - auto add_grad_cpu = [](const std::vector& lengths, const std::vector& ptrs) { - size_t n = lengths[0]; + auto add_grad_cpu = [](const std::vector& lengths, + const std::vector& ptrs) { + size_t n = lengths[0]; float* dout = static_cast(ptrs[0]); - float* dx = static_cast(ptrs[1]); - float* dy = static_cast(ptrs[2]); + float* dx = static_cast(ptrs[1]); + float* dy = static_cast(ptrs[2]); for (size_t i = 0; i < n; ++i) { float tmp = dout[i]; - dx[i] = tmp; + dx[i] = tmp; dy[0] += tmp; } }; - std::vector input_names = {dout.id().data()}; - std::vector output_names = {out_grads[0]->id, out_grads[1]->id}; + std::vector input_names = {dout.id().data()}; + std::vector output_names = {out_grads[0]->id, out_grads[1]->id}; std::vector> output_shapes = {{32, 16}, {1}}; - RunAndCheck(builder, input_names, output_names, output_shapes, add_grad_cpu); + RunAndCheck( + builder, input_names, output_names, output_shapes, add_grad_cpu); } TEST(Decomposer, elementwise_add_same_dims) { NetBuilder builder("elementwise_add"); - auto x = builder.CreateInput(Float(32), {32, 16}); - auto y = builder.CreateInput(Float(32), {32, 16}); + auto x = builder.CreateInput(Float(32), {32, 16}); + auto y = builder.CreateInput(Float(32), {32, 16}); auto out = builder.Add(x, y); - auto add_cpu = [](const std::vector& lengths, const std::vector& ptrs) { - size_t n = lengths[0]; - float* x = static_cast(ptrs[0]); - float* y = static_cast(ptrs[1]); + auto add_cpu = [](const std::vector& lengths, + const std::vector& ptrs) { + size_t n = lengths[0]; + float* x = static_cast(ptrs[0]); + float* y = static_cast(ptrs[1]); float* out = static_cast(ptrs[2]); for (size_t i = 0; i < n; ++i) { out[i] = x[i] + y[i]; } }; - std::vector input_names = {x.id().data(), y.id().data()}; - std::vector output_names = {out->id}; + std::vector input_names = {x.id().data(), y.id().data()}; + std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 16}}; - RunAndCheck(builder, input_names, output_names, output_shapes, add_cpu); + RunAndCheck( + builder, input_names, output_names, output_shapes, add_cpu); } TEST(Decomposer, elementwise_add_grad_same_dims) { NetBuilder builder("elementwise_add_grad"); - auto dout = builder.CreateInput(Float(32), {32, 16}); - auto x = builder.CreateInput(Float(32), {32, 16}); - auto y = builder.CreateInput(Float(32), {32, 16}); + auto dout = builder.CreateInput(Float(32), {32, 16}); + auto x = builder.CreateInput(Float(32), {32, 16}); + auto y = builder.CreateInput(Float(32), {32, 16}); auto out_grads = builder.ElementwiseAddGrad(dout, x, y); - auto add_grad_cpu = [](const std::vector& lengths, const std::vector& ptrs) { - size_t n = lengths[0]; + auto add_grad_cpu = [](const std::vector& lengths, + const std::vector& ptrs) { + size_t n = lengths[0]; float* dout = static_cast(ptrs[0]); - float* dx = static_cast(ptrs[1]); - float* dy = static_cast(ptrs[2]); + float* dx = static_cast(ptrs[1]); + float* dy = static_cast(ptrs[2]); for (size_t i = 0; i < n; ++i) { float tmp = dout[i]; - dx[i] = tmp; - dy[i] = tmp; + dx[i] = tmp; + dy[i] = tmp; } }; - std::vector input_names = {dout.id().data()}; - std::vector output_names = {out_grads[0]->id, out_grads[1]->id}; + std::vector input_names = {dout.id().data()}; + std::vector output_names = {out_grads[0]->id, out_grads[1]->id}; std::vector> output_shapes = {{32, 16}, {32, 16}}; - RunAndCheck(builder, input_names, output_names, output_shapes, add_grad_cpu); + RunAndCheck( + builder, input_names, output_names, output_shapes, add_grad_cpu); } } // namespace cinn::frontend diff --git a/paddle/cinn/frontend/decomposer/elementwise.cc b/paddle/cinn/frontend/decomposer/elementwise.cc index 9fddfde5de4a8..c284219642a17 100644 --- a/paddle/cinn/frontend/decomposer/elementwise.cc +++ b/paddle/cinn/frontend/decomposer/elementwise.cc @@ -20,10 +20,12 @@ namespace frontend { namespace decomposer { void sum(const Instruction& instr, const DecomposerContext& context) { - CHECK_GT(instr->inputs.size(), 0UL) << "At least 1 input tensor for " << instr->op_type; - CHECK_EQ(instr->outputs.size(), 1UL) << "1 output tensor for " << instr->op_type; - auto inputs = instr->inputs; - auto output = instr->outputs[0]; + CHECK_GT(instr->inputs.size(), 0UL) + << "At least 1 input tensor for " << instr->op_type; + CHECK_EQ(instr->outputs.size(), 1UL) + << "1 output tensor for " << instr->op_type; + auto inputs = instr->inputs; + auto output = instr->outputs[0]; auto* builder = context.builder(); auto sum = builder->Identity(inputs[0]); diff --git a/paddle/cinn/frontend/decomposer/elementwise_test.cc b/paddle/cinn/frontend/decomposer/elementwise_test.cc index 9a8b02bbe1d7e..6f02608ccc378 100644 --- a/paddle/cinn/frontend/decomposer/elementwise_test.cc +++ b/paddle/cinn/frontend/decomposer/elementwise_test.cc @@ -20,26 +20,29 @@ namespace cinn::frontend { TEST(Decomposer, sum) { NetBuilder builder("sum"); - auto x = builder.CreateInput(Float(32), {32, 16}); - auto y = builder.CreateInput(Float(32), {32, 16}); - auto z = builder.CreateInput(Float(32), {32, 16}); + auto x = builder.CreateInput(Float(32), {32, 16}); + auto y = builder.CreateInput(Float(32), {32, 16}); + auto z = builder.CreateInput(Float(32), {32, 16}); auto out = builder.Sum({x, y, z}); - auto sum_cpu = [](const std::vector& lengths, const std::vector& ptrs) { - size_t n = lengths[0]; - float* x = static_cast(ptrs[0]); - float* y = static_cast(ptrs[1]); - float* z = static_cast(ptrs[2]); + auto sum_cpu = [](const std::vector& lengths, + const std::vector& ptrs) { + size_t n = lengths[0]; + float* x = static_cast(ptrs[0]); + float* y = static_cast(ptrs[1]); + float* z = static_cast(ptrs[2]); float* out = static_cast(ptrs[3]); for (size_t i = 0; i < n; ++i) { out[i] = x[i] + y[i] + z[i]; } }; - std::vector input_names = {x.id().data(), y.id().data(), z.id().data()}; - std::vector output_names = {out->id}; + std::vector input_names = { + x.id().data(), y.id().data(), z.id().data()}; + std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 16}}; - RunAndCheck(builder, input_names, output_names, output_shapes, sum_cpu); + RunAndCheck( + builder, input_names, output_names, output_shapes, sum_cpu); } } // namespace cinn::frontend diff --git a/paddle/cinn/frontend/decomposer/test_helper.cc b/paddle/cinn/frontend/decomposer/test_helper.cc index a44111f46b3c1..c27b2d8eaae13 100644 --- a/paddle/cinn/frontend/decomposer/test_helper.cc +++ b/paddle/cinn/frontend/decomposer/test_helper.cc @@ -24,7 +24,11 @@ void RunDecomposer(Program* prog, for (int i = 0; i < prog->size(); i++) { VLOG(1) << "instruction: " << (*prog)[i]; } - ProgramPass::Apply(prog, std::unordered_set(fetch_ids.begin(), fetch_ids.end()), target, passes); + ProgramPass::Apply( + prog, + std::unordered_set(fetch_ids.begin(), fetch_ids.end()), + target, + passes); VLOG(1) << "===================== After Program Pass ====================="; for (int i = 0; i < prog->size(); i++) { VLOG(1) << "instruction: " << (*prog)[i]; @@ -32,7 +36,8 @@ void RunDecomposer(Program* prog, } template <> -void InitRandomVector(std::vector* vec, size_t numel, int low, int high, float precision) { +void InitRandomVector( + std::vector* vec, size_t numel, int low, int high, float precision) { std::random_device seed; std::default_random_engine engine(seed()); std::uniform_int_distribution dist(low, high); @@ -44,39 +49,48 @@ void InitRandomVector(std::vector* vec, size_t numel, int low, int hig } template <> -void CopyFromVector(const std::vector& vec, hlir::framework::Tensor tensor, Target target) { +void CopyFromVector(const std::vector& vec, + hlir::framework::Tensor tensor, + Target target) { auto* data = tensor->mutable_data(target); size_t numel = tensor->shape().numel(); CHECK_EQ(vec.size(), numel); #ifdef CINN_WITH_CUDA - // why not use vector ? Because to optimizes space, each value is stored in a single bit. - // So that the vector doesn't has data() function. - CHECK_EQ(sizeof(bool), sizeof(char)) << "The test need ensure the byte size of bool equal to the byte size of char."; + // why not use vector ? Because to optimizes space, each value is stored + // in a single bit. So that the vector doesn't has data() function. + CHECK_EQ(sizeof(bool), sizeof(char)) + << "The test need ensure the byte size of bool equal to the byte size of " + "char."; std::vector vec_char(numel); for (int i = 0; i < numel; ++i) vec_char[i] = static_cast(vec[i]); - cudaMemcpy(data, vec_char.data(), numel * sizeof(bool), cudaMemcpyHostToDevice); + cudaMemcpy( + data, vec_char.data(), numel * sizeof(bool), cudaMemcpyHostToDevice); #else std::copy(vec.begin(), vec.end(), data); #endif } template <> -void CopyToVector(const hlir::framework::Tensor tensor, std::vector* vec) { +void CopyToVector(const hlir::framework::Tensor tensor, + std::vector* vec) { auto* data = tensor->data(); size_t numel = tensor->shape().numel(); vec->resize(numel); #ifdef CINN_WITH_CUDA - // why not use vector ? Because to optimizes space, each value is stored in a single bit. - // So that the vector doesn't has data() function. - CHECK_EQ(sizeof(bool), sizeof(char)) << "The test need ensure the byte size of bool equal to the byte size of char."; + // why not use vector ? Because to optimizes space, each value is stored + // in a single bit. So that the vector doesn't has data() function. + CHECK_EQ(sizeof(bool), sizeof(char)) + << "The test need ensure the byte size of bool equal to the byte size of " + "char."; std::vector vec_char(numel); - cudaMemcpy(vec_char.data(), data, numel * sizeof(bool), cudaMemcpyDeviceToHost); + cudaMemcpy( + vec_char.data(), data, numel * sizeof(bool), cudaMemcpyDeviceToHost); for (int i = 0; i < numel; ++i) vec->at(i) = static_cast(vec_char[i]); #else for (size_t i = 0; i < numel; ++i) { diff --git a/paddle/cinn/frontend/decomposer/test_helper.h b/paddle/cinn/frontend/decomposer/test_helper.h index 995b65c607c0e..f2d9dddabda8b 100644 --- a/paddle/cinn/frontend/decomposer/test_helper.h +++ b/paddle/cinn/frontend/decomposer/test_helper.h @@ -34,7 +34,8 @@ namespace cinn::frontend { -using CPUKernelFunc = std::function& lengths, const std::vector& ptrs)>; +using CPUKernelFunc = std::function& lengths, + const std::vector& ptrs)>; template > std::ostream& operator<<(std::ostream& os, const std::vector& vec) { @@ -53,25 +54,31 @@ std::ostream& operator<<(std::ostream& os, const std::vector& vec) { } template -void InitRandomVector( - std::vector* vec, size_t numel, T low = static_cast(0), T high = static_cast(1), float precision = 1e-5) { +void InitRandomVector(std::vector* vec, + size_t numel, + T low = static_cast(0), + T high = static_cast(1), + float precision = 1e-5) { std::random_device seed; std::default_random_engine engine(seed()); std::uniform_real_distribution dist(low, high); vec->resize(numel); for (size_t i = 0; i < numel; ++i) { - T value = static_cast(dist(engine)); - int coeff = static_cast(value / precision); + T value = static_cast(dist(engine)); + int coeff = static_cast(value / precision); vec->at(i) = precision * static_cast(coeff); } } template <> -void InitRandomVector(std::vector* vec, size_t numel, int low, int high, float precision); +void InitRandomVector( + std::vector* vec, size_t numel, int low, int high, float precision); template -void CopyFromVector(const std::vector& vec, hlir::framework::Tensor tensor, Target target) { +void CopyFromVector(const std::vector& vec, + hlir::framework::Tensor tensor, + Target target) { auto* data = tensor->mutable_data(target); size_t numel = tensor->shape().numel(); @@ -81,7 +88,8 @@ void CopyFromVector(const std::vector& vec, hlir::framework::Tensor tensor, T #ifdef CINN_WITH_CUDA cudaMemcpy(data, vec.data(), numel * sizeof(T), cudaMemcpyHostToDevice); #else - LOG(FATAL) << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; + LOG(FATAL) + << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; #endif } else { std::copy(vec.begin(), vec.end(), data); @@ -89,7 +97,9 @@ void CopyFromVector(const std::vector& vec, hlir::framework::Tensor tensor, T } template <> -void CopyFromVector(const std::vector& vec, hlir::framework::Tensor tensor, Target target); +void CopyFromVector(const std::vector& vec, + hlir::framework::Tensor tensor, + Target target); template void CopyToVector(const hlir::framework::Tensor tensor, std::vector* vec) { @@ -108,17 +118,23 @@ void CopyToVector(const hlir::framework::Tensor tensor, std::vector* vec) { } template <> -void CopyToVector(const hlir::framework::Tensor tensor, std::vector* vec); +void CopyToVector(const hlir::framework::Tensor tensor, + std::vector* vec); template -void CheckOutput(const std::vector& actual, const std::vector& expect, float atol = 1e-8, float rtol = 1e-5) { +void CheckOutput(const std::vector& actual, + const std::vector& expect, + float atol = 1e-8, + float rtol = 1e-5) { CHECK_EQ(actual.size(), expect.size()); - auto allclose = [](T a, T e, float atol, float rtol) { return abs(a - e) <= (atol + rtol * abs(e)); }; + auto allclose = [](T a, T e, float atol, float rtol) { + return abs(a - e) <= (atol + rtol * abs(e)); + }; float max_diff = 0.0f; - int offset = 0; - int num_diffs = 0; + int offset = 0; + int num_diffs = 0; size_t numel = actual.size(); for (size_t i = 0; i < numel; ++i) { @@ -127,16 +143,21 @@ void CheckOutput(const std::vector& actual, const std::vector& expect, flo float relative_diff = abs(absolute_diff / expect[i]); if (relative_diff > max_diff) { max_diff = relative_diff; - offset = i; + offset = i; } num_diffs += 1; - VLOG(4) << "- i=" << i << ", " << std::setprecision(8) << actual[i] << " (actual) vs " << std::setprecision(8) - << expect[i] << " (expect), relative_diff=" << relative_diff << ", absolute_diff=" << absolute_diff; + VLOG(4) << "- i=" << i << ", " << std::setprecision(8) << actual[i] + << " (actual) vs " << std::setprecision(8) << expect[i] + << " (expect), relative_diff=" << relative_diff + << ", absolute_diff=" << absolute_diff; } } - LOG(INFO) << "- Total " << num_diffs << " different results, offset=" << offset << ", " << actual[offset] - << " (actual) vs " << expect[offset] << " (expect), maximum_relative_diff=" << max_diff - << " (absolute_diff=" << abs((actual[offset] - expect[offset])) << ")"; + LOG(INFO) << "- Total " << num_diffs + << " different results, offset=" << offset << ", " << actual[offset] + << " (actual) vs " << expect[offset] + << " (expect), maximum_relative_diff=" << max_diff + << " (absolute_diff=" << abs((actual[offset] - expect[offset])) + << ")"; CHECK_EQ(num_diffs, 0); } @@ -168,7 +189,7 @@ void ComputeReferenceCpu(const std::vector>& input_vecs, void RunDecomposer(Program* prog, const Target& target, - const std::vector& passes = {"Decomposer"}, + const std::vector& passes = {"Decomposer"}, const std::vector& fetch_ids = {}); template @@ -176,12 +197,12 @@ void RunAndCheckShape(NetBuilder& builder, const std::vector& input_names, const std::vector& output_names, const std::vector>& output_shapes, - std::vector>* input_vecs = nullptr, + std::vector>* input_vecs = nullptr, std::vector>* output_vecs = nullptr, - T low = 0, - T high = 1, - const std::vector& passes = {"Decomposer"}) { - auto prog = builder.Build(); + T low = 0, + T high = 1, + const std::vector& passes = {"Decomposer"}) { + auto prog = builder.Build(); Target target = common::DefaultTarget(); RunDecomposer(&prog, target, passes, output_names); auto graph = std::make_shared(prog, target); @@ -191,7 +212,8 @@ void RunAndCheckShape(NetBuilder& builder, auto runtime_program = gc.Build(); std::vector> input_vecs_internal; - std::vector>* input_vecs_ptr = input_vecs ? input_vecs : &input_vecs_internal; + std::vector>* input_vecs_ptr = + input_vecs ? input_vecs : &input_vecs_internal; for (size_t i = 0; i < input_names.size(); ++i) { scope->Var(input_names[i]); auto tensor = scope->GetTensor(input_names[i]); @@ -221,20 +243,30 @@ void RunAndCheck(NetBuilder& builder, const std::vector& output_names, const std::vector>& output_shapes, CPUKernelFunc cpu_kernel_func, - T low = 0, - T high = 1, - float atol = 1e-8, - float rtol = 1e-5, + T low = 0, + T high = 1, + float atol = 1e-8, + float rtol = 1e-5, const std::vector& passes = {"Decomposer"}) { std::vector> input_vecs; std::vector> output_vecs; - RunAndCheckShape(builder, input_names, output_names, output_shapes, &input_vecs, &output_vecs, low, high, passes); + RunAndCheckShape(builder, + input_names, + output_names, + output_shapes, + &input_vecs, + &output_vecs, + low, + high, + passes); std::vector> output_refs; - ComputeReferenceCpu(input_vecs, output_vecs, &output_refs, cpu_kernel_func); + ComputeReferenceCpu( + input_vecs, output_vecs, &output_refs, cpu_kernel_func); for (size_t i = 0; i < output_vecs.size(); ++i) { - LOG(INFO) << "Check the " << i << "-th output, name=" << output_names[i] << ", shape=" << output_shapes[i]; + LOG(INFO) << "Check the " << i << "-th output, name=" << output_names[i] + << ", shape=" << output_shapes[i]; CheckOutput(output_vecs[i], output_refs[i], atol, rtol); } } diff --git a/paddle/cinn/frontend/decomposer/top_k.cc b/paddle/cinn/frontend/decomposer/top_k.cc index 5f8e9401a94d0..4105978896ea5 100644 --- a/paddle/cinn/frontend/decomposer/top_k.cc +++ b/paddle/cinn/frontend/decomposer/top_k.cc @@ -20,24 +20,27 @@ namespace frontend { namespace decomposer { void top_k(const Instruction& instr, const DecomposerContext& context) { - CHECK_EQ(instr->inputs.size(), 1UL) << " 1 input tensor for " << instr->op_type; - CHECK_EQ(instr->outputs.size(), 2UL) << "2 output tensors for " << instr->op_type; - auto x = instr->inputs[0]; - auto output = instr->outputs[0]; + CHECK_EQ(instr->inputs.size(), 1UL) + << " 1 input tensor for " << instr->op_type; + CHECK_EQ(instr->outputs.size(), 2UL) + << "2 output tensors for " << instr->op_type; + auto x = instr->inputs[0]; + auto output = instr->outputs[0]; auto indices = instr->outputs[1]; auto* builder = context.builder(); - int k = instr.GetAttrs("k"); + int k = instr.GetAttrs("k"); CHECK_GT(k, 0) << "The attribute k must be greater than 0."; int axis = instr.GetAttrs("axis"); if (axis < 0) { axis += x->shape.size(); } - auto sort_tmp = builder->Sort(x, axis, false); - auto sort_out = builder->Slice(sort_tmp, {axis}, {0}, {k}); + auto sort_tmp = builder->Sort(x, axis, false); + auto sort_out = builder->Slice(sort_tmp, {axis}, {0}, {k}); auto argsort_tmp = builder->ArgSort(x, axis, false).at(0); - auto argsort_out = builder->Cast(builder->Slice(argsort_tmp, {axis}, {0}, {k}), "int64"); + auto argsort_out = + builder->Cast(builder->Slice(argsort_tmp, {axis}, {0}, {k}), "int64"); // map the the output of decomposed operator to the original. context.MapOutToOrigin(sort_out, output); diff --git a/paddle/cinn/frontend/decomposer/top_k_test.cc b/paddle/cinn/frontend/decomposer/top_k_test.cc index 495ddfa713d1a..38f2116cc4a47 100644 --- a/paddle/cinn/frontend/decomposer/top_k_test.cc +++ b/paddle/cinn/frontend/decomposer/top_k_test.cc @@ -32,7 +32,8 @@ TEST(Decomposer, top_k_decomposer) { auto target = common::DefaultTarget(); RunDecomposer(&program, target); - auto graph = std::make_shared(program, output_names, target); + auto graph = + std::make_shared(program, output_names, target); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); @@ -46,7 +47,7 @@ TEST(Decomposer, top_k_decomposer) { for (auto& input : inputs) { scope->Var(input.first); auto tensor = scope->GetTensor(input.first); - auto* data = tensor->mutable_data(target); + auto* data = tensor->mutable_data(target); CopyFromVector(input.second, tensor, target); } run_program->Execute(); diff --git a/paddle/cinn/frontend/decomposer_registry.h b/paddle/cinn/frontend/decomposer_registry.h index bcd0277316d35..bbad4864f4809 100644 --- a/paddle/cinn/frontend/decomposer_registry.h +++ b/paddle/cinn/frontend/decomposer_registry.h @@ -29,7 +29,8 @@ class Decomposer; class DecomposerContext { public: - explicit DecomposerContext(NetBuilder* builder, absl::flat_hash_map* var_map) + explicit DecomposerContext( + NetBuilder* builder, absl::flat_hash_map* var_map) : builder_(builder), var_map_(var_map) {} NetBuilder* builder() const { return builder_; }; @@ -37,13 +38,18 @@ class DecomposerContext { // Map the new var to the original var. void MapOutToOrigin(const Variable& new_var, const Variable& ori_var) const { if (new_var->shape != ori_var->shape) { - LOG(FATAL) << "The output shape should be equal to the original. But received : " << new_var->id << ".shape=[" - << utils::Join(new_var->shape, ", ") << "] and the original var " << ori_var->id << ".shape=[" - << utils::Join(ori_var->shape, ", ") << "]."; + LOG(FATAL) + << "The output shape should be equal to the original. But received : " + << new_var->id << ".shape=[" << utils::Join(new_var->shape, ", ") + << "] and the original var " << ori_var->id << ".shape=[" + << utils::Join(ori_var->shape, ", ") << "]."; } if (new_var->type != ori_var->type) { - LOG(FATAL) << "The output type shoule be equal to the original. But received : " << new_var->id - << ".type=" << new_var->type << " and the original var " << ori_var->id << ".type=" << ori_var->type; + LOG(FATAL) + << "The output type shoule be equal to the original. But received : " + << new_var->id << ".type=" << new_var->type + << " and the original var " << ori_var->id + << ".type=" << ori_var->type; } (*var_map_)[new_var->id] = ori_var; } @@ -60,13 +66,16 @@ class InstrDecomposerRegistry : public Registry { return &x; } - inline const Decomposer* Get(const std::string& op_name, const common::Target& target) { + inline const Decomposer* Get(const std::string& op_name, + const common::Target& target) { const Decomposer* decomposer = Find(op_name, target); - CHECK(decomposer) << "Decomposer for [" << op_name << ", " << target << "] is not registered"; + CHECK(decomposer) << "Decomposer for [" << op_name << ", " << target + << "] is not registered"; return decomposer; } - inline const Decomposer* Find(const std::string& name, const common::Target& target) { + inline const Decomposer* Find(const std::string& name, + const common::Target& target) { return Registry::Find(name + "_" + target.arch_str()); } @@ -77,14 +86,17 @@ class InstrDecomposerRegistry : public Registry { class Decomposer { public: - using DecomposerKernel = std::function; + using DecomposerKernel = + std::function; Decomposer& SetBody(const DecomposerKernel& kernel) { kernel_ = kernel; return *this; } - void Run(const Instruction& instr, const DecomposerContext& context) const { kernel_(instr, context); } + void Run(const Instruction& instr, const DecomposerContext& context) const { + kernel_(instr, context); + } std::string name; @@ -97,13 +109,14 @@ class Decomposer { ->__REGISTER__(std::string(#name) + "_" + target.arch_str()) \ .SetBody(kernel) -#define CINN_DECOMPOSER_REGISTER_ALL(name, kernel) \ - static std::vector<::cinn::common::Target> all_targets = {::cinn::common::DefaultHostTarget(), \ - ::cinn::common::DefaultNVGPUTarget()}; \ - for (auto& target : all_targets) { \ - ::cinn::frontend::InstrDecomposerRegistry::Global() \ - ->__REGISTER__(std::string(#name) + "_" + target.arch_str()) \ - .SetBody(kernel); \ +#define CINN_DECOMPOSER_REGISTER_ALL(name, kernel) \ + static std::vector<::cinn::common::Target> all_targets = { \ + ::cinn::common::DefaultHostTarget(), \ + ::cinn::common::DefaultNVGPUTarget()}; \ + for (auto& target : all_targets) { \ + ::cinn::frontend::InstrDecomposerRegistry::Global() \ + ->__REGISTER__(std::string(#name) + "_" + target.arch_str()) \ + .SetBody(kernel); \ } /** @@ -121,8 +134,11 @@ class Decomposer { * \endcode */ #define GET_MACRO(_0, _1, _2, FUNC, ...) FUNC -#define CINN_DECOMPOSER_REGISTER(...) \ - GET_MACRO(__VA_ARGS__, CINN_DECOMPOSER_REGISTER_CORE, CINN_DECOMPOSER_REGISTER_ALL)(__VA_ARGS__) +#define CINN_DECOMPOSER_REGISTER(...) \ + GET_MACRO(__VA_ARGS__, \ + CINN_DECOMPOSER_REGISTER_CORE, \ + CINN_DECOMPOSER_REGISTER_ALL) \ + (__VA_ARGS__) } // namespace frontend } // namespace cinn diff --git a/paddle/cinn/frontend/interpreter.cc b/paddle/cinn/frontend/interpreter.cc index c72bffc2ffc6c..6a432d4f58414 100755 --- a/paddle/cinn/frontend/interpreter.cc +++ b/paddle/cinn/frontend/interpreter.cc @@ -29,8 +29,11 @@ DECLARE_bool(enable_auto_tuner); namespace cinn::frontend { struct Interpreter::Impl { - Impl(const std::vector& input_names, const std::vector& input_shapes) - : scope_(std::make_shared()), input_names_(input_names), input_shapes_(input_shapes) {} + Impl(const std::vector& input_names, + const std::vector& input_shapes) + : scope_(std::make_shared()), + input_names_(input_names), + input_shapes_(input_shapes) {} /** * Build the model. @@ -67,15 +70,16 @@ void Interpreter::LoadPaddleModel(const std::string& model_dir, for (int idx = 0; idx < impl_->input_names_.size(); ++idx) { input_shape_map[impl_->input_names_[idx]] = impl_->input_shapes_[idx]; } - auto programTuple = LoadPaddleProgram(model_dir, impl_->scope_.get(), input_shape_map, params_combined, target); - auto& program = std::get<0>(programTuple); - auto& var_map = std::get<1>(programTuple); + auto programTuple = LoadPaddleProgram( + model_dir, impl_->scope_.get(), input_shape_map, params_combined, target); + auto& program = std::get<0>(programTuple); + auto& var_map = std::get<1>(programTuple); auto& var_map_paddle_to_program = std::get<2>(programTuple); - auto& fetch_names = std::get<3>(programTuple); + auto& fetch_names = std::get<3>(programTuple); impl_->program_.reset(program.release()); - impl_->var_map_ = var_map; + impl_->var_map_ = var_map; impl_->var_map_paddle_to_cinn_ = var_map_paddle_to_program; - impl_->fetch_names_ = fetch_names; + impl_->fetch_names_ = fetch_names; impl_->Build(target, model_name); } @@ -93,12 +97,14 @@ hlir::framework::Tensor Interpreter::GetTensor(const std::string& name) { auto it = impl_->var_map_paddle_to_cinn_.find(name); if (it == impl_->var_map_paddle_to_cinn_.end()) { LOG(FATAL) << "No variable called [" << name - << "] found in executor\nThe existing vars: " << utils::Join(impl_->scope_->var_names(), ", "); + << "] found in executor\nThe existing vars: " + << utils::Join(impl_->scope_->var_names(), ", "); } return impl_->scope_->GetTensor(it->second); } -void Interpreter::Impl::Build(const Target& target, const std::string& model_name) { +void Interpreter::Impl::Build(const Target& target, + const std::string& model_name) { CHECK(!var_map_.empty()); VLOG(3) << "Program:\n" << *program_; // applay frontend pass @@ -109,22 +115,26 @@ void Interpreter::Impl::Build(const Target& target, const std::string& model_nam } auto graph = Optimize(program_.get(), fetch_var_ids, target); - // auto graph = std::make_shared(*program_, target); + // auto graph = + // std::make_shared(*program_, target); graph->attrs["model_name"] = std::make_shared(model_name); - scope_ = hlir::framework::BuildScope(target, graph, scope_); + scope_ = hlir::framework::BuildScope(target, graph, scope_); - graph_compiler_.reset(new hlir::framework::GraphCompiler(target, scope_, graph)); + graph_compiler_.reset( + new hlir::framework::GraphCompiler(target, scope_, graph)); hlir::framework::GraphCompiler::CompileOptions options; options.with_instantiate_variables = true; if (FLAGS_enable_auto_tuner) { VLOG(4) << "Compile with auto-tune"; auto_schedule::AutoTuner auto_tuner(target, graph.get()); - auto_tuner.Initialize(auto_schedule::AutoTuner::Config(), graph_compiler_.get()); + auto_tuner.Initialize(auto_schedule::AutoTuner::Config(), + graph_compiler_.get()); auto_schedule::TuningOptions tuning_options; auto_schedule::TuningResult tuning_result = auto_tuner.Tune(tuning_options); options.Apply(tuning_result); } - runtime_program_ = graph_compiler_->Build(options, std::move(fetch_var_ids)).runtime_program; + runtime_program_ = + graph_compiler_->Build(options, std::move(fetch_var_ids)).runtime_program; runtime_program_->PreRun(); } @@ -133,8 +143,9 @@ std::shared_ptr Interpreter::GetScope() { return impl_->scope_; } -Interpreter::Interpreter(const std::vector& input_names, - const std::vector& input_shapes) +Interpreter::Interpreter( + const std::vector& input_names, + const std::vector& input_shapes) : impl_(new Impl(input_names, input_shapes)) {} } // namespace cinn::frontend diff --git a/paddle/cinn/frontend/interpreter.h b/paddle/cinn/frontend/interpreter.h index 36ef8aa3c15b6..f70a97eb78574 100755 --- a/paddle/cinn/frontend/interpreter.h +++ b/paddle/cinn/frontend/interpreter.h @@ -32,16 +32,18 @@ namespace frontend { */ class Interpreter final { public: - Interpreter(const std::vector& input_names, const std::vector& input_shapes); + Interpreter(const std::vector& input_names, + const std::vector& input_shapes); /** * Load a Paddle model. * @param model_dir The directory path to the model. - * @param params_combined Whether the parameters are composed to a single file. + * @param params_combined Whether the parameters are composed to a single + * file. */ void LoadPaddleModel(const std::string& model_dir, const Target& target, - bool params_combined = false, + bool params_combined = false, const std::string& model_name = ""); /** diff --git a/paddle/cinn/frontend/net_builder.cc b/paddle/cinn/frontend/net_builder.cc index 1a495b430aee6..6a059f8acec3f 100644 --- a/paddle/cinn/frontend/net_builder.cc +++ b/paddle/cinn/frontend/net_builder.cc @@ -18,12 +18,12 @@ #include #include +#include "glog/logging.h" #include "paddle/cinn/frontend/syntax.h" #include "paddle/cinn/hlir/pe/broadcast.h" #include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/utils/functional.h" #include "paddle/cinn/utils/profiler.h" -#include "glog/logging.h" namespace cinn { namespace frontend { @@ -54,23 +54,29 @@ Program NetBuilder::Build(bool in_reverse) { } void NetBuilder::InferShape(Instruction instr) const { - using ShapeFunc = std::function(const std::vector&, const AttributeMap&)>; - using TypeFunc = std::function(const std::vector&, const AttributeMap&)>; + using ShapeFunc = std::function( + const std::vector&, const AttributeMap&)>; + using TypeFunc = std::function(const std::vector&, + const AttributeMap&)>; const auto& op_infershape = Operator::GetAttrs("infershape"); const auto& op_inferdtype = Operator::GetAttrs("inferdtype"); size_t size = instr->inputs.size(); std::vector in_shapes(size); std::vector in_types(size); - std::transform( - instr->inputs.begin(), instr->inputs.end(), in_shapes.begin(), [](const Variable& var) { return var->shape; }); - std::transform( - instr->inputs.begin(), instr->inputs.end(), in_types.begin(), [](const Variable& var) { return var->type; }); - auto key = Operator::Get(instr->op_type); + std::transform(instr->inputs.begin(), + instr->inputs.end(), + in_shapes.begin(), + [](const Variable& var) { return var->shape; }); + std::transform(instr->inputs.begin(), + instr->inputs.end(), + in_types.begin(), + [](const Variable& var) { return var->type; }); + auto key = Operator::Get(instr->op_type); auto out_shapes = op_infershape[key](in_shapes, instr->attrs); - auto out_types = op_inferdtype[key](in_types, instr->attrs); + auto out_types = op_inferdtype[key](in_types, instr->attrs); - auto& outs = instr->outputs; + auto& outs = instr->outputs; size_t origin_out_num = outs.size(); outs.resize(out_shapes.size()); for (size_t i = origin_out_num; i < outs.size(); i++) { @@ -78,13 +84,14 @@ void NetBuilder::InferShape(Instruction instr) const { } for (size_t i = 0; i < outs.size(); i++) { outs[i]->shape = out_shapes[i]; - outs[i]->type = out_types[i]; + outs[i]->type = out_types[i]; } } -const std::vector& NetBuilder::CustomInstr(const std::string& type, - const std::vector& inputs, - const AttributeMap& attrs) { +const std::vector& NetBuilder::CustomInstr( + const std::string& type, + const std::vector& inputs, + const AttributeMap& attrs) { Instruction instr(type, inputs); for (auto& kv : attrs) { instr.SetAttr(kv.first, kv.second); @@ -95,24 +102,35 @@ const std::vector& NetBuilder::CustomInstr(const std::string& type, return instr.GetOutputs(); } -Variable NetBuilder::BinaryOp(const std::string& op_type, const Variable& lhs, const Variable& rhs, int axis) { - CHECK_EQ(lhs->type, rhs->type) << "The inputs type of op " << op_type << " should be equal!"; +Variable NetBuilder::BinaryOp(const std::string& op_type, + const Variable& lhs, + const Variable& rhs, + int axis) { + CHECK_EQ(lhs->type, rhs->type) + << "The inputs type of op " << op_type << " should be equal!"; return CustomInstr(op_type, {lhs, rhs}, {{"axis", axis}}).front(); } -Variable NetBuilder::UnaryOp(const std::string& op_type, const Variable& operand) { +Variable NetBuilder::UnaryOp(const std::string& op_type, + const Variable& operand) { return CustomInstr(op_type, {operand}, {}).front(); } -Variable NetBuilder::Reduce(const std::string& op_type, const Variable& x, const std::vector& dim, bool keep_dim) { +Variable NetBuilder::Reduce(const std::string& op_type, + const Variable& x, + const std::vector& dim, + bool keep_dim) { // TODO(thisjiang): move the reduce simplify to frontend pass - auto product = std::accumulate(x->shape.begin(), x->shape.end(), 1, std::multiplies()); + auto product = std::accumulate( + x->shape.begin(), x->shape.end(), 1, std::multiplies()); if (product == 1) { if (keep_dim) { return Identity(x); } else { - CHECK_GE(x->shape.size(), dim.size()) << "The inputs rank should be greater than or equal to axes."; - int new_rank = x->shape.size() == dim.size() ? 1 : x->shape.size() - dim.size(); + CHECK_GE(x->shape.size(), dim.size()) + << "The inputs rank should be greater than or equal to axes."; + int new_rank = + x->shape.size() == dim.size() ? 1 : x->shape.size() - dim.size(); std::vector new_shape(new_rank, 1); return Reshape(x, new_shape); } @@ -124,11 +142,15 @@ Variable NetBuilder::Reduce(const std::string& op_type, const Variable& x, const reduce_dim[i] = x->shape.size() + reduce_dim[i]; } } - return CustomInstr(op_type, {x}, {{"dim", reduce_dim}, {"keep_dim", keep_dim}}).front(); + return CustomInstr( + op_type, {x}, {{"dim", reduce_dim}, {"keep_dim", keep_dim}}) + .front(); } -#define NETBUILDER_UNARY_OP_DEF(func_name__, op_type__) \ - Variable NetBuilder::func_name__(const Variable& operand) { return UnaryOp(#op_type__, operand); } +#define NETBUILDER_UNARY_OP_DEF(func_name__, op_type__) \ + Variable NetBuilder::func_name__(const Variable& operand) { \ + return UnaryOp(#op_type__, operand); \ + } NETBUILDER_UNARY_OP_DEF(Sqrt, sqrt) NETBUILDER_UNARY_OP_DEF(Tanh, tanh) NETBUILDER_UNARY_OP_DEF(Relu, relu) @@ -171,9 +193,10 @@ NETBUILDER_UNARY_OP_DEF(Reciprocal, reciprocal) #undef NETBUILDER_UNARY_OP_DEF -#define NETBUILDER_BINARY_OP_DEF(func_name__, op_type__) \ - Variable NetBuilder::func_name__(const Variable& lhs, const Variable& rhs, int axis) { \ - return BinaryOp(#op_type__, lhs, rhs, axis); \ +#define NETBUILDER_BINARY_OP_DEF(func_name__, op_type__) \ + Variable NetBuilder::func_name__( \ + const Variable& lhs, const Variable& rhs, int axis) { \ + return BinaryOp(#op_type__, lhs, rhs, axis); \ } NETBUILDER_BINARY_OP_DEF(Add, elementwise_add) NETBUILDER_BINARY_OP_DEF(ElementwiseAdd, elementwise_add) @@ -206,15 +229,16 @@ NETBUILDER_BINARY_OP_DEF(LogicalRightShift, logical_right_shift); #undef NETBUILDER_BINARY_OP_DEF -#define NETBUILDER_REDUCE_OP_DEF(func_name__, op_type__) \ - Variable NetBuilder::func_name__(const Variable& x, const std::vector& dim, bool keep_dim) { \ - std::vector axes = dim; \ - if (axes.size() == 0) { \ - for (int idx = 0; idx < x->shape.size(); ++idx) { \ - axes.push_back(idx); \ - } \ - } \ - return Reduce(#op_type__, x, axes, keep_dim); \ +#define NETBUILDER_REDUCE_OP_DEF(func_name__, op_type__) \ + Variable NetBuilder::func_name__( \ + const Variable& x, const std::vector& dim, bool keep_dim) { \ + std::vector axes = dim; \ + if (axes.size() == 0) { \ + for (int idx = 0; idx < x->shape.size(); ++idx) { \ + axes.push_back(idx); \ + } \ + } \ + return Reduce(#op_type__, x, axes, keep_dim); \ } NETBUILDER_REDUCE_OP_DEF(ReduceSum, reduce_sum) @@ -226,21 +250,25 @@ NETBUILDER_REDUCE_OP_DEF(ReduceAny, reduce_any) #undef NETBUILDER_REDUCE_OP_DEF -Placeholder NetBuilder::CreateInput(const Type& type, const std::vector& shape, const std::string& id_hint) { +Placeholder NetBuilder::CreateInput(const Type& type, + const std::vector& shape, + const std::string& id_hint) { if (!id_hint.empty()) { cinn::utils::TransValidVarName(id_hint); } - std::string id = id_hint.empty() ? Context::Global().NewName("placeholder") : id_hint; + std::string id = + id_hint.empty() ? Context::Global().NewName("placeholder") : id_hint; inputs_.emplace_back(id); - auto& var = inputs_.back(); - var->type = type; + auto& var = inputs_.back(); + var->type = type; var->shape = shape; return Placeholder(var); } Placeholder NetBuilder::CreateInput(const Variable& var) { - VLOG_IF(4, var->shape.empty()) << "The input's shape is empty, Create 0D-Tensor for " << var->id; + VLOG_IF(4, var->shape.empty()) + << "The input's shape is empty, Create 0D-Tensor for " << var->id; CHECK(!var->type.is_unk()) << "The input's type is not set yet"; inputs_.push_back(var); return Placeholder(var); @@ -261,63 +289,83 @@ Variable NetBuilder::FillConstant(const std::vector& shape, } else if (type.is_bool()) { value = !cinn::runtime::CheckStringFlagFalse(str_value); } else { - LOG(FATAL) << "FillConstant only support int/float/bool, but here " << dtype; + LOG(FATAL) << "FillConstant only support int/float/bool, but here " + << dtype; } - auto out = - CustomInstr("fill_constant", {}, {{"shape", shape}, {"value", value}, {"dtype", dtype}, {"force_cpu", force_cpu}}) - .front(); + auto out = CustomInstr("fill_constant", + {}, + {{"shape", shape}, + {"value", value}, + {"dtype", dtype}, + {"force_cpu", force_cpu}}) + .front(); if (!name.empty()) { out.set_id(cinn::utils::TransValidVarName(name)); } return out; } -std::vector NetBuilder::Split(const Variable& operand, const std::vector& num_or_sections, int axis) { - return CustomInstr("split", {operand}, {{"num_or_sections", num_or_sections}, {"axis", axis}}); +std::vector NetBuilder::Split(const Variable& operand, + const std::vector& num_or_sections, + int axis) { + return CustomInstr("split", + {operand}, + {{"num_or_sections", num_or_sections}, {"axis", axis}}); } Variable NetBuilder::Concat(const std::vector& input_vars, int axis) { - CHECK(!input_vars.empty()) << "The inputs of concat op should not be empty! Please check."; + CHECK(!input_vars.empty()) + << "The inputs of concat op should not be empty! Please check."; return CustomInstr("concat", input_vars, {{"axis", axis}}).front(); } -Variable NetBuilder::BroadcastTo(const Variable& operand, const std::vector& out_shape) { +Variable NetBuilder::BroadcastTo(const Variable& operand, + const std::vector& out_shape) { auto x_shape_size = operand->shape.size(); auto y_shape_size = out_shape.size(); - CHECK_GT(x_shape_size, 0) << "Cannot broadcast a empty operand " << operand->id << " to " - << cinn::utils::Join(out_shape, ","); - CHECK_LE(x_shape_size, y_shape_size) << "The broadcast_p's input shape dimension should less than the output's, " - << "but here (" << x_shape_size << " > " << y_shape_size << ")."; - - VLOG(4) << "Try broadcast " << operand->id << " from shape (" << cinn::utils::Join(operand->shape, ",") - << ") to shape (" << cinn::utils::Join(out_shape, ",") << ")"; + CHECK_GT(x_shape_size, 0) + << "Cannot broadcast a empty operand " << operand->id << " to " + << cinn::utils::Join(out_shape, ","); + CHECK_LE(x_shape_size, y_shape_size) + << "The broadcast_p's input shape dimension should less than the " + "output's, " + << "but here (" << x_shape_size << " > " << y_shape_size << ")."; + + VLOG(4) << "Try broadcast " << operand->id << " from shape (" + << cinn::utils::Join(operand->shape, ",") << ") to shape (" + << cinn::utils::Join(out_shape, ",") << ")"; std::vector broadcast_axes(x_shape_size, 0); if (x_shape_size > 1) { for (int i = 1; i <= x_shape_size; ++i) { CHECK((out_shape[y_shape_size - i] == operand->shape[x_shape_size - i]) || (operand->shape[x_shape_size - i] == 1)) - << "We cannot broadcast from shape (" << cinn::utils::Join(operand->shape, ",") << ") to shape (" + << "We cannot broadcast from shape (" + << cinn::utils::Join(operand->shape, ",") << ") to shape (" << cinn::utils::Join(out_shape, ",") << ")"; broadcast_axes[x_shape_size - i] = y_shape_size - i; } } else { - int axis = -1; + int axis = -1; auto x_shape = operand->shape.at(0); if (x_shape == 1) { // Can broadcast directly, default axis 0 axis = 0; } else { - // The broadcast axes is the index of the shape in out_shape when the input dimension is 1 + // The broadcast axes is the index of the shape in out_shape when the + // input dimension is 1 for (int i = 0; i < y_shape_size; ++i) { if (out_shape[i] == x_shape) { axis = i; break; } } - CHECK_NE(axis, -1) << "When we broadcast a 1-dimension shape, the number should contained in the out_shape. " - << "We cannot broadcast from shape (" << cinn::utils::Join(operand->shape, ",") - << ") to shape (" << cinn::utils::Join(out_shape, ",") << ")"; + CHECK_NE(axis, -1) << "When we broadcast a 1-dimension shape, the number " + "should contained in the out_shape. " + << "We cannot broadcast from shape (" + << cinn::utils::Join(operand->shape, ",") + << ") to shape (" << cinn::utils::Join(out_shape, ",") + << ")"; } broadcast_axes[0] = axis; } @@ -328,15 +376,25 @@ Variable NetBuilder::BroadcastTo(const Variable& operand, const std::vector Variable NetBuilder::BroadcastTo(const Variable& operand, const std::vector& out_shape, const std::vector& broadcast_axes) { - return CustomInstr("broadcast_to", {operand}, {{"out_shape", out_shape}, {"broadcast_axes", broadcast_axes}}).front(); + return CustomInstr( + "broadcast_to", + {operand}, + {{"out_shape", out_shape}, {"broadcast_axes", broadcast_axes}}) + .front(); } -Variable NetBuilder::Reshape(const Variable& operand, const std::vector& shape) { +Variable NetBuilder::Reshape(const Variable& operand, + const std::vector& shape) { return CustomInstr("reshape", {operand}, {{"shape", shape}}).front(); } -Variable NetBuilder::Transpose(const Variable& operand, const std::vector& axis) { - return CustomInstr("transpose", {operand}, {{"axis", utils::GetPositiveAxes(axis, operand->shape.size())}}).front(); +Variable NetBuilder::Transpose(const Variable& operand, + const std::vector& axis) { + return CustomInstr( + "transpose", + {operand}, + {{"axis", utils::GetPositiveAxes(axis, operand->shape.size())}}) + .front(); } Variable NetBuilder::Slice(const Variable& operand, @@ -365,61 +423,97 @@ Variable NetBuilder::SliceAssign(const Variable& input, const std::vector& strides) { return CustomInstr("slice_assign", {input, assign}, - {{"axes", axes}, {"starts", starts}, {"ends", ends}, {"strides", strides}}) + {{"axes", axes}, + {"starts", starts}, + {"ends", ends}, + {"strides", strides}}) .front(); } -Variable NetBuilder::Reverse(const Variable& operand, const std::vector& axis) { - return CustomInstr("reverse", {operand}, {{"axis", utils::GetPositiveAxes(axis, operand->shape.size())}}).front(); +Variable NetBuilder::Reverse(const Variable& operand, + const std::vector& axis) { + return CustomInstr( + "reverse", + {operand}, + {{"axis", utils::GetPositiveAxes(axis, operand->shape.size())}}) + .front(); } -Variable NetBuilder::Select(const Variable& condition, const Variable& true_value, const Variable& false_value) { - return CustomInstr("select", {condition, true_value, false_value}, {}).front(); +Variable NetBuilder::Select(const Variable& condition, + const Variable& true_value, + const Variable& false_value) { + return CustomInstr("select", {condition, true_value, false_value}, {}) + .front(); } -Variable NetBuilder::Gather(const Variable& operand, const Variable& index, int axis) { +Variable NetBuilder::Gather(const Variable& operand, + const Variable& index, + int axis) { size_t x_ndim = operand->shape.size(); if (axis < 0) { axis += static_cast(x_ndim); } - CHECK_LT(axis, x_ndim) << "Axis must be in [" << -x_ndim << ", " << x_ndim - 1 << ")."; + CHECK_LT(axis, x_ndim) << "Axis must be in [" << -x_ndim << ", " << x_ndim - 1 + << ")."; Variable transformed_index = index; - // If we got 1-D Tensor, the first step is reshape, in order to keep operand.rank == index.rank + // If we got 1-D Tensor, the first step is reshape, in order to keep + // operand.rank == index.rank if (index->shape.size() == 1) { std::vector index_reshape(x_ndim, 1); index_reshape[axis] = index->shape[0]; - transformed_index = Reshape(index, index_reshape); + transformed_index = Reshape(index, index_reshape); } // Then we need to broadcast transformed index - auto broadcast_shape = operand->shape; + auto broadcast_shape = operand->shape; broadcast_shape[axis] = transformed_index->shape[axis]; - transformed_index = BroadcastTo(transformed_index, broadcast_shape); - return CustomInstr("gather", {operand, transformed_index}, {{"axis", axis}}).front(); + transformed_index = BroadcastTo(transformed_index, broadcast_shape); + return CustomInstr("gather", {operand, transformed_index}, {{"axis", axis}}) + .front(); } -Variable NetBuilder::ScatterAssign(const Variable& operand, const Variable& updates, const Variable& index, int axis) { - return CustomInstr("scatter_assign", {operand, updates, index}, {{"axis", axis}}).front(); +Variable NetBuilder::ScatterAssign(const Variable& operand, + const Variable& updates, + const Variable& index, + int axis) { + return CustomInstr( + "scatter_assign", {operand, updates, index}, {{"axis", axis}}) + .front(); } -Variable NetBuilder::ScatterAdd(const Variable& operand, const Variable& updates, const Variable& index, int axis) { - return CustomInstr("scatter_add", {operand, updates, index}, {{"axis", axis}}).front(); +Variable NetBuilder::ScatterAdd(const Variable& operand, + const Variable& updates, + const Variable& index, + int axis) { + return CustomInstr("scatter_add", {operand, updates, index}, {{"axis", axis}}) + .front(); } -Variable NetBuilder::IsClose(const Variable& x, const Variable& y, float rtol, float atol, bool equal_nan) { - return CustomInstr("isclose", {x, y}, {{"rtol", rtol}, {"atol", atol}, {"equal_nan", equal_nan}}).front(); +Variable NetBuilder::IsClose(const Variable& x, + const Variable& y, + float rtol, + float atol, + bool equal_nan) { + return CustomInstr("isclose", + {x, y}, + {{"rtol", rtol}, {"atol", atol}, {"equal_nan", equal_nan}}) + .front(); } -Variable NetBuilder::Mul(const Variable& a, const Variable& b, int x_num_col_dims, int y_num_col_dims, bool is_infer) { +Variable NetBuilder::Mul(const Variable& a, + const Variable& b, + int x_num_col_dims, + int y_num_col_dims, + bool is_infer) { return CustomInstr("mul", {a, b}, - {{"x_num_col_dims", x_num_col_dims}, {"y_num_col_dims", y_num_col_dims}, {"is_infer", is_infer}}) + {{"x_num_col_dims", x_num_col_dims}, + {"y_num_col_dims", y_num_col_dims}, + {"is_infer", is_infer}}) .front(); } -const std::vector& NetBuilder::ElementwiseAddGrad(const Variable& dout, - const Variable& x, - const Variable& y, - int axis) { +const std::vector& NetBuilder::ElementwiseAddGrad( + const Variable& dout, const Variable& x, const Variable& y, int axis) { return CustomInstr("elementwise_add_grad", {dout, x, y}, {{"axis", axis}}); } @@ -439,9 +533,13 @@ Variable NetBuilder::Cast(const Variable& operand, const std::string& dtype) { return CustomInstr("cast", {operand}, {{"dtype", dtype}}).front(); } -Variable NetBuilder::BitcastConvert(const Variable& operand, const std::string& dtype) { +Variable NetBuilder::BitcastConvert(const Variable& operand, + const std::string& dtype) { std::string input_data_type = common::Type2Str(operand->type); - return CustomInstr("bitcast_convert", {operand}, {{"dtype", dtype}, {"input_data_type", input_data_type}}).front(); + return CustomInstr("bitcast_convert", + {operand}, + {{"dtype", dtype}, {"input_data_type", input_data_type}}) + .front(); } Variable NetBuilder::OneHot(const Variable& indices, @@ -450,15 +548,19 @@ Variable NetBuilder::OneHot(const Variable& indices, const int depth, const int axis, const std::string& dtype) { - return CustomInstr("one_hot", {indices, on_value, off_value}, {{"depth", depth}, {"axis", axis}, {"dtype", dtype}}) + return CustomInstr("one_hot", + {indices, on_value, off_value}, + {{"depth", depth}, {"axis", axis}, {"dtype", dtype}}) .front(); } -Variable NetBuilder::Squeeze(const Variable& operand, const std::vector& axes) { +Variable NetBuilder::Squeeze(const Variable& operand, + const std::vector& axes) { return CustomInstr("squeeze", {operand}, {{"axes", axes}}).front(); } -Variable NetBuilder::ExpandDims(const Variable& operand, const cinn::utils::ShapeType& axes) { +Variable NetBuilder::ExpandDims(const Variable& operand, + const cinn::utils::ShapeType& axes) { return CustomInstr("expand_dims", {operand}, {{"axes", axes}}).front(); } @@ -485,24 +587,41 @@ Variable NetBuilder::Conv(const Variable& lhs, .front(); } -std::vector NetBuilder::ArgSort(const Variable& operand, const int& axis, const bool& is_ascend) { - return CustomInstr("argsort", {operand}, {{"axis", axis}, {"is_ascend", is_ascend}}); +std::vector NetBuilder::ArgSort(const Variable& operand, + const int& axis, + const bool& is_ascend) { + return CustomInstr( + "argsort", {operand}, {{"axis", axis}, {"is_ascend", is_ascend}}); } -Variable NetBuilder::Sort(const Variable& operand, const int& axis, const bool& is_ascend) { - return CustomInstr("sort", {operand}, {{"axis", axis}, {"is_ascend", is_ascend}}).front(); +Variable NetBuilder::Sort(const Variable& operand, + const int& axis, + const bool& is_ascend) { + return CustomInstr( + "sort", {operand}, {{"axis", axis}, {"is_ascend", is_ascend}}) + .front(); } -Variable NetBuilder::Argmax(const Variable& x, const int& axis, const bool& keep_dim) { - return CustomInstr("argmax", {x}, {{"axis", axis}, {"keep_dim", keep_dim}}).front(); +Variable NetBuilder::Argmax(const Variable& x, + const int& axis, + const bool& keep_dim) { + return CustomInstr("argmax", {x}, {{"axis", axis}, {"keep_dim", keep_dim}}) + .front(); } -Variable NetBuilder::Argmin(const Variable& x, const int& axis, const bool& keep_dim) { - return CustomInstr("argmin", {x}, {{"axis", axis}, {"keep_dim", keep_dim}}).front(); +Variable NetBuilder::Argmin(const Variable& x, + const int& axis, + const bool& keep_dim) { + return CustomInstr("argmin", {x}, {{"axis", axis}, {"keep_dim", keep_dim}}) + .front(); } -Variable NetBuilder::LookupTable(const Variable& table, const Variable& ids, int64_t padding_idx) { - return CustomInstr("lookup_table", {table, ids}, {{"padding_idx", padding_idx}}).front(); +Variable NetBuilder::LookupTable(const Variable& table, + const Variable& ids, + int64_t padding_idx) { + return CustomInstr( + "lookup_table", {table, ids}, {{"padding_idx", padding_idx}}) + .front(); } Variable NetBuilder::Conv2d(const Variable& a, @@ -513,7 +632,16 @@ Variable NetBuilder::Conv2d(const Variable& a, int groups, const std::string& data_format, const std::string& padding_algorithm) { - return Conv(a, b, strides, paddings, dilations, groups, "forward", data_format, padding_algorithm, {}); + return Conv(a, + b, + strides, + paddings, + dilations, + groups, + "forward", + data_format, + padding_algorithm, + {}); } Variable NetBuilder::DepthwiseConv2d(const Variable& a, @@ -542,13 +670,13 @@ std::vector UpdatePool2dKernelSize(const std::vector& x_shape, std::vector new_ksize{ksize}; // Setting h/w_axis according to data_format int height_axis = -1; - int width_axis = -1; + int width_axis = -1; if (data_format == "NCHW") { height_axis = 2; - width_axis = 3; + width_axis = 3; } else if (data_format == "NHWC") { height_axis = 1; - width_axis = 2; + width_axis = 2; } else { LOG(FATAL) << "Unsupport data_format: " << data_format; } @@ -571,37 +699,40 @@ std::vector UpdatePool2dPaddings(const std::vector& paddings, if (paddings.size() == 2) { new_paddings.insert(new_paddings.end(), paddings.begin(), paddings.end()); } - CHECK_EQ(new_paddings.size(), 4) << "Padding size must be 2 or 4, but got: " << paddings.size(); + CHECK_EQ(new_paddings.size(), 4) + << "Padding size must be 2 or 4, but got: " << paddings.size(); // Setting h/w_axis according to data_format int height_axis = -1; - int width_axis = -1; + int width_axis = -1; if (data_format == "NCHW") { height_axis = 2; - width_axis = 3; + width_axis = 3; } else if (data_format == "NHWC") { height_axis = 1; - width_axis = 2; + width_axis = 2; } else { LOG(FATAL) << "Unsupport data_format: " << data_format; } // When padding_algorithm is VALID, set paddings to [0, 0, 0, 0]. - // When padding_algorithm is SAME, the calculation formula of padding is as follows: - // output_h/w = ceil(input_h/w / stride_h/w) - // padding_sum_h/w = (output_h/w - 1) * stride_h/w + kernel_h/w - input_h/w - // padding_top/left = padding_sum_h/w / 2; - // padding_bottom/right = padding_sum_h/w - padding_top/left + // When padding_algorithm is SAME, the calculation formula of padding is as + // follows: output_h/w = ceil(input_h/w / stride_h/w) padding_sum_h/w = + // (output_h/w - 1) * stride_h/w + kernel_h/w - input_h/w padding_top/left = + // padding_sum_h/w / 2; padding_bottom/right = padding_sum_h/w - + // padding_top/left if (padding_algorithm == "VALID") { new_paddings = {0, 0, 0, 0}; } else if (padding_algorithm == "SAME") { int out_size_h = (x_shape[height_axis] + stride[0] - 1) / stride[0]; int out_size_w = (x_shape[width_axis] + stride[1] - 1) / stride[1]; - int pad_sum_h = std::max((out_size_h - 1) * stride[0] + ksize[0] - x_shape[height_axis], 0); - int pad_sum_w = std::max((out_size_w - 1) * stride[1] + ksize[1] - x_shape[width_axis], 0); - int pad_top = pad_sum_h / 2; + int pad_sum_h = std::max( + (out_size_h - 1) * stride[0] + ksize[0] - x_shape[height_axis], 0); + int pad_sum_w = std::max( + (out_size_w - 1) * stride[1] + ksize[1] - x_shape[width_axis], 0); + int pad_top = pad_sum_h / 2; int pad_bottom = pad_sum_h - pad_top; - int pad_left = pad_sum_w / 2; - int pad_right = pad_sum_w - pad_left; - new_paddings = {pad_top, pad_left, pad_bottom, pad_right}; + int pad_left = pad_sum_w / 2; + int pad_right = pad_sum_w - pad_left; + new_paddings = {pad_top, pad_left, pad_bottom, pad_right}; } // When global_pooling or adaptive is true, set paddings to [0, 0, 0, 0]. if (global_pooling || adaptive) { @@ -622,27 +753,33 @@ Variable NetBuilder::Pool2d(const Variable& a, bool adaptive, const std::string& padding_algorithm) { // Check input dim - CHECK_EQ(a->shape.size(), 4) << "Input's dim must be 4, but " << a->id << "'s shape is [" - << cinn::utils::Join(a->shape, ", ") << "]."; + CHECK_EQ(a->shape.size(), 4) + << "Input's dim must be 4, but " << a->id << "'s shape is [" + << cinn::utils::Join(a->shape, ", ") << "]."; // Transform pool_type std::string pool_type; - std::transform(pooling_type.begin(), pooling_type.end(), std::back_inserter(pool_type), [](unsigned char c) { - return std::tolower(c); - }); - CHECK(pool_type == "avg" || pool_type == "max") << "Pool_type must be avg or max, but got: " << pool_type; + std::transform(pooling_type.begin(), + pooling_type.end(), + std::back_inserter(pool_type), + [](unsigned char c) { return std::tolower(c); }); + CHECK(pool_type == "avg" || pool_type == "max") + << "Pool_type must be avg or max, but got: " << pool_type; // Transform ksize std::vector input_ksize{ksize}; if (input_ksize.size() == 1) { input_ksize.insert(input_ksize.end(), ksize.begin(), ksize.end()); } - CHECK_EQ(input_ksize.size(), 2) << "Kernel_size length must be 1 or 2, but got: " << ksize.size(); + CHECK_EQ(input_ksize.size(), 2) + << "Kernel_size length must be 1 or 2, but got: " << ksize.size(); // Transform stride std::vector new_strides{strides}; if (new_strides.size() == 1) { new_strides.insert(new_strides.end(), strides.begin(), strides.end()); } - CHECK_EQ(new_strides.size(), 2) << "Stride length must be 1 or 2, but got: " << strides.size(); - CHECK(new_strides[0] > 0 && new_strides[1] > 0) << "the value of kernel size for pool2d should greater than 0."; + CHECK_EQ(new_strides.size(), 2) + << "Stride length must be 1 or 2, but got: " << strides.size(); + CHECK(new_strides[0] > 0 && new_strides[1] > 0) + << "the value of kernel size for pool2d should greater than 0."; // Transform data_format std::string new_data_format{data_format}; if (new_data_format == "AnyLayout") { @@ -651,8 +788,10 @@ Variable NetBuilder::Pool2d(const Variable& a, CHECK(new_data_format == "NCHW" || new_data_format == "NHWC") << "Data_format must be AnyLayout/NCHW/NHWC, but got: " << data_format; // Check padding_algorithm - CHECK(padding_algorithm == "EXPLICIT" || padding_algorithm == "SAME" || padding_algorithm == "VALID") - << "Padding_algorithm must be EXPLICIT/SAME/VALID, but got: " << padding_algorithm; + CHECK(padding_algorithm == "EXPLICIT" || padding_algorithm == "SAME" || + padding_algorithm == "VALID") + << "Padding_algorithm must be EXPLICIT/SAME/VALID, but got: " + << padding_algorithm; utils::AttributeMap attrs = {{"pool_type", pool_type}, {"origin_kernel_size", input_ksize}, {"stride_size", new_strides}, @@ -663,22 +802,31 @@ Variable NetBuilder::Pool2d(const Variable& a, {"data_format", new_data_format}, {"origin_adaptive", adaptive}, {"padding_algorithm", padding_algorithm}}; - // In avg_pool2d, if global_pooling = false, adaptive = true and ksize is [1, 1], we turn off adaptive and use global - // pooling instead - if (pooling_type == "avg" && !global_pooling && adaptive && input_ksize[0] == 1 && input_ksize[1] == 1) { - VLOG(4) << "In avg_pool2d, got global_pooling = false, adaptive = true, ksize = [1, 1], turn off adaptive and " + // In avg_pool2d, if global_pooling = false, adaptive = true and ksize is [1, + // 1], we turn off adaptive and use global pooling instead + if (pooling_type == "avg" && !global_pooling && adaptive && + input_ksize[0] == 1 && input_ksize[1] == 1) { + VLOG(4) << "In avg_pool2d, got global_pooling = false, adaptive = true, " + "ksize = [1, 1], turn off adaptive and " "trans to global_pooling"; - adaptive = false; + adaptive = false; global_pooling = true; } // Transform paddings - auto new_paddings = UpdatePool2dPaddings( - paddings, a->shape, input_ksize, new_strides, global_pooling, adaptive, padding_algorithm, new_data_format); + auto new_paddings = UpdatePool2dPaddings(paddings, + a->shape, + input_ksize, + new_strides, + global_pooling, + adaptive, + padding_algorithm, + new_data_format); // Update kernel_size - auto new_ksize = UpdatePool2dKernelSize(a->shape, input_ksize, global_pooling, new_data_format); - attrs["kernel_size"] = new_ksize; - attrs["padding_size"] = new_paddings; - attrs["adaptive"] = adaptive; + auto new_ksize = UpdatePool2dKernelSize( + a->shape, input_ksize, global_pooling, new_data_format); + attrs["kernel_size"] = new_ksize; + attrs["padding_size"] = new_paddings; + attrs["adaptive"] = adaptive; attrs["global_pooling"] = global_pooling; return CustomInstr("pool2d", {a}, attrs).front(); } @@ -698,23 +846,28 @@ Variable NetBuilder::Pool2dGrad(const Variable& x, const std::string& padding_algorithm) { // Transform pool_type std::string pool_type; - std::transform(pooling_type.begin(), pooling_type.end(), std::back_inserter(pool_type), [](unsigned char c) { - return std::tolower(c); - }); - CHECK(pool_type == "avg" || pool_type == "max") << "Pool_type must be avg or max, but got: " << pool_type; + std::transform(pooling_type.begin(), + pooling_type.end(), + std::back_inserter(pool_type), + [](unsigned char c) { return std::tolower(c); }); + CHECK(pool_type == "avg" || pool_type == "max") + << "Pool_type must be avg or max, but got: " << pool_type; // Transform ksize std::vector input_ksize{ksize}; if (input_ksize.size() == 1) { input_ksize.insert(input_ksize.end(), ksize.begin(), ksize.end()); } - CHECK_EQ(input_ksize.size(), 2) << "Kernel_size length must be 1 or 2, but got: " << ksize.size(); + CHECK_EQ(input_ksize.size(), 2) + << "Kernel_size length must be 1 or 2, but got: " << ksize.size(); // Transform stride std::vector new_strides{strides}; if (new_strides.size() == 1) { new_strides.insert(new_strides.end(), strides.begin(), strides.end()); } - CHECK_EQ(new_strides.size(), 2) << "Stride length must be 1 or 2, but got: " << strides.size(); - CHECK(new_strides[0] > 0 && new_strides[1] > 0) << "the value of kernel size for pool2d should greater than 0."; + CHECK_EQ(new_strides.size(), 2) + << "Stride length must be 1 or 2, but got: " << strides.size(); + CHECK(new_strides[0] > 0 && new_strides[1] > 0) + << "the value of kernel size for pool2d should greater than 0."; // Transform data_format std::string new_data_format{data_format}; if (new_data_format == "AnyLayout") { @@ -723,21 +876,32 @@ Variable NetBuilder::Pool2dGrad(const Variable& x, CHECK(new_data_format == "NCHW" || new_data_format == "NHWC") << "Data_format must be AnyLayout/NCHW/NHWC, but got: " << data_format; // Check padding_algorithm - CHECK(padding_algorithm == "EXPLICIT" || padding_algorithm == "SAME" || padding_algorithm == "VALID") - << "Padding_algorithm must be EXPLICIT/SAME/VALID, but got: " << padding_algorithm; - // In avg_pool2d, if global_pooling = false, adaptive = true and ksize is [1, 1], we turn off adaptive and use global - // pooling instead - if (pooling_type == "avg" && !global_pooling && adaptive && input_ksize[0] == 1 && input_ksize[1] == 1) { - VLOG(4) << "In avg_pool2d, got global_pooling = false, adaptive = true, ksize = [1, 1], turn off adaptive and " + CHECK(padding_algorithm == "EXPLICIT" || padding_algorithm == "SAME" || + padding_algorithm == "VALID") + << "Padding_algorithm must be EXPLICIT/SAME/VALID, but got: " + << padding_algorithm; + // In avg_pool2d, if global_pooling = false, adaptive = true and ksize is [1, + // 1], we turn off adaptive and use global pooling instead + if (pooling_type == "avg" && !global_pooling && adaptive && + input_ksize[0] == 1 && input_ksize[1] == 1) { + VLOG(4) << "In avg_pool2d, got global_pooling = false, adaptive = true, " + "ksize = [1, 1], turn off adaptive and " "trans to global_pooling"; - adaptive = false; + adaptive = false; global_pooling = true; } // Transform paddings - auto new_paddings = UpdatePool2dPaddings( - paddings, x->shape, input_ksize, new_strides, global_pooling, adaptive, padding_algorithm, new_data_format); + auto new_paddings = UpdatePool2dPaddings(paddings, + x->shape, + input_ksize, + new_strides, + global_pooling, + adaptive, + padding_algorithm, + new_data_format); // Update kernel_size - auto new_ksize = UpdatePool2dKernelSize(x->shape, input_ksize, global_pooling, new_data_format); + auto new_ksize = UpdatePool2dKernelSize( + x->shape, input_ksize, global_pooling, new_data_format); return CustomInstr("pool2d_grad", {x, y, dy}, {{"pool_type", pool_type}, @@ -754,11 +918,15 @@ Variable NetBuilder::Pool2dGrad(const Variable& x, } Variable NetBuilder::Repeat(const Variable& x, int repeats, int axis) { - return CustomInstr("repeat", {x}, {{"repeats", repeats}, {"axis", axis}}).front(); + return CustomInstr("repeat", {x}, {{"repeats", repeats}, {"axis", axis}}) + .front(); } -Variable NetBuilder::Resize(const Variable& x, const std::vector& out_shape, const std::string& mode) { - return CustomInstr("resize", {x}, {{"out_shape", out_shape}, {"mode", mode}}).front(); +Variable NetBuilder::Resize(const Variable& x, + const std::vector& out_shape, + const std::string& mode) { + return CustomInstr("resize", {x}, {{"out_shape", out_shape}, {"mode", mode}}) + .front(); } std::vector NetBuilder::BatchNorm(const Variable& a, @@ -773,36 +941,55 @@ std::vector NetBuilder::BatchNorm(const Variable& a, std::string op_type = is_test ? "batch_norm" : "batch_norm_train"; return CustomInstr(op_type, {a, scale, bias, mean, variance}, - {{"epsilon", epsilon}, {"momentum", momentum}, {"data_layout", data_layout}}); + {{"epsilon", epsilon}, + {"momentum", momentum}, + {"data_layout", data_layout}}); } // batch norm grad, output(grad_x, grad_scale, grad_bias) -std::vector NetBuilder::BatchNormGrad(const Variable& dy, - const Variable& x, - const Variable& scale, - const Variable& save_mean, - const Variable& save_variance, - const float epsilon, - const std::string& data_layout) { +std::vector NetBuilder::BatchNormGrad( + const Variable& dy, + const Variable& x, + const Variable& scale, + const Variable& save_mean, + const Variable& save_variance, + const float epsilon, + const std::string& data_layout) { return CustomInstr("batch_norm_grad", {dy, x, scale, save_mean, save_variance}, {{"epsilon", epsilon}, {"data_layout", data_layout}}); } -Variable NetBuilder::Scale(const Variable& a, float scale, float bias, bool bias_after_scale) { - return CustomInstr("scale", {a}, {{"scale", scale}, {"bias", bias}, {"bias_after_scale", bias_after_scale}}).front(); +Variable NetBuilder::Scale(const Variable& a, + float scale, + float bias, + bool bias_after_scale) { + return CustomInstr("scale", + {a}, + {{"scale", scale}, + {"bias", bias}, + {"bias_after_scale", bias_after_scale}}) + .front(); } Variable NetBuilder::Softmax(const Variable& a, const std::vector& axes, const std::string& mode, const std::string& data_format) { - return CustomInstr("softmax", {a}, {{"axes", axes}, {"mode", mode}, {"data_format", data_format}}).front(); + return CustomInstr( + "softmax", + {a}, + {{"axes", axes}, {"mode", mode}, {"data_format", data_format}}) + .front(); } -Variable NetBuilder::DropoutInfer(const Variable& a, float dropout_prob, const std::string& dropout_implementation) { - return CustomInstr( - "dropout_infer", {a}, {{"dropout_prob", dropout_prob}, {"dropout_implementation", dropout_implementation}}) +Variable NetBuilder::DropoutInfer(const Variable& a, + float dropout_prob, + const std::string& dropout_implementation) { + return CustomInstr("dropout_infer", + {a}, + {{"dropout_prob", dropout_prob}, + {"dropout_implementation", dropout_implementation}}) .front(); } @@ -811,23 +998,53 @@ Variable NetBuilder::Sum(const std::vector& inputs) { ; } -Variable NetBuilder::Arange(const float start, const float stop, const float step, const std::string& dtype) { - return CustomInstr("arange", {}, {{"start", start}, {"stop", stop}, {"step", step}, {"dtype", dtype}}).front(); +Variable NetBuilder::Arange(const float start, + const float stop, + const float step, + const std::string& dtype) { + return CustomInstr("arange", + {}, + {{"start", start}, + {"stop", stop}, + {"step", step}, + {"dtype", dtype}}) + .front(); } -Variable NetBuilder::Flip(const Variable& operand, const std::vector& axes) { - return CustomInstr("reverse", {operand}, {{"axis", utils::GetPositiveAxes(axes, operand->shape.size())}}).front(); +Variable NetBuilder::Flip(const Variable& operand, + const std::vector& axes) { + return CustomInstr( + "reverse", + {operand}, + {{"axis", utils::GetPositiveAxes(axes, operand->shape.size())}}) + .front(); } -Variable NetBuilder::Matmul(const Variable& x, const Variable& y, bool trans_x, bool trans_y, float alpha) { - return CustomInstr("matmul", {x, y}, {{"trans_a", trans_x}, {"trans_b", trans_y}, {"alpha", alpha}}).front(); +Variable NetBuilder::Matmul(const Variable& x, + const Variable& y, + bool trans_x, + bool trans_y, + float alpha) { + return CustomInstr( + "matmul", + {x, y}, + {{"trans_a", trans_x}, {"trans_b", trans_y}, {"alpha", alpha}}) + .front(); ; } -Variable NetBuilder::GaussianRandom( - const std::vector& shape, float mean, float std, int seed, const std::string& dtype) { - return CustomInstr( - "gaussian_random", {}, {{"shape", shape}, {"mean", mean}, {"std", std}, {"seed", seed}, {"dtype", dtype}}) +Variable NetBuilder::GaussianRandom(const std::vector& shape, + float mean, + float std, + int seed, + const std::string& dtype) { + return CustomInstr("gaussian_random", + {}, + {{"shape", shape}, + {"mean", mean}, + {"std", std}, + {"seed", seed}, + {"dtype", dtype}}) .front(); } @@ -839,39 +1056,58 @@ Variable NetBuilder::UniformRandom(const std::vector& shape, int diag_num, int diag_step, float diag_val) { - auto uniform_out = - CustomInstr( - "uniform_random", {}, {{"shape", shape}, {"min", min}, {"max", max}, {"seed", seed}, {"dtype", dtype}}) - .front(); + auto uniform_out = CustomInstr("uniform_random", + {}, + {{"shape", shape}, + {"min", min}, + {"max", max}, + {"seed", seed}, + {"dtype", dtype}}) + .front(); if (min == 0.0f && max == 1.0f) { return uniform_out; } - auto uniform_range = FillConstant(shape, max - min, UniqName("uniform_range"), dtype); + auto uniform_range = + FillConstant(shape, max - min, UniqName("uniform_range"), dtype); auto uniform_mul_out = Multiply(uniform_out, uniform_range); - auto uniform_min = FillConstant(shape, min, UniqName("uniform_min"), dtype); - auto uniform_res = Add(uniform_mul_out, uniform_min); + auto uniform_min = FillConstant(shape, min, UniqName("uniform_min"), dtype); + auto uniform_res = Add(uniform_mul_out, uniform_min); if (diag_num > 0) { - int numel = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - CHECK_GT(numel, (diag_num - 1) * (diag_step + 1)) << "(diag_num - 1) * (diag_step + 1) should smaller than numel!"; - auto diag_index = - Arange(0.0f, static_cast(diag_num * (diag_step + 1)), static_cast(diag_step + 1), "int32"); - auto diag_val_tensor = FillConstant(diag_index->shape, diag_val, "diag_val", dtype); + int numel = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + CHECK_GT(numel, (diag_num - 1) * (diag_step + 1)) + << "(diag_num - 1) * (diag_step + 1) should smaller than numel!"; + auto diag_index = Arange(0.0f, + static_cast(diag_num * (diag_step + 1)), + static_cast(diag_step + 1), + "int32"); + auto diag_val_tensor = + FillConstant(diag_index->shape, diag_val, "diag_val", dtype); auto uniform_flatten = Reshape(uniform_res, {-1}); - auto uniform_scatter = ScatterAssign(uniform_flatten, diag_val_tensor, diag_index); - uniform_res = Reshape(uniform_scatter, shape); + auto uniform_scatter = + ScatterAssign(uniform_flatten, diag_val_tensor, diag_index); + uniform_res = Reshape(uniform_scatter, shape); } return uniform_res; } -Variable NetBuilder::RandInt(const std::vector& shape, int min, int max, int seed, const std::string& dtype) { +Variable NetBuilder::RandInt(const std::vector& shape, + int min, + int max, + int seed, + const std::string& dtype) { CHECK_GT(max, min) << "max: " << max << "should greater than" << "min: " << min; - auto randint_out = CustomInstr("randint", {}, {{"shape", shape}, {"seed", seed}, {"dtype", dtype}}).front(); - randint_out = Cast(randint_out, dtype); - auto randint_range = FillConstant(shape, max - min, UniqName("randint_range"), dtype); - auto randint_mod = Mod(randint_out, randint_range); - auto randint_min = FillConstant(shape, min, UniqName("randint_min"), dtype); - auto randint_ret = Add(randint_mod, randint_min); + auto randint_out = + CustomInstr( + "randint", {}, {{"shape", shape}, {"seed", seed}, {"dtype", dtype}}) + .front(); + randint_out = Cast(randint_out, dtype); + auto randint_range = + FillConstant(shape, max - min, UniqName("randint_range"), dtype); + auto randint_mod = Mod(randint_out, randint_range); + auto randint_min = FillConstant(shape, min, UniqName("randint_min"), dtype); + auto randint_ret = Add(randint_mod, randint_min); return randint_ret; } @@ -879,35 +1115,47 @@ Variable NetBuilder::Cholesky(const Variable& x, bool upper) { auto cholesky_out = CustomInstr("cholesky", {x}, {{"upper", upper}}).front(); // Set upper/lower triangle of matrices to 0 auto x_ndim = x->shape.size(); - CHECK_GE(x_ndim, 2) << "The input matrix x shape size should >= 2! Please check again."; + CHECK_GE(x_ndim, 2) + << "The input matrix x shape size should >= 2! Please check again."; CHECK_EQ(x->shape[x_ndim - 1], x->shape[x_ndim - 2]) - << "The input matrix x's last 2 dimensions must be the same! Please check again."; - int m = x->shape[x_ndim - 1]; - auto m_tensor = FillConstant({m * m}, m); - auto index = Arange(0.0f, static_cast(m * m), 1.0f, "int32"); + << "The input matrix x's last 2 dimensions must be the same! Please " + "check again."; + int m = x->shape[x_ndim - 1]; + auto m_tensor = FillConstant({m * m}, m); + auto index = Arange(0.0f, static_cast(m * m), 1.0f, "int32"); auto index_row = Mod(index, m_tensor); auto index_col = FloorDivide(index, m_tensor); - auto mask = upper ? GreaterEqual(index_row, index_col) : LessEqual(index_row, index_col); - auto mask_mat = Reshape(mask, {m, m}); + auto mask = upper ? GreaterEqual(index_row, index_col) + : LessEqual(index_row, index_col); + auto mask_mat = Reshape(mask, {m, m}); auto mask_full = BroadcastTo(mask_mat, x->shape); - auto zeros = FillConstant(x->shape, 0.0f, "zeros", common::Type2Str(x->type)); - auto out = Select(mask_full, cholesky_out, zeros); + auto zeros = FillConstant(x->shape, 0.0f, "zeros", common::Type2Str(x->type)); + auto out = Select(mask_full, cholesky_out, zeros); return out; } -Variable NetBuilder::TriangularSolve( - const Variable& input1, const Variable& input2, bool left_side, bool upper, bool transpose_a, bool unit_diagonal) { +Variable NetBuilder::TriangularSolve(const Variable& input1, + const Variable& input2, + bool left_side, + bool upper, + bool transpose_a, + bool unit_diagonal) { // broadcast std::vector inputs{input1, input2}; { auto a_ndim = input1->shape.size(); auto b_ndim = input2->shape.size(); - CHECK_GE(a_ndim, 2) << "The input matrix A shape size should >= 2! Please check again."; - CHECK_GE(b_ndim, 2) << "The input matrix B shape size should >= 2! Please check again."; - std::vector input1_shape_cut(input1->shape.begin(), input1->shape.end() - 2); - std::vector input2_shape_cut(input2->shape.begin(), input2->shape.end() - 2); + CHECK_GE(a_ndim, 2) + << "The input matrix A shape size should >= 2! Please check again."; + CHECK_GE(b_ndim, 2) + << "The input matrix B shape size should >= 2! Please check again."; + std::vector input1_shape_cut(input1->shape.begin(), + input1->shape.end() - 2); + std::vector input2_shape_cut(input2->shape.begin(), + input2->shape.end() - 2); std::vector common_shape; - hlir::pe::GetBroadcastOutShape(input1_shape_cut, input2_shape_cut, &common_shape); + hlir::pe::GetBroadcastOutShape( + input1_shape_cut, input2_shape_cut, &common_shape); // broadcast input1 std::vector input1_shape(common_shape.begin(), common_shape.end()); @@ -931,8 +1179,12 @@ Variable NetBuilder::TriangularSolve( .front(); } -std::vector NetBuilder::TopK(const Variable& x, int k, int axis, bool largest) { - return CustomInstr("top_k", {x}, {{"k", k}, {"axis", axis}, {"largest", largest}}); +std::vector NetBuilder::TopK(const Variable& x, + int k, + int axis, + bool largest) { + return CustomInstr( + "top_k", {x}, {{"k", k}, {"axis", axis}, {"largest", largest}}); } } // namespace frontend diff --git a/paddle/cinn/frontend/net_builder.h b/paddle/cinn/frontend/net_builder.h index cc9510e877580..59ed335d35ddc 100644 --- a/paddle/cinn/frontend/net_builder.h +++ b/paddle/cinn/frontend/net_builder.h @@ -173,43 +173,58 @@ class NetBuilder { * * @return The result variable. */ - Variable BinaryOp(const std::string& op_type, const Variable& lhs, const Variable& rhs, int axis = -1); + Variable BinaryOp(const std::string& op_type, + const Variable& lhs, + const Variable& rhs, + int axis = -1); /** * @brief Reduce array elements over the given dims. * * @param op_type The reduce op name. * @param x The input variable. - * @param dim The dims along which a sum is performed. If dim is empty, the operation will sum over all elements - * of the input array. If the dim has negative value, it should count from the last dim to the first. Default is None. - * @param keep_dim If it is set true, the axes which are reduced are left in the result as dimensions with size one. - * With this option, the result will broadcast correctly against the input array. Default is false. + * @param dim The dims along which a sum is performed. If dim is empty, the + * operation will sum over all elements of the input array. If the dim has + * negative value, it should count from the last dim to the first. Default is + * None. + * @param keep_dim If it is set true, the axes which are reduced are left in + * the result as dimensions with size one. With this option, the result will + * broadcast correctly against the input array. Default is false. * * @return The result variable. */ Variable Reduce(const std::string& op_type, const Variable& x, const cinn::utils::ShapeType& dim = {}, - bool keep_dim = false); + bool keep_dim = false); private: // the helper function for Matmul API - std::pair BroadcastMatmulInput( - const Variable& x, const Variable& y, bool trans_x, bool trans_y, float alpha); - cinn::utils::ShapeType GetMatmulOutputShape( - const Variable& x, const Variable& y, bool trans_x, bool trans_y, float alpha); + std::pair BroadcastMatmulInput(const Variable& x, + const Variable& y, + bool trans_x, + bool trans_y, + float alpha); + cinn::utils::ShapeType GetMatmulOutputShape(const Variable& x, + const Variable& y, + bool trans_x, + bool trans_y, + float alpha); // the helper function for Constant API template - std::enable_if_t::value, cinn::utils::ShapeType> GetVectorShape(const std::vector& value) { - CHECK(!value.empty()) << "The vector should not has empty list! Please check."; + std::enable_if_t::value, cinn::utils::ShapeType> + GetVectorShape(const std::vector& value) { + CHECK(!value.empty()) + << "The vector should not has empty list! Please check."; return {static_cast(value.size())}; } template - std::enable_if_t::value, cinn::utils::ShapeType> GetVectorShape( - const std::vector& value) { - CHECK(!value.empty()) << "The vector should not has empty list! Please check."; + std::enable_if_t::value, cinn::utils::ShapeType> + GetVectorShape(const std::vector& value) { + CHECK(!value.empty()) + << "The vector should not has empty list! Please check."; auto shape = GetVectorShape(value[0]); shape.insert(shape.begin(), static_cast(value.size())); @@ -220,17 +235,19 @@ class NetBuilder { // ******************************************* // Elementwise Operator /** - * @brief Elementwise compute each element in `input` variable, and return the result Variable. + * @brief Elementwise compute each element in `input` variable, and return the + * result Variable. * @param x The input variable. * @return The output variable. */ -#define NETBUILDER_UNARY_OP_DECL(func_name__) Variable func_name__(const Variable& x); +#define NETBUILDER_UNARY_OP_DECL(func_name__) \ + Variable func_name__(const Variable& x); NETBUILDER_UNARY_OP_FOREACH(NETBUILDER_UNARY_OP_DECL) #undef NETBUILDER_UNARY_OP_DECL /** - * @brief Compute each element in `lhs` variable and `rhs` variable in `axis` dimension, and return the result - * Variable. + * @brief Compute each element in `lhs` variable and `rhs` variable in `axis` + * dimension, and return the result Variable. * @param lhs The left input variable. * @param rhs The right input variable. * @param axis The compute axis. Default is -1. @@ -243,53 +260,72 @@ class NetBuilder { /** * @brief Return array elements depending on condition. - * @param condition The condition which determine return `true_value` or `false_value`. - * @param true_value Return `true_value` if the element of `condition` is true. - * @param false_value Return `false_value` if the element of `condition` is false. + * @param condition The condition which determine return `true_value` or + * `false_value`. + * @param true_value Return `true_value` if the element of `condition` is + * true. + * @param false_value Return `false_value` if the element of `condition` is + * false. * @return The result variable. */ - Variable Select(const Variable& condition, const Variable& true_value, const Variable& false_value); + Variable Select(const Variable& condition, + const Variable& true_value, + const Variable& false_value); /** * @brief Scale operator. * @param x Input N-D variable of scale operator. * @param scale The scale factor of the input. Default is 1.0f. * @param bias The bias to be put on the input. Default is 0.0f. - * @param bias_after_scale Apply bias addition after or before scaling. It is useful for numeric stability in some - * circumstances. Default is true. - * @return Output variable of scale operator, with shape and data type same as input. + * @param bias_after_scale Apply bias addition after or before scaling. It is + * useful for numeric stability in some circumstances. Default is true. + * @return Output variable of scale operator, with shape and data type same as + * input. */ - Variable Scale(const Variable& x, float scale = 1.0f, float bias = 0.0f, bool bias_after_scale = true); + Variable Scale(const Variable& x, + float scale = 1.0f, + float bias = 0.0f, + bool bias_after_scale = true); /** * @brief This OP is used to sum one or more variable of the input. - * @param x A Varaible list. The shape and data type of the list elements should be consistent. - * @return The sum of input `x`. its shape and data types are consistent with `x`. + * @param x A Varaible list. The shape and data type of the list elements + * should be consistent. + * @return The sum of input `x`. its shape and data types are consistent with + * `x`. */ Variable Sum(const std::vector& inputs); /** * @brief Drop or keep each element of x independently. * @param x The input variable. - * @param dropout_prob Probability of setting units to zero. The dropout operator randomly sets (according to the - * given dropout probability) the outputs of some units to zero, while others are remain unchanged. - * @param dropout_implementation Choice the mode of dropout. When "downgrade_in_infer", downgrade the outcome at - * inference: `train: out = input * mask, inference: out = input * (1.0 - dropout_prob)`. When "upscale_in_train", - * upscale the outcome at training time: `train: out = input * mask / ( 1.0 - dropout_prob), inference: out = input`. - * @return A variable representing the dropout, has same shape and data type with input. + * @param dropout_prob Probability of setting units to zero. The dropout + * operator randomly sets (according to the given dropout probability) the + * outputs of some units to zero, while others are remain unchanged. + * @param dropout_implementation Choice the mode of dropout. When + * "downgrade_in_infer", downgrade the outcome at inference: `train: out = + * input * mask, inference: out = input * (1.0 - dropout_prob)`. When + * "upscale_in_train", upscale the outcome at training time: `train: out = + * input * mask / ( 1.0 - dropout_prob), inference: out = input`. + * @return A variable representing the dropout, has same shape and data type + * with input. */ - Variable DropoutInfer(const Variable& x, - float dropout_prob = 0.5f, - const std::string& dropout_implementation = "downgrade_in_infer"); + Variable DropoutInfer( + const Variable& x, + float dropout_prob = 0.5f, + const std::string& dropout_implementation = "downgrade_in_infer"); Variable GatherNd(const Variable& x, const Variable& index); - Variable Scatter(const Variable& src, const Variable& index, const Variable& out, const int& axis = 0); + Variable Scatter(const Variable& src, + const Variable& index, + const Variable& out, + const int& axis = 0); Variable Scatter(const Variable& src, const Variable& index, const cinn::utils::ShapeType& shape, const float& default_value = 0, - const int& axis = 0); + const int& axis = 0); Variable ScatterNd(const Variable& src, const Variable& index, @@ -298,34 +334,44 @@ class NetBuilder { Variable ScatterNd(const Variable& src, const Variable& index, const cinn::utils::ShapeType& shape, - const float& default_value = 0, + const float& default_value = 0, const cinn::utils::ShapeType& axes = {}); /** - * @brief This operator checks if all `x` and `y` satisfy the condition: `|x - y| <= atol + rtol * |y|` + * @brief This operator checks if all `x` and `y` satisfy the condition: `|x - + * y| <= atol + rtol * |y|` * @param x The first variable. * @param y The second variable. * @param rtol The relative tolerance. Default: 1e−5f. * @param atol The absolute tolerance. Default: 1e−8f. - * @param equal_nan If `true`, then two NaNs will be compared as equal. Default: false . + * @param equal_nan If `true`, then two NaNs will be compared as equal. + * Default: false . * @return The output variable, it’s data type is bool. */ - Variable IsClose( - const Variable& x, const Variable& y, float rtol = 1e-05f, float atol = 1e-08f, bool equal_nan = false); + Variable IsClose(const Variable& x, + const Variable& y, + float rtol = 1e-05f, + float atol = 1e-08f, + bool equal_nan = false); // ******************************************* // Reduction operator /** * @brief Reduce array elements over the given dims. * @param x The input variable. - * @param dim The dims along which a sum is performed. If dim is empty, the operation will sum over all elements - * of the input array. If the dim has negative value, it should count from the last dim to the first. Default is None. - * @param keep_dim If it is set true, the axes which are reduced are left in the result as dimensions with size one. - * With this option, the result will broadcast correctly against the input array. Default is false. + * @param dim The dims along which a sum is performed. If dim is empty, the + * operation will sum over all elements of the input array. If the dim has + * negative value, it should count from the last dim to the first. Default is + * None. + * @param keep_dim If it is set true, the axes which are reduced are left in + * the result as dimensions with size one. With this option, the result will + * broadcast correctly against the input array. Default is false. * @return The result variable. */ -#define NETBUILDER_REDUCE_OP_DECL(func_name__) \ - Variable func_name__(const Variable& x, const cinn::utils::ShapeType& dim = {}, bool keep_dim = false); +#define NETBUILDER_REDUCE_OP_DECL(func_name__) \ + Variable func_name__(const Variable& x, \ + const cinn::utils::ShapeType& dim = {}, \ + bool keep_dim = false); NETBUILDER_REDUCE_OP_FOREACH(NETBUILDER_REDUCE_OP_DECL) #undef NETBUILDER_REDUCE_OP_DECL @@ -339,7 +385,8 @@ class NetBuilder { Placeholder CreateInput(const Variable& input); /** - * @brief Create new input, whose data type is `type`, shape is `shape`, and id is `id_hint`. + * @brief Create new input, whose data type is `type`, shape is `shape`, and + * id is `id_hint`. * @param type The input variable's data type. * @param shape The input variable's shape. * @param id_hint The input variable's name. Default is None. @@ -356,11 +403,16 @@ class NetBuilder { * @return The result variable. */ template - std::enable_if_t::value, Variable> Constant(const T& value, - const std::string& name = "", - const std::string& dtype = "") { - auto true_dtype = dtype.empty() ? common::Type2Str(common::type_of()) : dtype; - auto out = CustomInstr("const_scalar", {}, {{"value", value}, {"dtype", true_dtype}}).front(); + std::enable_if_t::value, Variable> Constant( + const T& value, + const std::string& name = "", + const std::string& dtype = "") { + auto true_dtype = + dtype.empty() ? common::Type2Str(common::type_of()) : dtype; + auto out = + CustomInstr( + "const_scalar", {}, {{"value", value}, {"dtype", true_dtype}}) + .front(); if (!name.empty()) { out.set_id(name); @@ -369,19 +421,23 @@ class NetBuilder { } template - std::enable_if_t::value, Variable> Constant(const T& value, - const std::string& name = "", - const std::string& dtype = "") { - CHECK(!value.empty()) << "The value of Constant should not be None or empty list! Please check."; + std::enable_if_t::value, Variable> Constant( + const T& value, + const std::string& name = "", + const std::string& dtype = "") { + CHECK(!value.empty()) << "The value of Constant should not be None or " + "empty list! Please check."; // flatten n-dims vector to 1-dim vector auto all_datas = cinn::utils::Flatten(value); - CHECK(!all_datas.empty()) << "The value of Constant should not be None or empty list! Please check."; + CHECK(!all_datas.empty()) << "The value of Constant should not be None or " + "empty list! Please check."; VLOG(4) << "Constant with values: " << cinn::utils::Join(all_datas, ", "); - using TYPE = typename decltype(all_datas)::value_type; - auto true_dtype = dtype.empty() ? common::Type2Str(common::type_of()) : dtype; + using TYPE = typename decltype(all_datas)::value_type; + auto true_dtype = + dtype.empty() ? common::Type2Str(common::type_of()) : dtype; const auto& real_shape = GetVectorShape(value); @@ -389,8 +445,11 @@ class NetBuilder { return Constant(all_datas[0], name, true_dtype); } - auto assign_out = CustomInstr("assign_value", {}, {{"values", all_datas}, {"dtype", true_dtype}}).front(); - auto out = Reshape(assign_out, real_shape); + auto assign_out = + CustomInstr( + "assign_value", {}, {{"values", all_datas}, {"dtype", true_dtype}}) + .front(); + auto out = Reshape(assign_out, real_shape); // set the name correctly if (!name.empty()) { @@ -402,10 +461,12 @@ class NetBuilder { /** * @brief The op return a variable with the specific value, shape and type. * @param shape Shape of the variable to be created. - * @param value The constant value used to initialize the variable to be created. + * @param value The constant value used to initialize the variable to be + * created. * @param name The name of the output variable. * @param dtype Data type of the output variable. - * @param force_cpu Whether the variable should force placed in cpu, default in device memory. Default is false. + * @param force_cpu Whether the variable should force placed in cpu, default + * in device memory. Default is false. * @return The result variable. */ template @@ -414,10 +475,13 @@ class NetBuilder { const std::string& name, const std::string& dtype, bool force_cpu = false) { - auto out = - CustomInstr( - "fill_constant", {}, {{"shape", shape}, {"value", value}, {"dtype", dtype}, {"force_cpu", force_cpu}}) - .front(); + auto out = CustomInstr("fill_constant", + {}, + {{"shape", shape}, + {"value", value}, + {"dtype", dtype}, + {"force_cpu", force_cpu}}) + .front(); if (!name.empty()) { out.set_id(cinn::utils::TransValidVarName(name)); } @@ -425,12 +489,15 @@ class NetBuilder { } /** - * @brief The op return a variable with the specific string value, shape and type. + * @brief The op return a variable with the specific string value, shape and + * type. * @param shape Shape of the variable to be created. - * @param str_value The constant string value used to initialize the variable to be created. + * @param str_value The constant string value used to initialize the variable + * to be created. * @param name The name of the output variable. * @param dtype Data type of the output variable. - * @param force_cpu Whether the variable should force placed in cpu, default in device memory. Default is false. + * @param force_cpu Whether the variable should force placed in cpu, default + * in device memory. Default is false. * @return The result variable. */ Variable FillConstant(const cinn::utils::ShapeType& shape, @@ -440,140 +507,187 @@ class NetBuilder { bool force_cpu = false); /** - * @brief The op return a variable with the specific value, shape and type, the type is infered from value. + * @brief The op return a variable with the specific value, shape and type, + * the type is infered from value. * @param shape Shape of the variable to be created. - * @param value The constant value used to initialize the variable to be created. + * @param value The constant value used to initialize the variable to be + * created. * @param name The name of the output variable. - * @param force_cpu Whether the variable should force placed in cpu, default in device memory. Default is false. + * @param force_cpu Whether the variable should force placed in cpu, default + * in device memory. Default is false. * @return The result variable. */ template Variable FillConstant(const cinn::utils::ShapeType& shape, T value, const std::string& name = "", - bool force_cpu = false) { - return FillConstant(shape, value, name, common::Type2Str(common::type_of()), force_cpu); + bool force_cpu = false) { + return FillConstant( + shape, value, name, common::Type2Str(common::type_of()), force_cpu); } /** - * @brief Return evenly spaced values within a given interval. Values are generated within the half-open interval - * `[start, stop)` (in other words, the interval including start but excluding stop). + * @brief Return evenly spaced values within a given interval. Values are + * generated within the half-open interval + * `[start, stop)` (in other words, the interval including start but excluding + * stop). * @param start Start of interval. The interval includes this value. - * @param stop End of interval. The interval does not include this value, except in some cases where step is not - * an integer and floating point round-off affects the length of out. - * @param step Spacing between values. For any output out, this is the distance between two adjacent values, `out[i+1] + * @param stop End of interval. The interval does not include this value, + * except in some cases where step is not an integer and floating point + * round-off affects the length of out. + * @param step Spacing between values. For any output out, this is the + * distance between two adjacent values, `out[i+1] * - out[i]`. * @param dtype The data type of the output. Default: "float32". - * @return A 1-D variable which is evenly spaced values within a given interval. Its data type is set by dtype. + * @return A 1-D variable which is evenly spaced values within a given + * interval. Its data type is set by dtype. */ - Variable Arange(const float start, const float stop, const float step, const std::string& dtype); + Variable Arange(const float start, + const float stop, + const float step, + const std::string& dtype); /** - * @brief This operator is used to perform matrix multiplication for input x and y. + * @brief This operator is used to perform matrix multiplication for input x + * and y. * @param x The first input variable. * @param y The second input variable. - * @param x_num_col_dims If the input `x` is a variable with more than two dimensions, `x` will be flattened into a - * two-dimensional matrix first. The flattening rule is: the first `num_col_dims` will be flattened to form the first - * dimension of the final matrix (the height of the matrix), and the rest `rank(x)` - `num_col_dims` dimensions are - * flattened to form the second dimension of the final matrix (the width of the matrix). Default is 1. - * @param y_num_col_dims If the input `y` is a variable with more than two dimensions, `y` will be flattened into a - * two-dimensional matrix first. The attribute `y_num_col_dims` determines how `y` is flattened. See comments of + * @param x_num_col_dims If the input `x` is a variable with more than two + * dimensions, `x` will be flattened into a two-dimensional matrix first. The + * flattening rule is: the first `num_col_dims` will be flattened to form the + * first dimension of the final matrix (the height of the matrix), and the + * rest `rank(x)` - `num_col_dims` dimensions are flattened to form the second + * dimension of the final matrix (the width of the matrix). Default is 1. + * @param y_num_col_dims If the input `y` is a variable with more than two + * dimensions, `y` will be flattened into a two-dimensional matrix first. The + * attribute `y_num_col_dims` determines how `y` is flattened. See comments of * `x_num_col_dims` for more details. Default is 1. * @return The result variable. */ - Variable Mul( - const Variable& x, const Variable& y, int x_num_col_dims = 1, int y_num_col_dims = 1, bool is_infer = false); + Variable Mul(const Variable& x, + const Variable& y, + int x_num_col_dims = 1, + int y_num_col_dims = 1, + bool is_infer = false); /** - * @brief Applies matrix multiplication to two variable. Matmul follows the complete broadcast rules, and its behavior - * is consistent with `np.matmul`. + * @brief Applies matrix multiplication to two variable. Matmul follows the + * complete broadcast rules, and its behavior is consistent with `np.matmul`. * @param x The left input variable. * @param y The right input variable. - * @param trans_x Whether to transpose `x` before multiplication. Default is false. - * @param trans_y Whether to transpose `y` before multiplication. Default is false. + * @param trans_x Whether to transpose `x` before multiplication. Default is + * false. + * @param trans_y Whether to transpose `y` before multiplication. Default is + * false. * @param alpha The scale of output. Default 1.0f. * @return The product variable. */ - Variable Matmul(const Variable& x, const Variable& y, bool trans_x = false, bool trans_y = false, float alpha = 1.0f); + Variable Matmul(const Variable& x, + const Variable& y, + bool trans_x = false, + bool trans_y = false, + float alpha = 1.0f); /** - * @brief This operation calculates the pooling output based on the input, pooling_type and pool_size, pool_stride, - * pool_padding parameters. - * @param x The input variable of pooling operator which is a 4-D variable with shape [N, C, H, W]. The format of - * input variable is “NCHW” or “NHWC”, where N is batch size, C is the number of channels, H is the height of the + * @brief This operation calculates the pooling output based on the input, + * pooling_type and pool_size, pool_stride, pool_padding parameters. + * @param x The input variable of pooling operator which is a 4-D variable + * with shape [N, C, H, W]. The format of input variable is “NCHW” or “NHWC”, + * where N is batch size, C is the number of channels, H is the height of the * feature, and W is the width of the feature. - * @param pooling_type Pooling type, can be “max” for max-pooling and “avg” for average-pooling - * @param ksize The pool kernel size. If pool kernel size is a tuple or list, it must contain two integers, - * (pool_size_Height, pool_size_Width). Otherwise, the pool kernel size will be a square of an int. - * @param strides The pool stride size. If pool stride size is a tuple or list, it must contain two integers, - * (pool_stride_Height, pool_stride_Width). Otherwise, the pool stride size will be a square of an int. Default is {1, - * 1}. - * @param paddings The padding size. If padding is a list/tuple, it must contain two integers, (padding_H, padding_W). - * Otherwise, the padding_H = padding_W = padding. Default: padding = {0, 0}. - * @param ceil_mode Whether to use the ceil function to calculate output height and width. False is the default. If it - * is set to False, the floor function will be used. Default False - * @param exclusive Whether to exclude padding points in average pooling mode, default is true. - * @param global_pooling Whether to use the global pooling. If global_pooling = true, kernel size and paddings will be - * ignored. Default False. - * @param data_format Data format that specifies the layout of input. It can be "NCHW" or "NHWC". Default: "NCHW". - * @param adaptive When true, will perform adaptive pooling instead, output shape in H and W dimensions will be same - * as ksize, input data will be divided into grids specify by ksize averagely and perform pooling in each grid area to - * get output pooling value. Default: False. - * @param padding_algorithm Can be "EXPLICIT"/"SAME"/"VALID". Default: "EXPLICIT". - * @return The output variable of pooling result. The data type is same as input variable. + * @param pooling_type Pooling type, can be “max” for max-pooling and “avg” + * for average-pooling + * @param ksize The pool kernel size. If pool kernel size is a tuple or list, + * it must contain two integers, (pool_size_Height, pool_size_Width). + * Otherwise, the pool kernel size will be a square of an int. + * @param strides The pool stride size. If pool stride size is a tuple or + * list, it must contain two integers, (pool_stride_Height, + * pool_stride_Width). Otherwise, the pool stride size will be a square of an + * int. Default is {1, 1}. + * @param paddings The padding size. If padding is a list/tuple, it must + * contain two integers, (padding_H, padding_W). Otherwise, the padding_H = + * padding_W = padding. Default: padding = {0, 0}. + * @param ceil_mode Whether to use the ceil function to calculate output + * height and width. False is the default. If it is set to False, the floor + * function will be used. Default False + * @param exclusive Whether to exclude padding points in average pooling mode, + * default is true. + * @param global_pooling Whether to use the global pooling. If global_pooling + * = true, kernel size and paddings will be ignored. Default False. + * @param data_format Data format that specifies the layout of input. It can + * be "NCHW" or "NHWC". Default: "NCHW". + * @param adaptive When true, will perform adaptive pooling instead, output + * shape in H and W dimensions will be same as ksize, input data will be + * divided into grids specify by ksize averagely and perform pooling in each + * grid area to get output pooling value. Default: False. + * @param padding_algorithm Can be "EXPLICIT"/"SAME"/"VALID". Default: + * "EXPLICIT". + * @return The output variable of pooling result. The data type is same as + * input variable. */ Variable Pool2d(const Variable& x, const std::string& pooling_type, const std::vector& ksize, - const std::vector& strides = {1, 1}, - const std::vector& paddings = {0, 0}, - bool ceil_mode = false, - bool exclusive = true, - bool global_pooling = false, - const std::string& data_format = "NCHW", - bool adaptive = false, + const std::vector& strides = {1, 1}, + const std::vector& paddings = {0, 0}, + bool ceil_mode = false, + bool exclusive = true, + bool global_pooling = false, + const std::string& data_format = "NCHW", + bool adaptive = false, const std::string& padding_algorithm = "EXPLICIT"); /** - * @brief This operation calculates the pooling output based on the input, pooling_type and pool_size, pool_stride, - * pool_padding parameters. - * @param x The input variable of pooling operator which is a 4-D variable with shape [N, C, H, W]. The format of - * input variable is “NCHW” or “NHWC”, where N is batch size, C is the number of channels, H is the height of the + * @brief This operation calculates the pooling output based on the input, + * pooling_type and pool_size, pool_stride, pool_padding parameters. + * @param x The input variable of pooling operator which is a 4-D variable + * with shape [N, C, H, W]. The format of input variable is “NCHW” or “NHWC”, + * where N is batch size, C is the number of channels, H is the height of the * feature, and W is the width of the feature. * @param y The output variable of pooling operator. * @param dy The gradient variable of pooling operator's otuput. - * @param pooling_type pooling type, can be “max” for max-pooling and “avg” for average-pooling - * @param ksize The pool kernel size. If pool kernel size is a tuple or list, it must contain two integers, - * (pool_size_Height, pool_size_Width). Otherwise, the pool kernel size will be a square of an int. - * @param strides The pool stride size. If pool stride size is a tuple or list, it must contain two integers, - * (pool_stride_Height, pool_stride_Width). Otherwise, the pool stride size will be a square of an int. Default is {1, - * 1}. - * @param paddings The padding size. If padding is a list/tuple, it must contain two integers, (padding_H, padding_W). - * Otherwise, the padding_H = padding_W = padding. Default: padding = {0, 0}. - * @param ceil_mode Whether to use the ceil function to calculate output height and width. False is the default. If it - * is set to False, the floor function will be used. Default False - * @param exclusive Whether to exclude padding points in average pooling mode, default is true. - * @param global_pooling Whether to use the global pooling. If global_pooling = true, kernel size and paddings will be - * ignored. Default False. - * @param data_format Data format that specifies the layout of input. It can be "NCHW" or "NHWC". Default: "NCHW". - * @param adaptive When true, will perform adaptive pooling instead, output shape in H and W dimensions will be same - * as ksize, input data will be divided into grids specify by ksize averagely and perform pooling in each grid area to - * get output pooling value. Default: False. - * @param padding_algorithm Can be "EXPLICIT"/"SAME"/"VALID". Default: "EXPLICIT". - * @return The gradient variable of pooling input "X". The data type is same as input variable. + * @param pooling_type pooling type, can be “max” for max-pooling and “avg” + * for average-pooling + * @param ksize The pool kernel size. If pool kernel size is a tuple or list, + * it must contain two integers, (pool_size_Height, pool_size_Width). + * Otherwise, the pool kernel size will be a square of an int. + * @param strides The pool stride size. If pool stride size is a tuple or + * list, it must contain two integers, (pool_stride_Height, + * pool_stride_Width). Otherwise, the pool stride size will be a square of an + * int. Default is {1, 1}. + * @param paddings The padding size. If padding is a list/tuple, it must + * contain two integers, (padding_H, padding_W). Otherwise, the padding_H = + * padding_W = padding. Default: padding = {0, 0}. + * @param ceil_mode Whether to use the ceil function to calculate output + * height and width. False is the default. If it is set to False, the floor + * function will be used. Default False + * @param exclusive Whether to exclude padding points in average pooling mode, + * default is true. + * @param global_pooling Whether to use the global pooling. If global_pooling + * = true, kernel size and paddings will be ignored. Default False. + * @param data_format Data format that specifies the layout of input. It can + * be "NCHW" or "NHWC". Default: "NCHW". + * @param adaptive When true, will perform adaptive pooling instead, output + * shape in H and W dimensions will be same as ksize, input data will be + * divided into grids specify by ksize averagely and perform pooling in each + * grid area to get output pooling value. Default: False. + * @param padding_algorithm Can be "EXPLICIT"/"SAME"/"VALID". Default: + * "EXPLICIT". + * @return The gradient variable of pooling input "X". The data type is same + * as input variable. */ Variable Pool2dGrad(const Variable& x, const Variable& y, const Variable& dy, const std::string& pooling_type, const std::vector& ksize, - const std::vector& strides = {1, 1}, - const std::vector& paddings = {0, 0}, - bool ceil_mode = false, - bool exclusive = true, - bool global_pooling = false, - const std::string& data_format = "NCHW", - bool adaptive = false, + const std::vector& strides = {1, 1}, + const std::vector& paddings = {0, 0}, + bool ceil_mode = false, + bool exclusive = true, + bool global_pooling = false, + const std::string& data_format = "NCHW", + bool adaptive = false, const std::string& padding_algorithm = "EXPLICIT"); /** @@ -589,29 +703,33 @@ class NetBuilder { * @brief Resize operator does 2D scaling to the given size. * @param x An input variable, the data layout of input is NCHW * @param out_shape The out size to which the image will be resized. - * @param mode Scale method to used [nearest, bilinear, bicubic], this will default to `bilinear`. + * @param mode Scale method to used [nearest, bilinear, bicubic], this will + * default to `bilinear`. * @return The resized result. */ - Variable Resize(const Variable& x, const std::vector& out_shape, const std::string& mode); + Variable Resize(const Variable& x, + const std::vector& out_shape, + const std::string& mode); // ******************************************* // Broadcast operator /** * @brief Broadcast the input variable to a given shape. * @param x The input variable need to broadcast. - * @param out_shape The result shape after broadcasting. The value -1 in shape means keeping the corresponding - * dimension unchanged. + * @param out_shape The result shape after broadcasting. The value -1 in shape + * means keeping the corresponding dimension unchanged. * @return The result variable with given shape. */ - Variable BroadcastTo(const Variable& x, const cinn::utils::ShapeType& out_shape); + Variable BroadcastTo(const Variable& x, + const cinn::utils::ShapeType& out_shape); /** * @brief Broadcast the input variable to a given shape. * @param x The input variable need to broadcast. - * @param out_shape The result shape after broadcasting. The value -1 in shape means keeping the corresponding - * dimension unchanged. - * @param broadcast_axes The axes need to broadcast, the axis not in `broadcast_axes` of `out_shape`'s value should be - * the same as input shape. + * @param out_shape The result shape after broadcasting. The value -1 in + * shape means keeping the corresponding dimension unchanged. + * @param broadcast_axes The axes need to broadcast, the axis not in + * `broadcast_axes` of `out_shape`'s value should be the same as input shape. * @return The result variable with given shape. */ Variable BroadcastTo(const Variable& x, @@ -623,8 +741,9 @@ class NetBuilder { /** * @brief This OP concatenates the input along the axis. * @param x Variable list with same data type. - * @param axis Specify the axis to operate on the input concatenates. The effective range is [-R, R), where R is - * Rank(x). When axis < 0, it works the same way as axis+R. Default is 0. + * @param axis Specify the axis to operate on the input concatenates. The + * effective range is [-R, R), where R is Rank(x). When axis < 0, it works the + * same way as axis+R. Default is 0. * @return A variable with the same data type as x. */ Variable Concat(const std::vector& x, int axis = 0); @@ -632,30 +751,38 @@ class NetBuilder { /** * @brief Split the input variable into multiple sub-variables. * @param x A N-D variable. - * @param num_or_sections If `num_or_sections` is an int, then `num_or_sections` indicates the number of equal sized - * sub-variables that the `x` will be divided into. If `num_or_sections` is a list, the length of it indicates the - * number of sub-variables and the elements in it indicate the sizes of sub-variables’ dimension orderly. The length - * of the list must not be larger than the `x`'s size of specified axis. - * @param axis The axis along which to split. The effective range is [-R, R), where R is Rank(x). When axis < 0, it - * works the same way as axis+R. Default is 0. + * @param num_or_sections If `num_or_sections` is an int, then + * `num_or_sections` indicates the number of equal sized sub-variables that + * the `x` will be divided into. If `num_or_sections` is a list, the length of + * it indicates the number of sub-variables and the elements in it indicate + * the sizes of sub-variables’ dimension orderly. The length of the list must + * not be larger than the `x`'s size of specified axis. + * @param axis The axis along which to split. The effective range is [-R, R), + * where R is Rank(x). When axis < 0, it works the same way as axis+R. Default + * is 0. * @return The list of segmented variables. */ - std::vector Split(const Variable& x, const std::vector& num_or_sections, int axis = 0); + std::vector Split(const Variable& x, + const std::vector& num_or_sections, + int axis = 0); /** * @brief This operator changes the shape of x without changing its data. * @param x An N-D variable. - * @param shape Define the target shape. At most one dimension of the target shape can be -1. + * @param shape Define the target shape. At most one dimension of the target + * shape can be -1. * @return A reshaped variable with the same data type as x. */ Variable Reshape(const Variable& x, const cinn::utils::ShapeType& shape); /** - * @brief This OP will squeeze single-dimensional entries of input variable shape. If axes is provided, will remove - * the dims by axes, the dims selected by axes should be one. If not provide axes, all dims equal to one will be + * @brief This OP will squeeze single-dimensional entries of input variable + * shape. If axes is provided, will remove the dims by axes, the dims selected + * by axes should be one. If not provide axes, all dims equal to one will be * deleted. * @param x An N-D variable. - * @param axes The dimensions to be squeezed. Axes range is `[−rank(input),rank(input)]`. If `axes` is negative, + * @param axes The dimensions to be squeezed. Axes range is + * `[−rank(input),rank(input)]`. If `axes` is negative, * `axes=axes+rank(input)`. * @return Output squeezed variable. Data type is same as input variable. */ @@ -664,11 +791,13 @@ class NetBuilder { /** * @brief Creates an operation to insert new dimensions of length 1. * @param operand An N-D variable. - * @param axis The index of the first new dimension (allows negative indices as offsets from the last dimension). + * @param axis The index of the first new dimension (allows negative indices + * as offsets from the last dimension). * @param num_newaxis The number of new dimensions to insert * @return A variable whose op member is the dim expandsion operation. */ - Variable ExpandDims(const Variable& operand, const cinn::utils::ShapeType& axes); + Variable ExpandDims(const Variable& operand, + const cinn::utils::ShapeType& axes); /** * @brief This operator reverse the input along the axis. @@ -679,8 +808,9 @@ class NetBuilder { Variable Reverse(const Variable& x, const cinn::utils::ShapeType& axis); /** - * @brief Permute the data dimensions of input according to perm. The i-th dimension of the returned variable will - * correspond to the perm[i]-th dimension of input. + * @brief Permute the data dimensions of input according to perm. The i-th + * dimension of the returned variable will correspond to the perm[i]-th + * dimension of input. * @param x An N-D variable. * @param axis Permute the input according to the data of perm. * @return A transposed n-D variable. @@ -691,25 +821,31 @@ class NetBuilder { * @brief This operator produces a slice of x along multiple axes. * @param x An N-D variable. * @param axes Axes that starts and ends apply to. - * @param starts The starting indices of corresponding axis in axes. Default: None. - * @param ends The ending indices of corresponding axis in axes. Default: None. - * @param infer_flags Whether the output shape can be infered in compile time. Now only support all 1. Default: None. + * @param starts The starting indices of corresponding axis in axes. Default: + * None. + * @param ends The ending indices of corresponding axis in axes. Default: + * None. + * @param infer_flags Whether the output shape can be infered in compile time. + * Now only support all 1. Default: None. * @param strides The slice step of corresponding axis in axes. Default: None. * @param decrease_axis Eliminate the specified dimension. Default: None. - * @return A variable with the same dimension as x. The data type is same as x. + * @return A variable with the same dimension as x. The data type is same as + * x. */ Variable Slice(const Variable& x, const cinn::utils::ShapeType& axes, - const std::vector& starts = {}, - const std::vector& ends = {}, - const std::vector& infer_flags = {}, - const std::vector& strides = {}, + const std::vector& starts = {}, + const std::vector& ends = {}, + const std::vector& infer_flags = {}, + const std::vector& strides = {}, const std::vector& decrease_axis = {}); /** - * @brief Returns a new variable which indexes the input variable along dimension axis using the entries in index - * which is a variable. The returned variable has the same number of dimensions as the original x variable. The dim-th - * dimension has the same size as the length of index; other dimensions have the same size as in the x variable. + * @brief Returns a new variable which indexes the input variable along + * dimension axis using the entries in index which is a variable. The returned + * variable has the same number of dimensions as the original x variable. The + * dim-th dimension has the same size as the length of index; other dimensions + * have the same size as in the x variable. * @param x An N-D variable. * @param index The 1-D variable containing the indices to index. * @param axis The dimension in which we index. Default: 0. @@ -718,35 +854,45 @@ class NetBuilder { Variable Gather(const Variable& x, const Variable& index, int axis = 0); /** - * @brief Output is obtained by updating the input on selected indices based on updates. + * @brief Output is obtained by updating the input on selected indices based + * on updates. * @param x The input N-D variable with ndim>=1. - * @param updates pdate input with updates parameter based on index. shape should be the same as input, and dim value - * with dim > 1 should be the same as input. - * @param index The index 1-D variable. The length of index cannot exceed updates’s length, and the value in index - * cannot exceed input’s length. + * @param updates pdate input with updates parameter based on index. shape + * should be the same as input, and dim value with dim > 1 should be the same + * as input. + * @param index The index 1-D variable. The length of index cannot exceed + * updates’s length, and the value in index cannot exceed input’s length. * @param axis The dimension in which we index. Default: 0. * @return A variable with same shape as x. */ - Variable ScatterAssign(const Variable& x, const Variable& updates, const Variable& index, int axis = 0); + Variable ScatterAssign(const Variable& x, + const Variable& updates, + const Variable& index, + int axis = 0); /** - * @brief Output is obtained by adding the `input` and the `updates` on selected indices. + * @brief Output is obtained by adding the `input` and the `updates` on + * selected indices. * @param x The input N-D variable with ndim>=1. - * @param updates Update input with updates parameter based on index. Shape should be the same as input, and dim value - * with dim > 1 should be the same as input. - * @param index The index 1-D variable. The length of index cannot exceed updates’s length, and the value in index - * cannot exceed input’s length. + * @param updates Update input with updates parameter based on index. Shape + * should be the same as input, and dim value with dim > 1 should be the same + * as input. + * @param index The index 1-D variable. The length of index cannot exceed + * updates’s length, and the value in index cannot exceed input’s length. * @param axis The dimension in which we index. Default: 0. * @return A variable with same shape as x. */ - Variable ScatterAdd(const Variable& x, const Variable& updates, const Variable& index, int axis = 0); + Variable ScatterAdd(const Variable& x, + const Variable& updates, + const Variable& index, + int axis = 0); /** - * @brief Replacing the value of `x` by `assign` variable on the range of `slice(x)`. In other word, - * `slice(x)=assign`. + * @brief Replacing the value of `x` by `assign` variable on the range of + * `slice(x)`. In other word, `slice(x)=assign`. * @param x An N-D variable. - * @param assign Update input with assign value based on slice result. Shape should be the same as the `slice` output - * shape. + * @param assign Update input with assign value based on slice result. Shape + * should be the same as the `slice` output shape. * @param axes Axes that starts and ends apply to. * @param starts The starting indices of corresponding axis in axes. * @param ends The ending indices of corresponding axis in axes. @@ -773,22 +919,25 @@ class NetBuilder { /** * @brief This operator implements the softmax layer. * @param x An N-D variable. - * @param axis The index of dimension to perform softmax calculations, it should be in range `[−1,rank−1]`, - * while `rank` is the rank of input variable. Default: -1. -1 means the last dimension. - * @param data_format Specify the data format of the output data, the input will be transformed automatically. - * An optional string from: "AnyLayout", "NHWC", "NCHW". Default: "AnyLayout". + * @param axis The index of dimension to perform softmax calculations, it + * should be in range `[−1,rank−1]`, while `rank` is the rank of input + * variable. Default: -1. -1 means the last dimension. + * @param data_format Specify the data format of the output data, the input + * will be transformed automatically. An optional string from: "AnyLayout", + * "NHWC", "NCHW". Default: "AnyLayout". * @return Output of softmax. The data type and shape are the same as input . */ Variable Softmax(const Variable& x, - const std::vector& axes = {-1}, - const std::string& mode = "fast", + const std::vector& axes = {-1}, + const std::string& mode = "fast", const std::string& data_format = "AnyLayout"); // ******************************************* // Type converter Operator /** - * @brief This OP takes in the Variable `x` with `x.dtype` and casts it to the output with dtype. - * It’s meaningless if the output dtype equals the input `dtype`, but it’s fine if you do so. + * @brief This OP takes in the Variable `x` with `x.dtype` and casts it to the + * output with dtype. It’s meaningless if the output dtype equals the input + * `dtype`, but it’s fine if you do so. * @param x An input N-D variable. * @param dtype Data type of the output. * @return A variable with the same shape as input’s. @@ -796,24 +945,29 @@ class NetBuilder { Variable Cast(const Variable& x, const std::string& dtype); /** - * @brief This OP takes in the Variable `x` with `x.dtype` and casts it to the output with dtype. - * The output data shape will be calculated according to the type of input data and the specified output data type. - * Assuming that the input data type is "T" and it's shape is [...], the output data type is specified as "S". - * If the "T" is larger than "S", then the shape changes from [...] to [..., sizeof(T)/sizeof(S)]. - * If "T" is smaller than "S", this operator requires that the rightmost dimension must be equal to - * sizeof(S)/sizeof(T) and the shape then goes from [..., sizeof(S)/sizeof(T)] to [...]. - * It’s meaningless if the output dtype equals the input `dtype`, but it’s fine if you do so. + * @brief This OP takes in the Variable `x` with `x.dtype` and casts it to the + * output with dtype. The output data shape will be calculated according to + * the type of input data and the specified output data type. Assuming that + * the input data type is "T" and it's shape is [...], the output data type is + * specified as "S". If the "T" is larger than "S", then the shape changes + * from [...] to [..., sizeof(T)/sizeof(S)]. If "T" is smaller than "S", this + * operator requires that the rightmost dimension must be equal to + * sizeof(S)/sizeof(T) and the shape then goes from [..., sizeof(S)/sizeof(T)] + * to [...]. It’s meaningless if the output dtype equals the input `dtype`, + * but it’s fine if you do so. * @param x An input N-D variable. * @param dtype Data type of the output. - * @return A variable with the same data buffer as input’s, but shape may different. + * @return A variable with the same data buffer as input’s, but shape may + * different. */ Variable BitcastConvert(const Variable& x, const std::string& dtype); /** - * @brief Returns a one-hot tensor where the locations repsented by indices take value `on_value`, - * other locations take value `off_value`. + * @brief Returns a one-hot tensor where the locations repsented by indices + * take value `on_value`, other locations take value `off_value`. * @param on_value Value to fill at indices. Its shape must be [1]. - * @param on_value Value to fill at all other positions besides indices. Its shape must be [1] + * @param on_value Value to fill at all other positions besides indices. Its + * shape must be [1] * @param depth Depth of the one-hot dimension. * @param axis Axis to fill. */ @@ -851,53 +1005,64 @@ class NetBuilder { * @brief Compute the convolution. * @param x The image variable. * @param weight The filter variable. - * @param strides The stride size. If stride is a list/tuple, it must contain two integers, (stride_H, stride_W). - * Otherwise, the stride_H = stride_W = stride. Default: stride = {1, 1}. - * @param paddings The padding size. If padding is a list/tuple, it must contain two integers, (padding_H, padding_W). - * Otherwise, the padding_H = padding_W = padding. Default: padding = {0, 0}. - * @param dilations The dilation size. If dilation is a list/tuple, it must contain two integers, (dilation_H, - * dilation_W). Otherwise, the dilation_H = dilation_W = dilation. Default: dilation = {1, 1}. + * @param strides The stride size. If stride is a list/tuple, it must contain + * two integers, (stride_H, stride_W). Otherwise, the stride_H = stride_W = + * stride. Default: stride = {1, 1}. + * @param paddings The padding size. If padding is a list/tuple, it must + * contain two integers, (padding_H, padding_W). Otherwise, the padding_H = + * padding_W = padding. Default: padding = {0, 0}. + * @param dilations The dilation size. If dilation is a list/tuple, it must + * contain two integers, (dilation_H, dilation_W). Otherwise, the dilation_H = + * dilation_W = dilation. Default: dilation = {1, 1}. * @param groups The groups number of the conv layer. Default: groups=1. - * @param conv_type The convolution type. The choice contain "forward"/"backward_data"/"backward_filter", note only - * support "forward" when using cudnn. - * @param data_format Data format that specifies the layout of input. It can be "NCHW" or "NHWC". Default: "NCHW". - * @param padding_algorithm CINN not support! It can be "EXPLICIT"/"SAME"/"VALID". Default: "EXPLICIT". + * @param conv_type The convolution type. The choice contain + * "forward"/"backward_data"/"backward_filter", note only support "forward" + * when using cudnn. + * @param data_format Data format that specifies the layout of input. It can + * be "NCHW" or "NHWC". Default: "NCHW". + * @param padding_algorithm CINN not support! It can be + * "EXPLICIT"/"SAME"/"VALID". Default: "EXPLICIT". * @param output_shape The shape of output. Default: None. * @return The convolution result variable. */ Variable Conv(const Variable& x, const Variable& weight, - const std::vector& strides = {1, 1}, - const std::vector& paddings = {0, 0}, - const std::vector& dilations = {1, 1}, - int groups = 1, - const std::string& conv_type = "forward", - const std::string& data_format = "NCHW", - const std::string& padding_algorithm = "EXPLICIT", + const std::vector& strides = {1, 1}, + const std::vector& paddings = {0, 0}, + const std::vector& dilations = {1, 1}, + int groups = 1, + const std::string& conv_type = "forward", + const std::string& data_format = "NCHW", + const std::string& padding_algorithm = "EXPLICIT", const cinn::utils::ShapeType& output_shape = {}); /** * @brief Compute the convolution-2d. * @param x The image variable. * @param weights The filter variable. - * @param strides The stride size. If stride is a list/tuple, it must contain two integers, (stride_H, stride_W). - * Otherwise, the stride_H = stride_W = stride. Default: stride = {1, 1}. - * @param paddings The padding size. If padding is a list/tuple, it must contain two integers, (padding_H, padding_W). - * Otherwise, the padding_H = padding_W = padding. Default: padding = {0, 0}. - * @param dilations The dilation size. If dilation is a list/tuple, it must contain two integers, (dilation_H, - * dilation_W). Otherwise, the dilation_H = dilation_W = dilation. Default: dilation = {1, 1}. + * @param strides The stride size. If stride is a list/tuple, it must contain + * two integers, (stride_H, stride_W). Otherwise, the stride_H = stride_W = + * stride. Default: stride = {1, 1}. + * @param paddings The padding size. If padding is a list/tuple, it must + * contain two integers, (padding_H, padding_W). Otherwise, the padding_H = + * padding_W = padding. Default: padding = {0, 0}. + * @param dilations The dilation size. If dilation is a list/tuple, it must + * contain two integers, (dilation_H, dilation_W). Otherwise, the dilation_H = + * dilation_W = dilation. Default: dilation = {1, 1}. * @param groups The groups number of the conv layer. Default: groups=1. - * @param data_format Data format that specifies the layout of input. It can be "NCHW" or "NHWC". Default: "NCHW". - * @param padding_algorithm CINN not support! It can be "EXPLICIT"/"SAME"/"VALID". Default: "EXPLICIT". + * @param data_format Data format that specifies the layout of input. It can + * be "NCHW" or "NHWC". Default: "NCHW". + * @param padding_algorithm CINN not support! It can be + * "EXPLICIT"/"SAME"/"VALID". Default: "EXPLICIT". * @return The convolution-2d result variable. */ Variable Conv2d(const Variable& x, const Variable& weights, - const std::vector& strides = {1, 1}, - const std::vector& paddings = {0, 0}, - const std::vector& dilations = {1, 1}, - int groups = 1, - const std::string& data_format = "NCHW", + const std::vector& strides = {1, 1}, + const std::vector& paddings = {0, 0}, + const std::vector& dilations = {1, 1}, + int groups = 1, + const std::string& data_format = "NCHW", const std::string& padding_algorithm = "EXPLICIT"); /** @@ -913,86 +1078,105 @@ class NetBuilder { * @param dout The gradient variable of the `conv2d`'s output. * @param x The image variable. * @param weights The filter variable. - * @param strides The stride size. If stride is a list/tuple, it must contain two integers, (stride_H, stride_W). - * Otherwise, the stride_H = stride_W = stride. Default: stride = {1, 1}. - * @param paddings The padding size. If padding is a list/tuple, it must contain two integers, (padding_H, padding_W). - * Otherwise, the padding_H = padding_W = padding. Default: padding = {0, 0}. - * @param dilations The dilation size. If dilation is a list/tuple, it must contain two integers, (dilation_H, - * dilation_W). Otherwise, the dilation_H = dilation_W = dilation. Default: dilation = {1, 1}. + * @param strides The stride size. If stride is a list/tuple, it must contain + * two integers, (stride_H, stride_W). Otherwise, the stride_H = stride_W = + * stride. Default: stride = {1, 1}. + * @param paddings The padding size. If padding is a list/tuple, it must + * contain two integers, (padding_H, padding_W). Otherwise, the padding_H = + * padding_W = padding. Default: padding = {0, 0}. + * @param dilations The dilation size. If dilation is a list/tuple, it must + * contain two integers, (dilation_H, dilation_W). Otherwise, the dilation_H = + * dilation_W = dilation. Default: dilation = {1, 1}. * @param groups The groups number of the conv layer. Default: groups=1. - * @param data_format Data format that specifies the layout of input. It can be "NCHW" or "NHWC". Default: "NCHW". - * @param padding_algorithm CINN not support! It can be "EXPLICIT"/"SAME"/"VALID". Default: "EXPLICIT". + * @param data_format Data format that specifies the layout of input. It can + * be "NCHW" or "NHWC". Default: "NCHW". + * @param padding_algorithm CINN not support! It can be + * "EXPLICIT"/"SAME"/"VALID". Default: "EXPLICIT". * @return The gradient variable of 'x'. */ - std::vector Conv2dGrad(const Variable& dout, - const Variable& x, - const Variable& weights, - const std::vector& strides = {1, 1}, - const std::vector& paddings = {0, 0}, - const std::vector& dilations = {1, 1}, - const int groups = 1, - const std::string& data_format = "NCHW", - const std::string& padding_algorithm = "EXPLICIT"); + std::vector Conv2dGrad( + const Variable& dout, + const Variable& x, + const Variable& weights, + const std::vector& strides = {1, 1}, + const std::vector& paddings = {0, 0}, + const std::vector& dilations = {1, 1}, + const int groups = 1, + const std::string& data_format = "NCHW", + const std::string& padding_algorithm = "EXPLICIT"); /** * @brief Compute the depthwise convolution-2d. * @param x The image variable. * @param weights The filter variable. - * @param strides The stride size. If stride is a list/tuple, it must contain two integers, (stride_H, stride_W). - * Otherwise, the stride_H = stride_W = stride. Default: stride = {1, 1}. - * @param paddings The padding size. If padding is a list/tuple, it must contain two integers, (padding_H, padding_W). - * Otherwise, the padding_H = padding_W = padding. Default: padding = {0, 0}. - * @param dilations The dilation size. If dilation is a list/tuple, it must contain two integers, (dilation_H, - * dilation_W). Otherwise, the dilation_H = dilation_W = dilation. Default: dilation = {1, 1}. + * @param strides The stride size. If stride is a list/tuple, it must contain + * two integers, (stride_H, stride_W). Otherwise, the stride_H = stride_W = + * stride. Default: stride = {1, 1}. + * @param paddings The padding size. If padding is a list/tuple, it must + * contain two integers, (padding_H, padding_W). Otherwise, the padding_H = + * padding_W = padding. Default: padding = {0, 0}. + * @param dilations The dilation size. If dilation is a list/tuple, it must + * contain two integers, (dilation_H, dilation_W). Otherwise, the dilation_H = + * dilation_W = dilation. Default: dilation = {1, 1}. * @param groups The groups number of the conv layer. Default: groups=1. - * @param data_format Data format that specifies the layout of input. It can be "NCHW" or "NHWC". Default: "NCHW". - * @param padding_algorithm CINN not support! It can be "EXPLICIT"/"SAME"/"VALID". Default: "EXPLICIT". + * @param data_format Data format that specifies the layout of input. It can + * be "NCHW" or "NHWC". Default: "NCHW". + * @param padding_algorithm CINN not support! It can be + * "EXPLICIT"/"SAME"/"VALID". Default: "EXPLICIT". * @return The depthwise convolution-2d result variable. */ Variable DepthwiseConv2d(const Variable& x, const Variable& weights, - const std::vector& strides = {1, 1}, - const std::vector& paddings = {0, 0}, - const std::vector& dilations = {1, 1}, - int groups = 1, - const std::string& data_format = "NCHW", + const std::vector& strides = {1, 1}, + const std::vector& paddings = {0, 0}, + const std::vector& dilations = {1, 1}, + int groups = 1, + const std::string& data_format = "NCHW", const std::string& padding_algorithm = "EXPLICIT"); /** * @brief Compute the depthwise convolution-2d. * @param x The image variable. - * @param scale Scale is a 1-dimensional tensor of size C that is applied to the output. - * @param bias Bias is a 1-dimensional tensor of size C that is applied to the output. + * @param scale Scale is a 1-dimensional tensor of size C that is applied to + * the output. + * @param bias Bias is a 1-dimensional tensor of size C that is applied to the + * output. * @param mean The global mean (for training) or estimated mean (for testing). - * @param variance The global variance (for training) or estimated Variance (for testing) - * @param epsilon The small value added to the variance to prevent division by zero. Default: 1e-5f. - * @param momentum The value used for the moving_mean and moving_var computation. Default: 0.9f. - * @param data_layout Specify the input data format, may be “NC”, “NCL”, “NCHW”, “NCDHW”, “NLC”, “NHWC” or “NDHWC”. - * Defalut “NCHW”. + * @param variance The global variance (for training) or estimated Variance + * (for testing) + * @param epsilon The small value added to the variance to prevent division by + * zero. Default: 1e-5f. + * @param momentum The value used for the moving_mean and moving_var + * computation. Default: 0.9f. + * @param data_layout Specify the input data format, may be “NC”, “NCL”, + * “NCHW”, “NCDHW”, “NLC”, “NHWC” or “NDHWC”. Defalut “NCHW”. * @param is_test A flag indicating whether it is in test phrase or not. - * @return `{out}` if `is_test` it true, `{out, saved_mean, saved_variance, moving_mean, moving_variance}` if - * `is_test` is false. + * @return `{out}` if `is_test` it true, `{out, saved_mean, saved_variance, + * moving_mean, moving_variance}` if `is_test` is false. */ std::vector BatchNorm(const Variable& x, const Variable& scale, const Variable& bias, const Variable& mean, const Variable& variance, - float epsilon = 1e-5f, - float momentum = 0.9f, + float epsilon = 1e-5f, + float momentum = 0.9f, const std::string& data_layout = "NCHW", - bool is_test = false); + bool is_test = false); /** * @brief The gradient function of BatchNorm training. - * @param dout The gradient variable of the `batch_norm_training`'s first output. + * @param dout The gradient variable of the `batch_norm_training`'s first + * output. * @param x The image variable. - * @param scale Scale is a 1-dimensional tensor of size C that is applied to the output. + * @param scale Scale is a 1-dimensional tensor of size C that is applied to + * the output. * @param save_mean The global mean saved from forward compute. * @param save_variance The global variance from forward compute. - * @param epsilon The small value added to the variance to prevent division by zero. Default: 1e-5f. - * @param data_layout Specify the input data format, may be “NC”, “NCL”, “NCHW”, “NCDHW”, “NLC”, “NHWC” or “NDHWC”. - * Defalut “NCHW”. + * @param epsilon The small value added to the variance to prevent division by + * zero. Default: 1e-5f. + * @param data_layout Specify the input data format, may be “NC”, “NCL”, + * “NCHW”, “NCDHW”, “NLC”, “NHWC” or “NDHWC”. Defalut “NCHW”. * @return `{x_grad, scale_grad, bias_grad}`. */ // batch norm grad, output(x_grad, scale_grad, bias_grad) @@ -1001,7 +1185,7 @@ class NetBuilder { const Variable& scale, const Variable& save_mean, const Variable& save_variance, - const float epsilon = 1e-5f, + const float epsilon = 1e-5f, const std::string& data_layout = "NCHW"); /** @@ -1012,7 +1196,9 @@ class NetBuilder { * Defalut “NCHW”. * @return `Index of variable x to the maximum value`. */ - Variable Argmax(const Variable& x, const int& axis = 0, const bool& keep_dim = false); + Variable Argmax(const Variable& x, + const int& axis = 0, + const bool& keep_dim = false); /** * @brief Get index of variable x to the minimum value along the given axis. @@ -1022,28 +1208,35 @@ class NetBuilder { * Defalut “NCHW”. * @return `Index of variable x to the minimum value`. */ - Variable Argmin(const Variable& x, const int& axis = 0, const bool& keep_dim = false); + Variable Argmin(const Variable& x, + const int& axis = 0, + const bool& keep_dim = false); /** - * @brief Sort Variable x along the given axis and return sorted index. The original Variable x will not be changed. + * @brief Sort Variable x along the given axis and return sorted index. The + * original Variable x will not be changed. * @param operand The variable that will be sorted. * @param axis Specify the axis to operate on the input. Default: 0. * @param is_ascend Sort mode. * Defalut “NCHW”. * @return `Sorted variable index`. */ - std::vector ArgSort(const Variable& operand, const int& axis, const bool& is_ascend = true); + std::vector ArgSort(const Variable& operand, + const int& axis, + const bool& is_ascend = true); /** - * @brief Sort Variable x along the given axis and return sorted variable. The original Variable x will not be - * changed. + * @brief Sort Variable x along the given axis and return sorted variable. The + * original Variable x will not be changed. * @param operand The variable that will be sorted. * @param axis Specify the axis to operate on the input. Default: 0. * @param is_ascend Sort mode. * Defalut “NCHW”. * @return `Sorted variable`. */ - Variable Sort(const Variable& operand, const int& axis, const bool& is_ascend = true); + Variable Sort(const Variable& operand, + const int& axis, + const bool& is_ascend = true); /** * @brief Lookup embeddings vector of ids provided by x . @@ -1054,7 +1247,9 @@ class NetBuilder { with zeros whenever lookup encounters it in Ids. * @return `The concatenated variable of selected values`. */ - Variable LookupTable(const Variable& table, const Variable& ids, int64_t padding_idx); + Variable LookupTable(const Variable& table, + const Variable& ids, + int64_t padding_idx); /** * @brief Gaussian random @@ -1062,47 +1257,59 @@ class NetBuilder { * @param mean Mean of the output variable, default is 0.0f. * @param std Standard deviation of the output variable, default is 1.0f. * @param seed Random seed of generator, default is 0. - * @param dtype Data type of output variable, supported data types: float32, float64. + * @param dtype Data type of output variable, supported data types: float32, + * float64. */ Variable GaussianRandom(const std::vector& shape, - float mean = 0.0f, - float std = 1.0f, - int seed = 0, + float mean = 0.0f, + float std = 1.0f, + int seed = 0, const std::string& dtype = "float32"); /** * @brief Uniform random * @param shape Shape of the variable to be created. - * @param min The lower bound of the range of random values ​​generated, min is included in the range. - * @param max The upper bound of the range of random values ​​generated, max is not included in the range. + * @param min The lower bound of the range of random values ​​generated, + * min is included in the range. + * @param max The upper bound of the range of random values ​​generated, + * max is not included in the range. * @param seed Random seed of generator, default is 0. - * @param dtype Data tpye of output variable, supported data types: float32, float64. + * @param dtype Data tpye of output variable, supported data types: float32, + * float64. */ Variable UniformRandom(const std::vector& shape, - float min = -1.0f, - float max = 1.0f, - int seed = 0, + float min = -1.0f, + float max = 1.0f, + int seed = 0, const std::string& dtype = "float32", - int diag_num = 0, - int diag_step = 0, - float diag_val = 1.0f); + int diag_num = 0, + int diag_step = 0, + float diag_val = 1.0f); /** * @brief Generate random integers in the range min to max * @param shape Shape of the variable to be created. - * @param min The lower bound of the range of random values ​​generated, min is included in the range. - * @param max The upper bound of the range of random values ​​generated, max is not included in the range. + * @param min The lower bound of the range of random values ​​generated, + * min is included in the range. + * @param max The upper bound of the range of random values ​​generated, + * max is not included in the range. * @param seed Random seed of generator, default is 0. - * @param dtype Data tpye of output variable, supported data types: int32, int64. + * @param dtype Data tpye of output variable, supported data types: int32, + * int64. */ - Variable RandInt( - const std::vector& shape, int min = 0, int max = 0, int seed = 0, const std::string& dtype = "int64"); + Variable RandInt(const std::vector& shape, + int min = 0, + int max = 0, + int seed = 0, + const std::string& dtype = "int64"); /** - * @brief Compute cholesky decomposition of a positive definite symmetric matrix. + * @brief Compute cholesky decomposition of a positive definite symmetric + matrix. * @param x Positive definite symmetric matrix. - * @param upper When upper is true, calculate and return the upper triangular matrix. - When upper is false, calculate and return the lower triangular matrix. + * @param upper When upper is true, calculate and return the upper triangular + matrix. When upper is false, calculate and return the lower triangular + matrix. * @return Triangular matrix, shape is same as input. */ Variable Cholesky(const Variable& x, bool upper = false); @@ -1113,28 +1320,34 @@ class NetBuilder { * @param input2 matrix on the right hand side. * @param left_side When left_side is true, compute A*X = B. When left_side is false, compute X*A = B. - * @param upper When upper is true, use the upper part of the triangular matrix. - When upper is false, use the lower part of the triangular matrix. + * @param upper When upper is true, use the upper part of the triangular + matrix. When upper is false, use the lower part of the triangular matrix. * @param transpose_a When transpose_a is true, use the transpose of matrix A - * @param unit_diagonal When unit_diagonal is true, assume the elements on the main diagonal of matrix A are unity + * @param unit_diagonal When unit_diagonal is true, assume the elements on the + main diagonal of matrix A are unity * @return The solution for the triangular linear systems. */ - Variable TriangularSolve( - const Variable& input1, const Variable& input2, bool left_side, bool upper, bool transpose_a, bool unit_diagonal); + Variable TriangularSolve(const Variable& input1, + const Variable& input2, + bool left_side, + bool upper, + bool transpose_a, + bool unit_diagonal); /** - * @brief Return values and indices of the k largest or smallest at the optional axis. - * If the input is a 1-D Tensor, finds the k largest or smallest values and indices. - * If the input is a Tensor with higher rank, this operator computes the top k values - * and indices along the axis. + * @brief Return values and indices of the k largest or smallest at the + * optional axis. If the input is a 1-D Tensor, finds the k largest or + * smallest values and indices. If the input is a Tensor with higher rank, + * this operator computes the top k values and indices along the axis. * @param x Input tensor. * @param k The number of top elements to look for along the axis. - * @param axis Axis to compute indices along. The effective range is [-R, R), where R is - * x.ndim. when axis < 0, it works the same way as axis + R. Default is -1. - * @param largest largest is a flag, if set to true, algorithm will sort by descending - * order, otherwise sort by ascending order. Default is True. - * @return The values and indices. The value data type is the same as the input x. The - * indices data type is int64. + * @param axis Axis to compute indices along. The effective range is [-R, R), + * where R is x.ndim. when axis < 0, it works the same way as axis + R. + * Default is -1. + * @param largest largest is a flag, if set to true, algorithm will sort by + * descending order, otherwise sort by ascending order. Default is True. + * @return The values and indices. The value data type is the same as the + * input x. The indices data type is int64. */ std::vector TopK(const Variable& x, int k, int axis, bool largest); diff --git a/paddle/cinn/frontend/net_builder_test.cc b/paddle/cinn/frontend/net_builder_test.cc index 14450ae8bad34..ba104d642f2a2 100644 --- a/paddle/cinn/frontend/net_builder_test.cc +++ b/paddle/cinn/frontend/net_builder_test.cc @@ -44,10 +44,10 @@ Program CreateAddProgram() { constexpr int N = 24; NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {M, N}, "A"); - auto b = builder.CreateInput(Float(32), {M, N}, "B"); - auto c = builder.Add(a, b); - auto d = builder.Add(a, c); + auto a = builder.CreateInput(Float(32), {M, N}, "A"); + auto b = builder.CreateInput(Float(32), {M, N}, "B"); + auto c = builder.Add(a, b); + auto d = builder.Add(a, c); auto program = builder.Build(); return program; @@ -65,8 +65,10 @@ std::ostream& operator<<(std::ostream& os, const std::vector& vec) { } // namespace TEST(net_build, basic) { - LOG(INFO) << "The size of registered operators: " << OpRegistry::Global()->ListAllNames().size(); - LOG(INFO) << "Registered operators:\n" << OpRegistry::Global()->ListAllNames(); + LOG(INFO) << "The size of registered operators: " + << OpRegistry::Global()->ListAllNames().size(); + LOG(INFO) << "Registered operators:\n" + << OpRegistry::Global()->ListAllNames(); auto program = CreateAddProgram(); // output program for (int i = 0; i < program.size(); i++) { @@ -135,11 +137,11 @@ TEST(net_build, program_execute_fc) { scope->Var(std::string(b.id())); scope->Var(std::string(mul_out->id)); - auto a_ten = scope->GetTensor(std::string(a.id())); - auto w_ten = scope->GetTensor(std::string(w.id())); - auto b_ten = scope->GetTensor(std::string(b.id())); + auto a_ten = scope->GetTensor(std::string(a.id())); + auto w_ten = scope->GetTensor(std::string(w.id())); + auto b_ten = scope->GetTensor(std::string(b.id())); auto fake_out_ten = scope->GetTensor(std::string(mul_out->id)); - auto add_out_ten = scope->GetTensor(std::string(add_out->id)); + auto add_out_ten = scope->GetTensor(std::string(add_out->id)); SetRandData(a_ten, target); SetRandData(w_ten, target); SetRandData(b_ten, target); @@ -154,10 +156,10 @@ TEST(net_build, program_execute_multi_elementwise_add_bf16) { constexpr int N = 24; NetBuilder builder("net_builder"); - auto a = builder.CreateInput(cinn::common::BFloat16(), {M, N}, "A"); - auto b = builder.CreateInput(cinn::common::BFloat16(), {M, N}, "B"); - auto c = builder.Add(a, b); - auto d = builder.Add(a, c); + auto a = builder.CreateInput(cinn::common::BFloat16(), {M, N}, "A"); + auto b = builder.CreateInput(cinn::common::BFloat16(), {M, N}, "B"); + auto c = builder.Add(a, b); + auto d = builder.Add(a, c); auto program = builder.Build(); #ifdef CINN_WITH_CUDA @@ -193,8 +195,9 @@ TEST(net_build, program_execute_fc_bf16) { NetBuilder builder("net_builder"); auto a = builder.CreateInput(cinn::common::BFloat16(), {B * M, K}, "A"); - auto w = builder.CreateInput(cinn::common::BFloat16(), {K, N}, "W"); // weight - auto b = builder.CreateInput(cinn::common::BFloat16(), {N}, "B"); // bias + auto w = + builder.CreateInput(cinn::common::BFloat16(), {K, N}, "W"); // weight + auto b = builder.CreateInput(cinn::common::BFloat16(), {N}, "B"); // bias auto mul_out = builder.Matmul(a, w); auto add_out = builder.Add(mul_out, b); @@ -219,11 +222,11 @@ TEST(net_build, program_execute_fc_bf16) { scope->Var(std::string(b.id())); scope->Var(std::string(mul_out->id)); - auto a_ten = scope->GetTensor(std::string(a.id())); - auto w_ten = scope->GetTensor(std::string(w.id())); - auto b_ten = scope->GetTensor(std::string(b.id())); + auto a_ten = scope->GetTensor(std::string(a.id())); + auto w_ten = scope->GetTensor(std::string(w.id())); + auto b_ten = scope->GetTensor(std::string(b.id())); auto fake_out_ten = scope->GetTensor(std::string(mul_out->id)); - auto add_out_ten = scope->GetTensor(std::string(add_out->id)); + auto add_out_ten = scope->GetTensor(std::string(add_out->id)); SetRandData(a_ten, target); SetRandData(w_ten, target); SetRandData(b_ten, target); @@ -239,18 +242,18 @@ TEST(net_build, program_execute_pool2d) { const int W = 112; NetBuilder builder("net_builder"); - Placeholder input = builder.CreateInput(Float(32), {B, C, H, W}, "Img"); + Placeholder input = builder.CreateInput(Float(32), {B, C, H, W}, "Img"); std::string pooling_type = "max"; std::vector ksize{3, 3}; std::vector strides{2, 2}; std::vector paddings{1, 1, 1, 1}; - bool ceil_mode = false; - bool exclusive = true; - bool global_pooling = false; - std::string data_format = "NCHW"; - bool adaptive = false; + bool ceil_mode = false; + bool exclusive = true; + bool global_pooling = false; + std::string data_format = "NCHW"; + bool adaptive = false; std::string padding_algorithm = "EXPLICIT"; - Variable pool_out = builder.Pool2d(input, + Variable pool_out = builder.Pool2d(input, pooling_type, ksize, strides, @@ -261,7 +264,7 @@ TEST(net_build, program_execute_pool2d) { data_format, adaptive, padding_algorithm); - auto program = builder.Build(); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -290,9 +293,9 @@ TEST(net_build, program_execute_reverse) { const int W = 224; NetBuilder builder("net_builder"); - Placeholder input = builder.CreateInput(Float(32), {B, C, H, W}, "Img"); + Placeholder input = builder.CreateInput(Float(32), {B, C, H, W}, "Img"); Variable reverse_out = builder.Reverse(input, {2, 3}); - auto program = builder.Build(); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -317,15 +320,15 @@ TEST(net_build, program_execute_reverse) { } TEST(net_build, program_execute_gather) { - const int B = 4; + const int B = 4; const int H_IN1 = 18; const int H_IN2 = 14; NetBuilder builder("net_builder"); Placeholder input1 = builder.CreateInput(Float(32), {B, H_IN1}, "In1"); Placeholder input2 = builder.CreateInput(Int(32), {H_IN2}, "In2"); - Variable output = builder.Gather(input1, input2, 1); - auto program = builder.Build(); + Variable output = builder.Gather(input1, input2, 1); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -353,7 +356,7 @@ TEST(net_build, program_execute_gather) { runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_tensor->type(), Float(32)); EXPECT_EQ(output_shape.size(), 2UL); @@ -365,8 +368,8 @@ TEST(net_build, program_execute_gather) { for (int b = 0; b < B; ++b) { for (int h = 0; h < H_IN2; ++h) { std::string line; - int index = h + H_IN2 * b; - float in_data = input1_data[input2_data[h] + H_IN1 * b]; + int index = h + H_IN2 * b; + float in_data = input1_data[input2_data[h] + H_IN1 * b]; float out_data = output_data[index]; line += (std::to_string(out_data) + ", "); EXPECT_EQ(in_data, out_data); @@ -376,15 +379,15 @@ TEST(net_build, program_execute_gather) { } TEST(net_build, program_execute_gather_nd) { - const int B = 4; + const int B = 4; const int H_IN1 = 11; const int H_IN2 = 14; NetBuilder builder("net_builder"); Placeholder input1 = builder.CreateInput(Float(32), {B, H_IN1}, "In1"); Placeholder input2 = builder.CreateInput(Int(32), {B, H_IN2, 1}, "In2"); - Variable output = builder.GatherNd(input1, input2); - auto program = builder.Build(); + Variable output = builder.GatherNd(input1, input2); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -413,7 +416,7 @@ TEST(net_build, program_execute_gather_nd) { runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_tensor->type(), Float(32)); EXPECT_EQ(output_shape.size(), 3UL); @@ -427,8 +430,8 @@ TEST(net_build, program_execute_gather_nd) { for (int h = 0; h < H_IN2; ++h) { std::string line; for (int c = 0; c < H_IN1; ++c) { - float in_data = input1_data[input2_data[b * H_IN2 + h] * H_IN1 + c]; - int out_index = c + h * H_IN1 + H_IN1 * H_IN2 * b; + float in_data = input1_data[input2_data[b * H_IN2 + h] * H_IN1 + c]; + int out_index = c + h * H_IN1 + H_IN1 * H_IN2 * b; float out_data = output_data[out_index]; line += (std::to_string(out_data) + ", "); EXPECT_EQ(in_data, out_data); @@ -444,8 +447,8 @@ TEST(net_build, program_execute_cast) { NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Int(32), {B, H}, "In"); - Variable output = builder.Cast(input, "float"); - auto program = builder.Build(); + Variable output = builder.Cast(input, "float"); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -468,7 +471,7 @@ TEST(net_build, program_execute_cast) { runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_tensor->type(), Float(32)); EXPECT_EQ(output_shape.size(), 2UL); @@ -480,8 +483,8 @@ TEST(net_build, program_execute_cast) { for (int b = 0; b < B; ++b) { for (int h = 0; h < H; ++h) { std::string line; - int index = h + H * b; - float in_data = (float)input_data[index]; + int index = h + H * b; + float in_data = (float)input_data[index]; float out_data = output_data[index]; line += (std::to_string(out_data) + ", "); EXPECT_EQ(in_data, out_data); @@ -498,8 +501,8 @@ TEST(net_build, program_execute_squeeze_case0) { NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {B, C, H, W}, "In"); - Variable output = builder.Squeeze(input, {1}); - auto program = builder.Build(); + Variable output = builder.Squeeze(input, {1}); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -522,7 +525,7 @@ TEST(net_build, program_execute_squeeze_case0) { runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_shape.size(), 3UL); EXPECT_EQ(output_shape[0], B); @@ -537,8 +540,8 @@ TEST(net_build, program_execute_squeeze_case0) { for (int h = 0; h < H; ++h) { std::string line; for (int w = 0; w < W; ++w) { - int index = w + W * (h + H * (c + C * b)); - float in_data = input_data[index]; + int index = w + W * (h + H * (c + C * b)); + float in_data = input_data[index]; float out_data = output_data[index]; line += (std::to_string(out_data) + ", "); EXPECT_EQ(in_data, out_data); @@ -557,8 +560,8 @@ TEST(net_build, program_execute_squeeze_case1) { NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {B, C, H, W}, "In"); - Variable output = builder.Squeeze(input, {-3}); - auto program = builder.Build(); + Variable output = builder.Squeeze(input, {-3}); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -581,7 +584,7 @@ TEST(net_build, program_execute_squeeze_case1) { runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_shape.size(), 3UL); EXPECT_EQ(output_shape[0], B); @@ -596,8 +599,8 @@ TEST(net_build, program_execute_squeeze_case1) { for (int h = 0; h < H; ++h) { std::string line; for (int w = 0; w < W; ++w) { - int index = w + W * (h + H * (c + C * b)); - float in_data = input_data[index]; + int index = w + W * (h + H * (c + C * b)); + float in_data = input_data[index]; float out_data = output_data[index]; line += (std::to_string(out_data) + ", "); EXPECT_EQ(in_data, out_data); @@ -616,8 +619,8 @@ TEST(net_build, program_execute_squeeze_case2) { NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {B, C, H, W}, "In"); - Variable output = builder.Squeeze(input, {1, 3}); - auto program = builder.Build(); + Variable output = builder.Squeeze(input, {1, 3}); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -640,7 +643,7 @@ TEST(net_build, program_execute_squeeze_case2) { runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_shape.size(), 2UL); EXPECT_EQ(output_shape[0], B); @@ -654,8 +657,8 @@ TEST(net_build, program_execute_squeeze_case2) { for (int h = 0; h < H; ++h) { std::string line; for (int w = 0; w < W; ++w) { - int index = w + W * (h + H * (c + C * b)); - float in_data = input_data[index]; + int index = w + W * (h + H * (c + C * b)); + float in_data = input_data[index]; float out_data = output_data[index]; line += (std::to_string(out_data) + ", "); EXPECT_EQ(in_data, out_data); @@ -674,8 +677,8 @@ TEST(net_build, program_execute_squeeze_case3) { NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {B, C, H, W}, "In"); - Variable output = builder.Squeeze(input, {1, -1}); - auto program = builder.Build(); + Variable output = builder.Squeeze(input, {1, -1}); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -698,7 +701,7 @@ TEST(net_build, program_execute_squeeze_case3) { runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_shape.size(), 2UL); EXPECT_EQ(output_shape[0], B); @@ -712,8 +715,8 @@ TEST(net_build, program_execute_squeeze_case3) { for (int h = 0; h < H; ++h) { std::string line; for (int w = 0; w < W; ++w) { - int index = w + W * (h + H * (c + C * b)); - float in_data = input_data[index]; + int index = w + W * (h + H * (c + C * b)); + float in_data = input_data[index]; float out_data = output_data[index]; line += (std::to_string(out_data) + ", "); EXPECT_EQ(in_data, out_data); @@ -732,8 +735,8 @@ TEST(net_build, program_execute_squeeze_case4) { NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {B, C, H, W}, "In"); - Variable output = builder.Squeeze(input, {}); - auto program = builder.Build(); + Variable output = builder.Squeeze(input, {}); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -756,7 +759,7 @@ TEST(net_build, program_execute_squeeze_case4) { runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_shape.size(), 2UL); EXPECT_EQ(output_shape[0], B); @@ -770,8 +773,8 @@ TEST(net_build, program_execute_squeeze_case4) { for (int h = 0; h < H; ++h) { std::string line; for (int w = 0; w < W; ++w) { - int index = w + W * (h + H * (c + C * b)); - float in_data = input_data[index]; + int index = w + W * (h + H * (c + C * b)); + float in_data = input_data[index]; float out_data = output_data[index]; line += (std::to_string(out_data) + ", "); EXPECT_EQ(in_data, out_data); @@ -788,8 +791,8 @@ TEST(net_build, program_execute_argsort) { NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {B, H}, "In"); - Variable output = builder.ArgSort(input, 0, true).at(0); - auto program = builder.Build(); + Variable output = builder.ArgSort(input, 0, true).at(0); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -812,7 +815,7 @@ TEST(net_build, program_execute_argsort) { runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_tensor->type(), Int(32)); EXPECT_EQ(output_shape.size(), 2UL); @@ -833,9 +836,9 @@ TEST(net_build, program_execute_argsort) { for (int b = 0; b < B; ++b) { std::string line; - int index = h + H * b; + int index = h + H * b; float true_data = sorted_data[b]; - float out_data = out_sorted_data[b]; + float out_data = out_sorted_data[b]; line += (std::to_string(out_data) + ", "); EXPECT_EQ(true_data, out_data); VLOG(6) << line; @@ -849,8 +852,8 @@ TEST(net_build, program_execute_sort) { NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {B, H}, "In"); - Variable output = builder.Sort(input, 0, true); - auto program = builder.Build(); + Variable output = builder.Sort(input, 0, true); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -873,7 +876,7 @@ TEST(net_build, program_execute_sort) { runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_tensor->type(), Float(32)); EXPECT_EQ(output_shape.size(), 2UL); @@ -892,9 +895,9 @@ TEST(net_build, program_execute_sort) { for (int b = 0; b < B; ++b) { std::string line; - int index = h + H * b; + int index = h + H * b; float true_data = sorted_data[b]; - float out_data = output_data[index]; + float out_data = output_data[index]; line += (std::to_string(out_data) + ", "); EXPECT_EQ(true_data, out_data); VLOG(6) << line; @@ -903,9 +906,9 @@ TEST(net_build, program_execute_sort) { } TEST(net_build, program_execute_arange_float) { - const float start = 1.5F; - const float stop = 31.5F; - const float step = 2.0F; + const float start = 1.5F; + const float stop = 31.5F; + const float step = 2.0F; const std::string dtype = "float32"; NetBuilder builder("net_builder"); @@ -928,7 +931,7 @@ TEST(net_build, program_execute_arange_float) { runtime_program->Execute(); - auto out_tensor = scope->GetTensor(std::string(out->id)); + auto out_tensor = scope->GetTensor(std::string(out->id)); const std::vector& out_tensor_shape = out_tensor->shape().data(); EXPECT_EQ(out_tensor->type(), Float(32)); EXPECT_EQ(out_tensor_shape.size(), 1UL); @@ -944,9 +947,9 @@ TEST(net_build, program_execute_arange_float) { } TEST(net_build, program_execute_arange_int) { - const float start = 1.5F; - const float stop = 31.5F; - const float step = 1.6F; + const float start = 1.5F; + const float stop = 31.5F; + const float step = 1.6F; const std::string dtype = "int32"; NetBuilder builder("net_builder"); @@ -969,7 +972,7 @@ TEST(net_build, program_execute_arange_int) { runtime_program->Execute(); - auto out_tensor = scope->GetTensor(std::string(out->id)); + auto out_tensor = scope->GetTensor(std::string(out->id)); const std::vector& out_tensor_shape = out_tensor->shape().data(); EXPECT_EQ(out_tensor->type(), Int(32)); EXPECT_EQ(out_tensor_shape.size(), 1UL); @@ -985,16 +988,16 @@ TEST(net_build, program_execute_arange_int) { } TEST(net_build, program_argmax_case1) { - const int N = 4; - const int IN_C = 3; + const int N = 4; + const int IN_C = 3; const int OUT_C = 1; - const int H = 7; - const int W = 7; + const int H = 7; + const int W = 7; NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {N, IN_C, H, W}, "In"); - Variable output = builder.Argmax(input, 1, true); - auto program = builder.Build(); + Variable output = builder.Argmax(input, 1, true); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -1030,7 +1033,7 @@ TEST(net_build, program_argmax_case1) { } runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_shape.size(), 4UL); EXPECT_EQ(output_shape[0], N); @@ -1046,13 +1049,13 @@ TEST(net_build, program_argmax_case1) { for (int h = 0; h < H; ++h) { std::string line; for (int w = 0; w < W; ++w) { - int index = w + W * (h + H * (c + IN_C * n)); + int index = w + W * (h + H * (c + IN_C * n)); int out_index = w + W * (h + H * n); float in_data = input_data[index]; - int out_data = output_data[out_index]; + int out_data = output_data[out_index]; EXPECT_LE(0, out_data); EXPECT_LT(out_data, IN_C); - int max_index = w + W * (h + H * (out_data + IN_C * n)); + int max_index = w + W * (h + H * (out_data + IN_C * n)); float max_value = input_data[max_index]; line += (std::to_string(out_data) + ", "); EXPECT_LE(in_data, max_value); @@ -1064,15 +1067,15 @@ TEST(net_build, program_argmax_case1) { } TEST(net_build, program_argmax_case2) { - const int N = 4; + const int N = 4; const int IN_C = 3; - const int H = 7; - const int W = 7; + const int H = 7; + const int W = 7; NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {N, IN_C, H, W}, "In"); - Variable output = builder.Argmax(input, 1, false); - auto program = builder.Build(); + Variable output = builder.Argmax(input, 1, false); + auto program = builder.Build(); Target target = common::DefaultHostTarget(); std::unordered_set fetch_ids; @@ -1104,7 +1107,7 @@ TEST(net_build, program_argmax_case2) { } runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_shape.size(), 3UL); EXPECT_EQ(output_shape[0], N); @@ -1119,13 +1122,13 @@ TEST(net_build, program_argmax_case2) { for (int h = 0; h < H; ++h) { std::string line; for (int w = 0; w < W; ++w) { - int index = w + W * (h + H * (c + IN_C * n)); + int index = w + W * (h + H * (c + IN_C * n)); int out_index = w + W * (h + H * n); float in_data = input_data[index]; - int out_data = output_data[out_index]; + int out_data = output_data[out_index]; EXPECT_LE(0, out_data); EXPECT_LT(out_data, IN_C); - int max_index = w + W * (h + H * (out_data + IN_C * n)); + int max_index = w + W * (h + H * (out_data + IN_C * n)); float max_value = input_data[max_index]; line += (std::to_string(out_data) + ", "); EXPECT_LE(in_data, max_value); @@ -1137,16 +1140,16 @@ TEST(net_build, program_argmax_case2) { } TEST(net_build, program_argmin_case1) { - const int N = 4; - const int IN_C = 3; + const int N = 4; + const int IN_C = 3; const int OUT_C = 1; - const int H = 7; - const int W = 7; + const int H = 7; + const int W = 7; NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {N, IN_C, H, W}, "In"); - Variable output = builder.Argmin(input, 1, true); - auto program = builder.Build(); + Variable output = builder.Argmin(input, 1, true); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); #else @@ -1182,7 +1185,7 @@ TEST(net_build, program_argmin_case1) { } runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_shape.size(), 4UL); EXPECT_EQ(output_shape[0], N); @@ -1198,13 +1201,13 @@ TEST(net_build, program_argmin_case1) { for (int h = 0; h < H; ++h) { std::string line; for (int w = 0; w < W; ++w) { - int index = w + W * (h + H * (c + IN_C * n)); + int index = w + W * (h + H * (c + IN_C * n)); int out_index = w + W * (h + H * n); float in_data = input_data[index]; - int out_data = output_data[out_index]; + int out_data = output_data[out_index]; EXPECT_LE(0, out_data); EXPECT_LT(out_data, IN_C); - int max_index = w + W * (h + H * (out_data + IN_C * n)); + int max_index = w + W * (h + H * (out_data + IN_C * n)); float max_value = input_data[max_index]; line += (std::to_string(out_data) + ", "); EXPECT_GE(in_data, max_value); @@ -1216,15 +1219,15 @@ TEST(net_build, program_argmin_case1) { } TEST(net_build, program_argmin_case2) { - const int N = 4; + const int N = 4; const int IN_C = 3; - const int H = 7; - const int W = 7; + const int H = 7; + const int W = 7; NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {N, IN_C, H, W}, "In"); - Variable output = builder.Argmin(input, 1, false); - auto program = builder.Build(); + Variable output = builder.Argmin(input, 1, false); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); #else @@ -1259,7 +1262,7 @@ TEST(net_build, program_argmin_case2) { } runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); EXPECT_EQ(output_shape.size(), 3UL); EXPECT_EQ(output_shape[0], N); @@ -1274,13 +1277,13 @@ TEST(net_build, program_argmin_case2) { for (int h = 0; h < H; ++h) { std::string line; for (int w = 0; w < W; ++w) { - int index = w + W * (h + H * (c + IN_C * n)); + int index = w + W * (h + H * (c + IN_C * n)); int out_index = w + W * (h + H * n); float in_data = input_data[index]; - int out_data = output_data[out_index]; + int out_data = output_data[out_index]; EXPECT_LE(0, out_data); EXPECT_LT(out_data, IN_C); - int max_index = w + W * (h + H * (out_data + IN_C * n)); + int max_index = w + W * (h + H * (out_data + IN_C * n)); float max_value = input_data[max_index]; line += (std::to_string(out_data) + ", "); EXPECT_GE(in_data, max_value); @@ -1292,15 +1295,15 @@ TEST(net_build, program_argmin_case2) { } TEST(net_build, program_execute_repeat_axis_0) { - const int M = 4; - const int N = 4; + const int M = 4; + const int N = 4; const int repeats = 3; - const int axis = 0; + const int axis = 0; NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {M, N}, "In"); - Variable output = builder.Repeat(input, repeats, axis); - auto program = builder.Build(); + Variable output = builder.Repeat(input, repeats, axis); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -1323,7 +1326,7 @@ TEST(net_build, program_execute_repeat_axis_0) { runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); const int new_M = M * repeats; @@ -1336,9 +1339,9 @@ TEST(net_build, program_execute_repeat_axis_0) { std::vector output_data = GetTensorData(output_tensor, target); for (int m = 0; m < new_M; ++m) { for (int n = 0; n < new_N; ++n) { - int in_index = n + N * static_cast(std::floor((float)m / repeats)); - int out_index = n + new_N * m; - float in_data = input_data[in_index]; + int in_index = n + N * static_cast(std::floor((float)m / repeats)); + int out_index = n + new_N * m; + float in_data = input_data[in_index]; float out_data = output_data[out_index]; EXPECT_EQ(in_data, out_data); } @@ -1346,15 +1349,15 @@ TEST(net_build, program_execute_repeat_axis_0) { } TEST(net_build, program_execute_repeat_axis_1) { - const int M = 4; - const int N = 4; + const int M = 4; + const int N = 4; const int repeats = 3; - const int axis = 1; + const int axis = 1; NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {M, N}, "In"); - Variable output = builder.Repeat(input, repeats, axis); - auto program = builder.Build(); + Variable output = builder.Repeat(input, repeats, axis); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -1377,7 +1380,7 @@ TEST(net_build, program_execute_repeat_axis_1) { runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); const int new_M = M; @@ -1390,9 +1393,9 @@ TEST(net_build, program_execute_repeat_axis_1) { std::vector output_data = GetTensorData(output_tensor, target); for (int m = 0; m < new_M; ++m) { for (int n = 0; n < new_N; ++n) { - int in_index = N * m + static_cast(std::floor((float)n / repeats)); - int out_index = n + new_N * m; - float in_data = input_data[in_index]; + int in_index = N * m + static_cast(std::floor((float)n / repeats)); + int out_index = n + new_N * m; + float in_data = input_data[in_index]; float out_data = output_data[out_index]; EXPECT_EQ(in_data, out_data); } @@ -1400,20 +1403,21 @@ TEST(net_build, program_execute_repeat_axis_1) { } TEST(net_build, program_execute_one_hot) { - const int M = 4; - const int N = 4; - const int on_value = 1; - const int off_value = 0; - const int depth = 11; - const int axis = 0; // [-1 , M] + const int M = 4; + const int N = 4; + const int on_value = 1; + const int off_value = 0; + const int depth = 11; + const int axis = 0; // [-1 , M] const std::string dtype = "int32"; NetBuilder builder("net_builder"); - Placeholder input = builder.CreateInput(Int(32), {M, N}, "In"); - Placeholder on_value_input = builder.CreateInput(Int(32), {1}, "OnValue"); + Placeholder input = builder.CreateInput(Int(32), {M, N}, "In"); + Placeholder on_value_input = builder.CreateInput(Int(32), {1}, "OnValue"); Placeholder off_value_input = builder.CreateInput(Int(32), {1}, "OffValue"); - Variable output = builder.OneHot(input, on_value_input, off_value_input, depth, axis, dtype); - auto program = builder.Build(); + Variable output = builder.OneHot( + input, on_value_input, off_value_input, depth, axis, dtype); + auto program = builder.Build(); #ifdef CINN_WITH_CUDA Target target = common::DefaultNVGPUTarget(); @@ -1432,7 +1436,7 @@ TEST(net_build, program_execute_one_hot) { scope->Var(std::string(off_value_input.id())); scope->Var(std::string(output->id)); - auto input_tensor = scope->GetTensor(std::string(input.id())); + auto input_tensor = scope->GetTensor(std::string(input.id())); const std::vector& intput_shape = input_tensor->shape().data(); SetRandInt(input_tensor, target); std::vector input_data = GetTensorData(input_tensor, target); @@ -1445,14 +1449,14 @@ TEST(net_build, program_execute_one_hot) { runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); + auto output_tensor = scope->GetTensor(std::string(output->id)); const std::vector& output_shape = output_tensor->shape().data(); - std::vector output_data = GetTensorData(output_tensor, target); + std::vector output_data = GetTensorData(output_tensor, target); EXPECT_EQ(output_tensor->type(), Int(32)); EXPECT_EQ(output_shape.size(), intput_shape.size() + 1); - const int true_axis = axis == -1 ? M : axis; + const int true_axis = axis == -1 ? M : axis; int input_shape_index = 0; for (int i = 0; i < output_shape.size(); i++) { @@ -1468,9 +1472,9 @@ TEST(net_build, program_execute_one_hot) { for (int j = 0; j < output_shape[1]; ++j) { for (int k = 0; k < output_shape[2]; ++k) { std::vector s = {i, j, k}; - int input_index = 0; - int output_index = 0; - int base = 1; + int input_index = 0; + int output_index = 0; + int base = 1; for (int x = s.size() - 1; x >= 0; --x) { if (x == true_axis) { diff --git a/paddle/cinn/frontend/op_mapper_registry.cc b/paddle/cinn/frontend/op_mapper_registry.cc index a38b3209b4e8c..883ac8104d9ae 100644 --- a/paddle/cinn/frontend/op_mapper_registry.cc +++ b/paddle/cinn/frontend/op_mapper_registry.cc @@ -19,39 +19,51 @@ namespace cinn { namespace frontend { -void OpMapperContext::AddVar(const std::string& origin_name, const Variable& var, bool can_inplace) const { +void OpMapperContext::AddVar(const std::string& origin_name, + const Variable& var, + bool can_inplace) const { CHECK(can_inplace || !var_map_->count(origin_name)) - << "Duplicate variable \"" << origin_name << "\" found, whose id is " << var_map_->at(origin_name)->id; + << "Duplicate variable \"" << origin_name << "\" found, whose id is " + << var_map_->at(origin_name)->id; if (var_map_->count(origin_name)) { - VLOG(1) << "The Paddle inplace output var \"" << origin_name + paddle::InplaceOutSuffix - << "\" is mapped to CINN var \"" << var->id << "\" with shape=[" << cinn::utils::Join(var->shape, ", ") - << "], dtype=" << var->type << ". The input var \"" << origin_name << "\" still mapped to \"" + VLOG(1) << "The Paddle inplace output var \"" + << origin_name + paddle::InplaceOutSuffix + << "\" is mapped to CINN var \"" << var->id << "\" with shape=[" + << cinn::utils::Join(var->shape, ", ") << "], dtype=" << var->type + << ". The input var \"" << origin_name << "\" still mapped to \"" << var_map_->at(origin_name)->id << "\""; } else { - VLOG(1) << "The Paddle var \"" << origin_name << "\" is mapped to CINN var \"" << var->id << "\" with shape=[" + VLOG(1) << "The Paddle var \"" << origin_name + << "\" is mapped to CINN var \"" << var->id << "\" with shape=[" << cinn::utils::Join(var->shape, ", ") << "], dtype=" << var->type; } (*var_map_)[origin_name] = var; } -void OpMapperContext::AddVarModelToProgram(const std::string& name, const std::string& id, bool can_inplace) const { - CHECK(!id.empty()) << "Paddle name [" << name << "]'s program id is empty ! Please check."; +void OpMapperContext::AddVarModelToProgram(const std::string& name, + const std::string& id, + bool can_inplace) const { + CHECK(!id.empty()) << "Paddle name [" << name + << "]'s program id is empty ! Please check."; if (!var_model_to_program_map_->count(name)) { (*var_model_to_program_map_)[name] = id; VLOG(4) << "Paddle name [" << name << "] map to program id " << id; } else { - CHECK(can_inplace) << "Duplicate variable [" << name << "] found, whose id is " + CHECK(can_inplace) << "Duplicate variable [" << name + << "] found, whose id is " << var_model_to_program_map_->at(name); - const auto& inplace_out_name = name + paddle::InplaceOutSuffix; + const auto& inplace_out_name = name + paddle::InplaceOutSuffix; (*var_model_to_program_map_)[inplace_out_name] = id; - VLOG(4) << "Paddle name [" << name << "] 's trick output [" << inplace_out_name << "] map to program id [" << id - << "]"; + VLOG(4) << "Paddle name [" << name << "] 's trick output [" + << inplace_out_name << "] map to program id [" << id << "]"; } } -void OpMapperContext::AddFetchVarName(const std::string& name) const { fetch_var_names_->insert(name); } +void OpMapperContext::AddFetchVarName(const std::string& name) const { + fetch_var_names_->insert(name); +} Variable OpMapperContext::GetVar(const std::string& origin_name) const { auto it = var_map_->find(origin_name); @@ -66,7 +78,7 @@ Variable OpMapperContext::GetVar(const std::string& origin_name) const { Variable local_var; local_var.set_id(name); local_var->shape = tensor->shape().data(); - local_var->type = tensor->type(); + local_var->type = tensor->type(); AddVar(origin_name, local_var); return local_var; } @@ -75,13 +87,17 @@ Variable OpMapperContext::GetVar(const std::string& origin_name) const { return Variable(); } -void OpMapperContext::AddFeedInfo(const std::string& name, const FeedInfo& info) { - CHECK(!feed_info_map_.count(name)) << "Duplicate variable info [" << name << "] found"; +void OpMapperContext::AddFeedInfo(const std::string& name, + const FeedInfo& info) { + CHECK(!feed_info_map_.count(name)) + << "Duplicate variable info [" << name << "] found"; feed_info_map_[name] = info; } -const OpMapperContext::FeedInfo& OpMapperContext::GetFeedInfo(const std::string& name) const { - CHECK(feed_info_map_.count(name)) << "No variable info called [" << name << "] exists"; +const OpMapperContext::FeedInfo& OpMapperContext::GetFeedInfo( + const std::string& name) const { + CHECK(feed_info_map_.count(name)) + << "No variable info called [" << name << "] exists"; return feed_info_map_.at(name); } diff --git a/paddle/cinn/frontend/op_mapper_registry.h b/paddle/cinn/frontend/op_mapper_registry.h index 51864d6bf2eec..9351e60e8ff70 100644 --- a/paddle/cinn/frontend/op_mapper_registry.h +++ b/paddle/cinn/frontend/op_mapper_registry.h @@ -40,7 +40,7 @@ namespace paddle { // same as Paddle's!!! The definition ref to // https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/operator.h#L97 inline std::string GradVarName(const std::string& var_name) { - constexpr char kGradVarSuffix[] = "@GRAD"; + constexpr char kGradVarSuffix[] = "@GRAD"; constexpr size_t kGradVarSuffixSize = 5U; std::string result; @@ -57,12 +57,13 @@ constexpr char InplaceOutSuffix[] = "@InplaceOut"; class OpMapperContext { public: - OpMapperContext(const hlir::framework::Scope& scope, - const common::Target& target, - NetBuilder* builder, - std::unordered_map* var_map, - std::unordered_map* var_model_to_program_map, - std::unordered_set* fetch_var_names) + OpMapperContext( + const hlir::framework::Scope& scope, + const common::Target& target, + NetBuilder* builder, + std::unordered_map* var_map, + std::unordered_map* var_model_to_program_map, + std::unordered_set* fetch_var_names) : scope_(scope), target_(target), builder_(builder), @@ -82,13 +83,17 @@ class OpMapperContext { NetBuilder* Builder() const { return builder_; } // add Variable into local var_map - void AddVar(const std::string& name, const Variable& var, bool can_inplace = true) const; + void AddVar(const std::string& name, + const Variable& var, + bool can_inplace = true) const; // get Variable from local var_map or scope Variable GetVar(const std::string& name) const; // add map from paddle name to cinn name into var_model_to_program_map - void AddVarModelToProgram(const std::string& name, const std::string& id, bool can_inplace = true) const; + void AddVarModelToProgram(const std::string& name, + const std::string& id, + bool can_inplace = true) const; void AddFetchVarName(const std::string& name) const; @@ -108,7 +113,8 @@ class OpMapperContext { std::unordered_map* var_map_{nullptr}; // map from var in Paddle model to var name in program. - std::unordered_map* var_model_to_program_map_{nullptr}; + std::unordered_map* var_model_to_program_map_{ + nullptr}; // fetch var names used in Paddle std::unordered_set* fetch_var_names_{nullptr}; @@ -117,7 +123,8 @@ class OpMapperContext { class OpMapper { public: - using OpMapperFunc = std::function; + using OpMapperFunc = + std::function; OpMapper() = default; @@ -125,7 +132,10 @@ class OpMapper { kernel_ = kernel; return *this; } - void Run(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) const { kernel_(op_desc, ctx); } + void Run(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) const { + kernel_(op_desc, ctx); + } std::string name; @@ -141,11 +151,14 @@ class OpMapperRegistry : public Registry { CINN_DISALLOW_COPY_AND_ASSIGN(OpMapperRegistry); }; -#define UNIQUE_OPMAPPER_NAME(OpName) static ::cinn::frontend::OpMapper& __op_mapper_registrar_##OpName +#define UNIQUE_OPMAPPER_NAME(OpName) \ + static ::cinn::frontend::OpMapper& __op_mapper_registrar_##OpName #define CINN_REGISTER_OP_MAPPER(OpName, Kernel) \ CINN_STR_CONCAT(UNIQUE_OPMAPPER_NAME(OpName), __COUNTER__) = \ - ::cinn::frontend::OpMapperRegistry::Global()->__REGISTER_OR_GET__(#OpName).Set(Kernel); + ::cinn::frontend::OpMapperRegistry::Global() \ + ->__REGISTER_OR_GET__(#OpName) \ + .Set(Kernel); } // namespace frontend } // namespace cinn diff --git a/paddle/cinn/frontend/op_mapper_registry_test.cc b/paddle/cinn/frontend/op_mapper_registry_test.cc index be9699b7053ad..5852bbfe51298 100644 --- a/paddle/cinn/frontend/op_mapper_registry_test.cc +++ b/paddle/cinn/frontend/op_mapper_registry_test.cc @@ -26,7 +26,8 @@ namespace frontend { TEST(OpMapperRegistryTest, list_all_opmappers) { auto all_opmappers_names = OpMapperRegistry::Global()->ListAllNames(); - LOG(INFO) << "Total has " << all_opmappers_names.size() << " registered OpMappers:\n" + LOG(INFO) << "Total has " << all_opmappers_names.size() + << " registered OpMappers:\n" << cinn::utils::Join(all_opmappers_names, ", "); ASSERT_FALSE(all_opmappers_names.empty()); } diff --git a/paddle/cinn/frontend/op_mappers/common_utils.h b/paddle/cinn/frontend/op_mappers/common_utils.h index 2ef4293192023..387a2c1fe7a8c 100644 --- a/paddle/cinn/frontend/op_mappers/common_utils.h +++ b/paddle/cinn/frontend/op_mappers/common_utils.h @@ -30,38 +30,46 @@ namespace frontend { namespace utils { template -inline T GetAttrOrDefault(const paddle::cpp::OpDesc& op_desc, const std::string& name, const T& default_value = T{}) { +inline T GetAttrOrDefault(const paddle::cpp::OpDesc& op_desc, + const std::string& name, + const T& default_value = T{}) { if (op_desc.HasAttr(name)) { return op_desc.GetAttr(name); } return default_value; } -#define EXPAND_SINGLE_NUM_TO_VECTOR(DATA_TYPE, ATTR_TYPE) \ - template <> \ - inline std::vector GetAttrOrDefault( \ - const paddle::cpp::OpDesc& op_desc, const std::string& name, const std::vector& default_value) { \ - if (op_desc.HasAttr(name)) { \ - auto attr_type = op_desc.GetAttrType(name); \ - using AttrType = paddle::cpp::OpDescAPI::AttrType; \ - switch (attr_type) { \ - case AttrType::ATTR_TYPE##S: \ - return op_desc.GetAttr>(name); \ - case AttrType::ATTR_TYPE: \ - return std::vector{op_desc.GetAttr(name)}; \ - default: \ - if (attr_type == AttrType::BOOLEANS) { \ - LOG(WARNING) << "Op \"" << op_desc.Type() << "\"'s attribute \"" << name << "\" should be " << #ATTR_TYPE \ - << "S, but here is BOOLEANS, considering the type of python empty list in cpp are BOOLEANS," \ - << " here we will return a empty vector."; \ - return {}; \ - } else { \ - LOG(FATAL) << "Op \"" << op_desc.Type() << "\"'s attribute \"" << name << "\" should be " << #ATTR_TYPE \ - << "S. But here " << static_cast(attr_type) << " Please Check!"; \ - } \ - } \ - } \ - return default_value; \ +#define EXPAND_SINGLE_NUM_TO_VECTOR(DATA_TYPE, ATTR_TYPE) \ + template <> \ + inline std::vector GetAttrOrDefault( \ + const paddle::cpp::OpDesc& op_desc, \ + const std::string& name, \ + const std::vector& default_value) { \ + if (op_desc.HasAttr(name)) { \ + auto attr_type = op_desc.GetAttrType(name); \ + using AttrType = paddle::cpp::OpDescAPI::AttrType; \ + switch (attr_type) { \ + case AttrType::ATTR_TYPE##S: \ + return op_desc.GetAttr>(name); \ + case AttrType::ATTR_TYPE: \ + return std::vector{op_desc.GetAttr(name)}; \ + default: \ + if (attr_type == AttrType::BOOLEANS) { \ + LOG(WARNING) << "Op \"" << op_desc.Type() << "\"'s attribute \"" \ + << name << "\" should be " << #ATTR_TYPE \ + << "S, but here is BOOLEANS, considering the type " \ + "of python empty list in cpp are BOOLEANS," \ + << " here we will return a empty vector."; \ + return {}; \ + } else { \ + LOG(FATAL) << "Op \"" << op_desc.Type() << "\"'s attribute \"" \ + << name << "\" should be " << #ATTR_TYPE \ + << "S. But here " << static_cast(attr_type) \ + << " Please Check!"; \ + } \ + } \ + } \ + return default_value; \ } EXPAND_SINGLE_NUM_TO_VECTOR(int, INT) @@ -72,7 +80,9 @@ EXPAND_SINGLE_NUM_TO_VECTOR(bool, BOOLEAN) #undef EXPAND_SINGLE_NUM_TO_VECTOR template <> -inline bool GetAttrOrDefault(const paddle::cpp::OpDesc& op_desc, const std::string& name, const bool& default_value) { +inline bool GetAttrOrDefault(const paddle::cpp::OpDesc& op_desc, + const std::string& name, + const bool& default_value) { if (op_desc.HasAttr(name)) { auto attr_type = op_desc.GetAttrType(name); using AttrType = paddle::cpp::OpDescAPI::AttrType; @@ -84,7 +94,8 @@ inline bool GetAttrOrDefault(const paddle::cpp::OpDesc& op_desc, const std::stri case AttrType::LONG: return static_cast(op_desc.GetAttr(name)); default: - LOG(FATAL) << "Op " << op_desc.Type() << "'s attribute " << name << " should be BOOLEAN. Please Check!"; + LOG(FATAL) << "Op " << op_desc.Type() << "'s attribute " << name + << " should be BOOLEAN. Please Check!"; } } return default_value; @@ -103,16 +114,18 @@ inline int64_t GetAttrOrDefault(const paddle::cpp::OpDesc& op_desc, case AttrType::INT: return static_cast(op_desc.GetAttr(name)); default: - LOG(FATAL) << "Op " << op_desc.Type() << "'s attribute " << name << " should be LONG. Please Check!"; + LOG(FATAL) << "Op " << op_desc.Type() << "'s attribute " << name + << " should be LONG. Please Check!"; } } return default_value; } template <> -inline std::vector GetAttrOrDefault(const paddle::cpp::OpDesc& op_desc, - const std::string& name, - const std::vector& default_value) { +inline std::vector GetAttrOrDefault( + const paddle::cpp::OpDesc& op_desc, + const std::string& name, + const std::vector& default_value) { if (op_desc.HasAttr(name)) { auto attr_type = op_desc.GetAttrType(name); using AttrType = paddle::cpp::OpDescAPI::AttrType; @@ -122,19 +135,23 @@ inline std::vector GetAttrOrDefault(const paddle::cpp::OpDesc& op_desc, case AttrType::LONG: return std::vector{GetAttrOrDefault(op_desc, name)}; case AttrType::INTS: { - const auto& ints_val = GetAttrOrDefault>(op_desc, name); + const auto& ints_val = + GetAttrOrDefault>(op_desc, name); return std::vector{ints_val.begin(), ints_val.end()}; } case AttrType::INT: return std::vector{GetAttrOrDefault(op_desc, name)}; case AttrType::BOOLEANS: { - LOG(WARNING) << "Op \"" << op_desc.Type() << "\"'s attribute \"" << name << "\" should be LONGS, " - << "but here is BOOLEANS, considering the type of python empty list in cpp are BOOLEANS, " + LOG(WARNING) << "Op \"" << op_desc.Type() << "\"'s attribute \"" << name + << "\" should be LONGS, " + << "but here is BOOLEANS, considering the type of python " + "empty list in cpp are BOOLEANS, " << "here we will return a empty vector."; return {}; } default: - LOG(FATAL) << "Op " << op_desc.Type() << "'s attribute " << name << " should be LONGS. Please Check!"; + LOG(FATAL) << "Op " << op_desc.Type() << "'s attribute " << name + << " should be LONGS. Please Check!"; } } return default_value; @@ -153,11 +170,12 @@ inline cinn::utils::DimType ToDimType(const T& val) { inline std::string GetPaddleDtype(const paddle::cpp::OpDesc& op_desc, const std::string& dtype_attr_name, paddle::cpp::VarDescAPI::Type default_dtype) { - auto dtype_id = GetAttrOrDefault(op_desc, dtype_attr_name, static_cast(default_dtype)); + auto dtype_id = GetAttrOrDefault( + op_desc, dtype_attr_name, static_cast(default_dtype)); if (dtype_id < 0) { return ""; } - auto dtype_pd = static_cast(dtype_id); + auto dtype_pd = static_cast(dtype_id); auto dtype_cinn = CppVarType2CommonType(dtype_pd); if (dtype_cinn.is_unk()) { return ""; diff --git a/paddle/cinn/frontend/op_mappers/paddle/arg_min_max.cc b/paddle/cinn/frontend/op_mappers/paddle/arg_min_max.cc index a7845b6557388..5c37eea4542cb 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/arg_min_max.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/arg_min_max.cc @@ -23,43 +23,58 @@ namespace paddle_mappers { enum class ArgType { ArgMax, ArgMin }; template -Variable ArgImpl(NetBuilder* builder, const Variable& x, int axis, bool keepdims); +Variable ArgImpl(NetBuilder* builder, + const Variable& x, + int axis, + bool keepdims); template <> -Variable ArgImpl(NetBuilder* builder, const Variable& x, int axis, bool keepdims) { +Variable ArgImpl(NetBuilder* builder, + const Variable& x, + int axis, + bool keepdims) { return builder->Argmax(x, axis, keepdims); } template <> -Variable ArgImpl(NetBuilder* builder, const Variable& x, int axis, bool keepdims) { +Variable ArgImpl(NetBuilder* builder, + const Variable& x, + int axis, + bool keepdims) { return builder->Argmin(x, axis, keepdims); } template -void ArgOpMapperHelper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ArgOpMapperHelper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); - CHECK(op_desc.HasAttr("axis")) << "Argmax/Argmin op should has attribute \"axis\"! Please check."; + CHECK(op_desc.HasAttr("axis")) + << "Argmax/Argmin op should has attribute \"axis\"! Please check."; auto keepdims = utils::GetAttrOrDefault(op_desc, "keepdims", false); - CHECK(op_desc.HasAttr("keepdims")) << "Argmax/Argmin op should has attribute \"keepdims\"! Please check."; + CHECK(op_desc.HasAttr("keepdims")) + << "Argmax/Argmin op should has attribute \"keepdims\"! Please check."; auto flatten = utils::GetAttrOrDefault(op_desc, "flatten", false); - CHECK(op_desc.HasAttr("flatten")) << "Argmax/Argmin op should has attribute \"flatten\"! Please check."; + CHECK(op_desc.HasAttr("flatten")) + << "Argmax/Argmin op should has attribute \"flatten\"! Please check."; - auto dtype = utils::GetPaddleDtype(op_desc, "dtype", paddle::cpp::VarDescAPI::Type::INT64); - CHECK(dtype == "int32" || dtype == "int64") << "the indices dtype must be int32 or int64, but got dtype = " << dtype; + auto dtype = utils::GetPaddleDtype( + op_desc, "dtype", paddle::cpp::VarDescAPI::Type::INT64); + CHECK(dtype == "int32" || dtype == "int64") + << "the indices dtype must be int32 or int64, but got dtype = " << dtype; int ndim = x->shape.size(); // If flatten = true, flatten x and do opration on axis 0. if (flatten) { - x = ctx.Builder()->Reshape(x, {-1}); + x = ctx.Builder()->Reshape(x, {-1}); axis = 0; ndim = x->shape.size(); } @@ -72,11 +87,13 @@ void ArgOpMapperHelper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext ctx.AddVarModelToProgram(out_name, out->id); } -void ArgMaxOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ArgMaxOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { ArgOpMapperHelper(op_desc, ctx); } -void ArgMinOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ArgMinOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { ArgOpMapperHelper(op_desc, ctx); } @@ -85,8 +102,10 @@ void ArgMinOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c } // namespace cinn CINN_REGISTER_HELPER(paddle_arg) { - CINN_REGISTER_OP_MAPPER(arg_max, cinn::frontend::paddle_mappers::ArgMaxOpMapper) - CINN_REGISTER_OP_MAPPER(arg_min, cinn::frontend::paddle_mappers::ArgMinOpMapper) + CINN_REGISTER_OP_MAPPER(arg_max, + cinn::frontend::paddle_mappers::ArgMaxOpMapper) + CINN_REGISTER_OP_MAPPER(arg_min, + cinn::frontend::paddle_mappers::ArgMinOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/argsort.cc b/paddle/cinn/frontend/op_mappers/paddle/argsort.cc index fe290a3476062..fb2e76d01ecd9 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/argsort.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/argsort.cc @@ -22,7 +22,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void ArgsortOpMapper(const paddle::cpp::OpDesc& op_desc, const cinn::frontend::OpMapperContext& ctx) { +void ArgsortOpMapper(const paddle::cpp::OpDesc& op_desc, + const cinn::frontend::OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); @@ -32,10 +33,12 @@ void ArgsortOpMapper(const paddle::cpp::OpDesc& op_desc, const cinn::frontend::O CHECK_EQ(op_desc.Output("Indices").size(), 1UL); auto indices_name = op_desc.Output("Indices").front(); - auto is_ascend = !(utils::GetAttrOrDefault(op_desc, "descending", false)); - auto axis = utils::ToDimType(utils::GetAttrOrDefault(op_desc, "axis", 0)); + auto is_ascend = + !(utils::GetAttrOrDefault(op_desc, "descending", false)); + auto axis = + utils::ToDimType(utils::GetAttrOrDefault(op_desc, "axis", 0)); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto out = ctx.Builder()->ArgSort(x, axis, is_ascend); auto idx = ctx.Builder()->Cast(out[0], "int64"); @@ -43,8 +46,9 @@ void ArgsortOpMapper(const paddle::cpp::OpDesc& op_desc, const cinn::frontend::O ctx.AddVarModelToProgram(indices_name, idx->id); // TODO: return the sorted tensor here. Now out[1] is a temporary tensor. - // this is because output 'Out' is never uesd in Paddle API, but CINN need to return 2 output vars - // to meet the op defination, this should be resolved after sort op restructured. + // this is because output 'Out' is never uesd in Paddle API, but CINN need to + // return 2 output vars to meet the op defination, this should be resolved + // after sort op restructured. ctx.AddVar(out_name, out[1]); ctx.AddVarModelToProgram(out_name, out[1]->id); } @@ -54,6 +58,7 @@ void ArgsortOpMapper(const paddle::cpp::OpDesc& op_desc, const cinn::frontend::O } // namespace cinn CINN_REGISTER_HELPER(paddle_argsort) { - CINN_REGISTER_OP_MAPPER(argsort, cinn::frontend::paddle_mappers::ArgsortOpMapper) + CINN_REGISTER_OP_MAPPER(argsort, + cinn::frontend::paddle_mappers::ArgsortOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/atan.cc b/paddle/cinn/frontend/op_mappers/paddle/atan.cc index dbae16696ab30..961e643ae88e1 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/atan.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/atan.cc @@ -22,7 +22,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void Atan2OpMapper(const paddle::cpp::OpDesc& op_desc, const cinn::frontend::OpMapperContext& ctx) { +void Atan2OpMapper(const paddle::cpp::OpDesc& op_desc, + const cinn::frontend::OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X1").size(), 1UL); auto x1_name = op_desc.Input("X1").front(); CHECK_EQ(op_desc.Input("X2").size(), 1UL); diff --git a/paddle/cinn/frontend/op_mappers/paddle/batchnorm.cc b/paddle/cinn/frontend/op_mappers/paddle/batchnorm.cc index 4d379b2c4208f..b99b13fdaaeb4 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/batchnorm.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/batchnorm.cc @@ -19,17 +19,21 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void BatchNormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { - auto add_output = [&op_desc, &ctx]( - const std::string& pd_param_name, const Variable& out, bool can_inplace = false) -> void { +void BatchNormOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { + auto add_output = [&op_desc, &ctx](const std::string& pd_param_name, + const Variable& out, + bool can_inplace = false) -> void { if (!op_desc.HasOutput(pd_param_name)) { - VLOG(4) << "Cannot find parameter " << pd_param_name << " in op " << op_desc.Type(); + VLOG(4) << "Cannot find parameter " << pd_param_name << " in op " + << op_desc.Type(); return; } CHECK_EQ(op_desc.Output(pd_param_name).size(), 1UL); auto output_name = op_desc.Output(pd_param_name).front(); - VLOG(4) << "The " << op_desc.Type() << "'s output " << pd_param_name << " is " << output_name; + VLOG(4) << "The " << op_desc.Type() << "'s output " << pd_param_name + << " is " << output_name; ctx.AddVar(output_name, out, can_inplace); ctx.AddVarModelToProgram(output_name, out->id, can_inplace); @@ -46,28 +50,42 @@ void BatchNormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext CHECK_EQ(op_desc.Input("Variance").size(), 1UL); auto variance_name = op_desc.Input("Variance").front(); - auto epsilon = utils::GetAttrOrDefault(op_desc, "epsilon", 1e-5f); - auto momentum = utils::GetAttrOrDefault(op_desc, "momentum", 0.9f); - auto data_layout = utils::GetAttrOrDefault(op_desc, "data_layout", "NCHW"); - auto x = ctx.GetVar(x_name); - auto scale = ctx.GetVar(scale_name); - auto bias = ctx.GetVar(bias_name); - auto mean = ctx.GetVar(mean_name); - auto variance = ctx.GetVar(variance_name); - - auto is_test = utils::GetAttrOrDefault(op_desc, "is_test", false); - auto trainable_stats = utils::GetAttrOrDefault(op_desc, "trainable_statistics", false); - auto use_global_stats = utils::GetAttrOrDefault(op_desc, "use_global_stats", false); - bool use_run_stat = (is_test && (!trainable_stats)) || use_global_stats; - - VLOG(4) << "Try compute batch_norm(X:" << x_name << ", Scale:" << scale_name << ", Bias:" << bias_name + auto epsilon = utils::GetAttrOrDefault(op_desc, "epsilon", 1e-5f); + auto momentum = utils::GetAttrOrDefault(op_desc, "momentum", 0.9f); + auto data_layout = + utils::GetAttrOrDefault(op_desc, "data_layout", "NCHW"); + auto x = ctx.GetVar(x_name); + auto scale = ctx.GetVar(scale_name); + auto bias = ctx.GetVar(bias_name); + auto mean = ctx.GetVar(mean_name); + auto variance = ctx.GetVar(variance_name); + + auto is_test = utils::GetAttrOrDefault(op_desc, "is_test", false); + auto trainable_stats = + utils::GetAttrOrDefault(op_desc, "trainable_statistics", false); + auto use_global_stats = + utils::GetAttrOrDefault(op_desc, "use_global_stats", false); + bool use_run_stat = (is_test && (!trainable_stats)) || use_global_stats; + + VLOG(4) << "Try compute batch_norm(X:" << x_name << ", Scale:" << scale_name + << ", Bias:" << bias_name << "," ", Mean:" - << mean_name << ", Variance:" << variance_name << ", epsilon=" << epsilon << ", momentum=" << momentum - << ", data_layout=" << data_layout << ", is_test=" << is_test << ", trainable_statistics=" << trainable_stats + << mean_name << ", Variance:" << variance_name + << ", epsilon=" << epsilon << ", momentum=" << momentum + << ", data_layout=" << data_layout << ", is_test=" << is_test + << ", trainable_statistics=" << trainable_stats << ", use_global_stats=" << use_global_stats << ")"; - auto outs = ctx.Builder()->BatchNorm(x, scale, bias, mean, variance, epsilon, momentum, data_layout, use_run_stat); + auto outs = ctx.Builder()->BatchNorm(x, + scale, + bias, + mean, + variance, + epsilon, + momentum, + data_layout, + use_run_stat); if (use_run_stat) { VLOG(4) << "Invoke batch_norm OpMapper with test mode"; @@ -78,15 +96,17 @@ void BatchNormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext add_output("SavedMean", save_mean); auto save_variance = ctx.Builder()->Identity(variance); add_output("SavedVariance", save_variance); - // Just for skip error of "Variable(batch_norm2d_0.w_2@InplaceOut) not applied in cinn" when run batchnorm in - // paddle, remove after inpace mechanism perfect. The value should shared memory with mean and variance. + // Just for skip error of "Variable(batch_norm2d_0.w_2@InplaceOut) not + // applied in cinn" when run batchnorm in paddle, remove after inpace + // mechanism perfect. The value should shared memory with mean and variance. auto mean_out = ctx.Builder()->Identity(mean); add_output("MeanOut", mean_out, true); auto variance_out = ctx.Builder()->Identity(variance); add_output("VarianceOut", variance_out, true); } else { VLOG(4) << "Invoke batch_norm OpMapper with train mode"; - CHECK_EQ(outs.size(), 5U) << "batch_norm in train mode should only has 5 output! Please check."; + CHECK_EQ(outs.size(), 5U) + << "batch_norm in train mode should only has 5 output! Please check."; add_output("Y", outs[0]); add_output("SavedMean", outs[1]); @@ -97,19 +117,23 @@ void BatchNormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext } } -void BatchNormGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void BatchNormGradOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { std::unordered_map input_names_map; - auto get_input_var = [&op_desc, &ctx, &input_names_map](const std::string& op_name) { - CHECK_EQ(op_desc.Input(op_name).size(), 1UL); - auto var_name = op_desc.Input(op_name).front(); - input_names_map.emplace(op_name, var_name); - return ctx.GetVar(var_name); - }; + auto get_input_var = + [&op_desc, &ctx, &input_names_map](const std::string& op_name) { + CHECK_EQ(op_desc.Input(op_name).size(), 1UL); + auto var_name = op_desc.Input(op_name).front(); + input_names_map.emplace(op_name, var_name); + return ctx.GetVar(var_name); + }; std::unordered_map output_names_map; - auto get_output_name = [&op_desc, &output_names_map](const std::string& op_name) -> std::string { + auto get_output_name = + [&op_desc, &output_names_map](const std::string& op_name) -> std::string { if (op_desc.Output(op_name).empty()) { - CHECK_NE(op_name, paddle::GradVarName("X")) << "The input X should not empty."; + CHECK_NE(op_name, paddle::GradVarName("X")) + << "The input X should not empty."; return ""; } @@ -119,33 +143,39 @@ void BatchNormGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperCon return var_name; }; - std::vector output_names = {get_output_name(paddle::GradVarName("X")), - get_output_name(paddle::GradVarName("Scale")), - get_output_name(paddle::GradVarName("Bias"))}; + std::vector output_names = { + get_output_name(paddle::GradVarName("X")), + get_output_name(paddle::GradVarName("Scale")), + get_output_name(paddle::GradVarName("Bias"))}; - auto x = get_input_var("X"); - auto dy = get_input_var(paddle::GradVarName("Y")); - auto scale = get_input_var("Scale"); - auto saved_mean = get_input_var("SavedMean"); + auto x = get_input_var("X"); + auto dy = get_input_var(paddle::GradVarName("Y")); + auto scale = get_input_var("Scale"); + auto saved_mean = get_input_var("SavedMean"); auto saved_variance = get_input_var("SavedVariance"); - auto data_layout = utils::GetAttrOrDefault(op_desc, "data_layout", "NCHW"); - auto epsilon = utils::GetAttrOrDefault(op_desc, "epsilon", 1e-5f); + auto data_layout = + utils::GetAttrOrDefault(op_desc, "data_layout", "NCHW"); + auto epsilon = utils::GetAttrOrDefault(op_desc, "epsilon", 1e-5f); - auto get_arg_debug_info = [](const std::unordered_map& names_map) { - std::string res; - for (const auto& pair : names_map) { - res.append(pair.first + ":" + pair.second + ", "); - } - return res; - }; + auto get_arg_debug_info = + [](const std::unordered_map& names_map) { + std::string res; + for (const auto& pair : names_map) { + res.append(pair.first + ":" + pair.second + ", "); + } + return res; + }; - VLOG(4) << "{" << get_arg_debug_info(output_names_map) << "} = batch_norm_grad(" - << get_arg_debug_info(input_names_map) << ", data_layout=" << data_layout << ", epsilon=" << epsilon << ")"; + VLOG(4) << "{" << get_arg_debug_info(output_names_map) + << "} = batch_norm_grad(" << get_arg_debug_info(input_names_map) + << ", data_layout=" << data_layout << ", epsilon=" << epsilon << ")"; // batch norm grad, output(grad_x, grad_scale, grad_bias) - auto outs = ctx.Builder()->BatchNormGrad(dy, x, scale, saved_mean, saved_variance, epsilon, data_layout); - CHECK_EQ(outs.size(), 3ul) << "batch_norm_grad APIs should return 3 Variable!"; + auto outs = ctx.Builder()->BatchNormGrad( + dy, x, scale, saved_mean, saved_variance, epsilon, data_layout); + CHECK_EQ(outs.size(), 3ul) + << "batch_norm_grad APIs should return 3 Variable!"; for (int i = 0; i < outs.size(); i++) { if (output_names[i].empty()) { @@ -162,7 +192,9 @@ void BatchNormGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperCon } // namespace cinn CINN_REGISTER_HELPER(paddle_batchnorm) { - CINN_REGISTER_OP_MAPPER(batch_norm, cinn::frontend::paddle_mappers::BatchNormOpMapper) - CINN_REGISTER_OP_MAPPER(batch_norm_grad, cinn::frontend::paddle_mappers::BatchNormGradOpMapper) + CINN_REGISTER_OP_MAPPER(batch_norm, + cinn::frontend::paddle_mappers::BatchNormOpMapper) + CINN_REGISTER_OP_MAPPER(batch_norm_grad, + cinn::frontend::paddle_mappers::BatchNormGradOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/binary.cc b/paddle/cinn/frontend/op_mappers/paddle/binary.cc index 98b550f89484c..02016c2835439 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/binary.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/binary.cc @@ -19,19 +19,20 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -#define BINARY_OPMAPPER_FUNCTION(OP_NAME) \ - void OP_NAME##OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { \ - CHECK_EQ(op_desc.Input("X").size(), 1UL); \ - auto x_name = op_desc.Input("X").front(); \ - CHECK_EQ(op_desc.Input("Y").size(), 1UL); \ - auto y_name = op_desc.Input("Y").front(); \ - CHECK_EQ(op_desc.Output("Out").size(), 1UL); \ - auto out_name = op_desc.Output("Out").front(); \ - auto x = ctx.GetVar(x_name); \ - auto y = ctx.GetVar(y_name); \ - auto out = ctx.Builder()->OP_NAME(x, y); \ - ctx.AddVar(out_name, out); \ - ctx.AddVarModelToProgram(out_name, out->id); \ +#define BINARY_OPMAPPER_FUNCTION(OP_NAME) \ + void OP_NAME##OpMapper(const paddle::cpp::OpDesc& op_desc, \ + const OpMapperContext& ctx) { \ + CHECK_EQ(op_desc.Input("X").size(), 1UL); \ + auto x_name = op_desc.Input("X").front(); \ + CHECK_EQ(op_desc.Input("Y").size(), 1UL); \ + auto y_name = op_desc.Input("Y").front(); \ + CHECK_EQ(op_desc.Output("Out").size(), 1UL); \ + auto out_name = op_desc.Output("Out").front(); \ + auto x = ctx.GetVar(x_name); \ + auto y = ctx.GetVar(y_name); \ + auto out = ctx.Builder()->OP_NAME(x, y); \ + ctx.AddVar(out_name, out); \ + ctx.AddVarModelToProgram(out_name, out->id); \ } BINARY_OPMAPPER_FUNCTION(LogicalAnd) @@ -49,7 +50,8 @@ BINARY_OPMAPPER_FUNCTION(BitwiseXor) CINN_REGISTER_HELPER(paddle_binary) { #define BINARY_OPMAPPER_REGISTER(PD_OP, CINN_OP) \ - CINN_REGISTER_OP_MAPPER(PD_OP, cinn::frontend::paddle_mappers::CINN_OP##OpMapper) + CINN_REGISTER_OP_MAPPER(PD_OP, \ + cinn::frontend::paddle_mappers::CINN_OP##OpMapper) BINARY_OPMAPPER_REGISTER(logical_and, LogicalAnd) BINARY_OPMAPPER_REGISTER(logical_or, LogicalOr) diff --git a/paddle/cinn/frontend/op_mappers/paddle/cholesky.cc b/paddle/cinn/frontend/op_mappers/paddle/cholesky.cc index 1364ed57962a5..51c8a2fb1dc53 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/cholesky.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/cholesky.cc @@ -19,7 +19,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void CholeskyOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void CholeskyOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); @@ -28,7 +29,7 @@ void CholeskyOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& auto upper = utils::GetAttrOrDefault(op_desc, "upper", false); VLOG(4) << out_name << " = cholesky(" << x_name << ", upper=" << upper << ")"; - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto out = ctx.Builder()->Cholesky(x, upper); ctx.AddVar(out_name, out); @@ -40,6 +41,7 @@ void CholeskyOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& } // namespace cinn CINN_REGISTER_HELPER(paddle_cholesky) { - CINN_REGISTER_OP_MAPPER(cholesky, cinn::frontend::paddle_mappers::CholeskyOpMapper) + CINN_REGISTER_OP_MAPPER(cholesky, + cinn::frontend::paddle_mappers::CholeskyOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/clip.cc b/paddle/cinn/frontend/op_mappers/paddle/clip.cc index bcd65a15ae700..1dc659b7410f4 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/clip.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/clip.cc @@ -19,17 +19,19 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void ClipOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ClipOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = ctx.GetVar(x_name); - auto builder = ctx.Builder(); + auto x = ctx.GetVar(x_name); + auto builder = ctx.Builder(); if (op_desc.HasInput("Min") && op_desc.Input("Min").size() > 0) { - CHECK_EQ(op_desc.Input("Min").size(), 1) << "clip op should have only one input for Min"; - auto min_val_name = op_desc.Input("Min").front(); + CHECK_EQ(op_desc.Input("Min").size(), 1) + << "clip op should have only one input for Min"; + auto min_val_name = op_desc.Input("Min").front(); auto min_val_tensor = ctx.GetVar(min_val_name); CHECK(min_val_tensor->shape == cinn::utils::ShapeType{1}) << "The [Min] tensor shape of clip op should be [1], but here [" @@ -38,18 +40,23 @@ void ClipOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx min_val_tensor = builder->Cast(min_val_tensor, common::Type2Str(x->type)); } min_val_tensor = builder->BroadcastTo(min_val_tensor, x->shape); - x = builder->Max(x, min_val_tensor); + x = builder->Max(x, min_val_tensor); } else { - CHECK(op_desc.HasAttr("min")) << "The clip op should has [min] attribute or [Min] tensor input."; + CHECK(op_desc.HasAttr("min")) + << "The clip op should has [min] attribute or [Min] tensor input."; auto min_value = op_desc.GetAttr("min"); auto min_val_tensor = - builder->FillConstant(x->shape, min_value, common::UniqName(x->id + "_min"), common::Type2Str(x->type)); + builder->FillConstant(x->shape, + min_value, + common::UniqName(x->id + "_min"), + common::Type2Str(x->type)); x = builder->Max(x, min_val_tensor); } if (op_desc.HasInput("Max") && op_desc.Input("Max").size() > 0) { - CHECK_EQ(op_desc.Input("Max").size(), 1) << "clip op should have only one input for Max"; - auto max_val_name = op_desc.Input("Max").front(); + CHECK_EQ(op_desc.Input("Max").size(), 1) + << "clip op should have only one input for Max"; + auto max_val_name = op_desc.Input("Max").front(); auto max_val_tensor = ctx.GetVar(max_val_name); CHECK(max_val_tensor->shape == cinn::utils::ShapeType{1}) << "The [Max] tensor shape of clip op should be [1], but here [" @@ -58,12 +65,15 @@ void ClipOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx max_val_tensor = builder->Cast(max_val_tensor, common::Type2Str(x->type)); } max_val_tensor = builder->BroadcastTo(max_val_tensor, x->shape); - x = builder->Min(x, max_val_tensor); + x = builder->Min(x, max_val_tensor); } else { - CHECK(op_desc.HasAttr("max")) << "The clip op should has [max] attribute or [Max] tensor input."; + CHECK(op_desc.HasAttr("max")) + << "The clip op should has [max] attribute or [Max] tensor input."; auto max_value = op_desc.GetAttr("max"); - auto max_val_tensor = - builder->FillConstant(x->shape, max_value, common::UniqName("constant"), common::Type2Str(x->type)); + auto max_val_tensor = builder->FillConstant(x->shape, + max_value, + common::UniqName("constant"), + common::Type2Str(x->type)); x = builder->Min(x, max_val_tensor); } diff --git a/paddle/cinn/frontend/op_mappers/paddle/compare.cc b/paddle/cinn/frontend/op_mappers/paddle/compare.cc index 158477a7eb0c8..ecf0313de6d50 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/compare.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/compare.cc @@ -28,25 +28,28 @@ static const std::string& GetCompareDebugString(const std::string& compare_op) { {"Equal", " == "}, {"NotEqual", " != "}, }; - CHECK_GT(compare_debug_map.count(compare_op), 0) << "Unsupported compare op " << compare_op; + CHECK_GT(compare_debug_map.count(compare_op), 0) + << "Unsupported compare op " << compare_op; return compare_debug_map[compare_op]; } -#define COMPARE_OPMAPPER_FUNCTION(OP_NAME) \ - void OP_NAME##OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { \ - CHECK_EQ(op_desc.Input("X").size(), 1UL); \ - auto x_name = op_desc.Input("X").front(); \ - CHECK_EQ(op_desc.Input("Y").size(), 1UL); \ - auto y_name = op_desc.Input("Y").front(); \ - CHECK_EQ(op_desc.Output("Out").size(), 1UL); \ - auto out_name = op_desc.Output("Out").front(); \ - auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); \ - VLOG(4) << out_name << " = " << x_name << GetCompareDebugString(#OP_NAME) << y_name << " at " << axis; \ - auto x = ctx.GetVar(x_name); \ - auto y = ctx.GetVar(y_name); \ - auto out = ctx.Builder()->OP_NAME(x, y, axis); \ - ctx.AddVar(out_name, out); \ - ctx.AddVarModelToProgram(out_name, out->id); \ +#define COMPARE_OPMAPPER_FUNCTION(OP_NAME) \ + void OP_NAME##OpMapper(const paddle::cpp::OpDesc& op_desc, \ + const OpMapperContext& ctx) { \ + CHECK_EQ(op_desc.Input("X").size(), 1UL); \ + auto x_name = op_desc.Input("X").front(); \ + CHECK_EQ(op_desc.Input("Y").size(), 1UL); \ + auto y_name = op_desc.Input("Y").front(); \ + CHECK_EQ(op_desc.Output("Out").size(), 1UL); \ + auto out_name = op_desc.Output("Out").front(); \ + auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); \ + VLOG(4) << out_name << " = " << x_name << GetCompareDebugString(#OP_NAME) \ + << y_name << " at " << axis; \ + auto x = ctx.GetVar(x_name); \ + auto y = ctx.GetVar(y_name); \ + auto out = ctx.Builder()->OP_NAME(x, y, axis); \ + ctx.AddVar(out_name, out); \ + ctx.AddVarModelToProgram(out_name, out->id); \ } COMPARE_OPMAPPER_FUNCTION(GreaterThan) @@ -64,7 +67,8 @@ COMPARE_OPMAPPER_FUNCTION(NotEqual) CINN_REGISTER_HELPER(paddle_compare) { #define COMPARE_OPMAPPER_REGISTER(PD_OP, CINN_OP) \ - CINN_REGISTER_OP_MAPPER(PD_OP, cinn::frontend::paddle_mappers::CINN_OP##OpMapper) + CINN_REGISTER_OP_MAPPER(PD_OP, \ + cinn::frontend::paddle_mappers::CINN_OP##OpMapper) COMPARE_OPMAPPER_REGISTER(greater_than, GreaterThan) COMPARE_OPMAPPER_REGISTER(greater_equal, GreaterEqual) diff --git a/paddle/cinn/frontend/op_mappers/paddle/concat.cc b/paddle/cinn/frontend/op_mappers/paddle/concat.cc index 61a84a8aa4789..48518898e7c33 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/concat.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/concat.cc @@ -21,7 +21,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void ConcatOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ConcatOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_GE(op_desc.Input("X").size(), 1UL); auto x_names = op_desc.Input("X"); CHECK_EQ(op_desc.Output("Out").size(), 1UL); @@ -34,10 +35,14 @@ void ConcatOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c xs.emplace_back(ctx.GetVar(name)); } - auto err_x = std::find_if(xs.begin(), xs.end(), [&](Variable x) { return x->type != xs.front()->type; }); - CHECK(err_x == xs.end()) << "All input's dtype of [concat] should be the same, be the input " << (*err_x)->id - << "'s dtype [" << (*err_x)->type << "] not equal to the first input " << xs.front()->id - << "'s dtype [" << xs.front()->type << "]"; + auto err_x = std::find_if(xs.begin(), xs.end(), [&](Variable x) { + return x->type != xs.front()->type; + }); + CHECK(err_x == xs.end()) + << "All input's dtype of [concat] should be the same, be the input " + << (*err_x)->id << "'s dtype [" << (*err_x)->type + << "] not equal to the first input " << xs.front()->id << "'s dtype [" + << xs.front()->type << "]"; auto out = ctx.Builder()->Concat(xs, axis); @@ -45,7 +50,8 @@ void ConcatOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c ctx.AddVarModelToProgram(out_name, out->id); } -void StackOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void StackOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_GE(op_desc.Input("X").size(), 1UL); auto x_names = op_desc.Input("X"); @@ -57,7 +63,8 @@ void StackOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ct CHECK_EQ(op_desc.Output("Y").size(), 1UL); out_name = op_desc.Output("Y").front(); } else { - LOG(FATAL) << "The output argument name of [stack] should be 'Out' or 'Y', but here cannot found! Please check."; + LOG(FATAL) << "The output argument name of [stack] should be 'Out' or 'Y', " + "but here cannot found! Please check."; } auto axis = utils::GetAttrOrDefault(op_desc, "axis", 0); @@ -67,22 +74,31 @@ void StackOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ct xs.emplace_back(ctx.GetVar(name)); } - auto err_x = std::find_if(xs.begin(), xs.end(), [&](Variable x) { return x->type != xs.front()->type; }); - CHECK(err_x == xs.end()) << "All input's dtype of [concat] should be the same, be the input " << (*err_x)->id - << "'s dtype [" << (*err_x)->type << "] not equal to the first input " << xs.front()->id - << "'s dtype [" << xs.front()->type << "]"; - - err_x = std::find_if(xs.begin(), xs.end(), [&](Variable x) { return x->shape != xs.front()->shape; }); - CHECK(err_x == xs.end()) << "All input shape of [stack] should be the same, be the input " << (*err_x)->id - << "'s shape [" << cinn::utils::Join((*err_x)->shape, ", ") << "] not equal to " - << "the first input " << xs.front()->id << "'s shape [" - << cinn::utils::Join(xs.front()->shape, ", ") << "]"; + auto err_x = std::find_if(xs.begin(), xs.end(), [&](Variable x) { + return x->type != xs.front()->type; + }); + CHECK(err_x == xs.end()) + << "All input's dtype of [concat] should be the same, be the input " + << (*err_x)->id << "'s dtype [" << (*err_x)->type + << "] not equal to the first input " << xs.front()->id << "'s dtype [" + << xs.front()->type << "]"; + + err_x = std::find_if(xs.begin(), xs.end(), [&](Variable x) { + return x->shape != xs.front()->shape; + }); + CHECK(err_x == xs.end()) + << "All input shape of [stack] should be the same, be the input " + << (*err_x)->id << "'s shape [" + << cinn::utils::Join((*err_x)->shape, ", ") << "] not equal to " + << "the first input " << xs.front()->id << "'s shape [" + << cinn::utils::Join(xs.front()->shape, ", ") << "]"; auto concat_out = ctx.Builder()->Concat(xs, axis); int rank = concat_out->shape.size(); - axis = axis >= 0 ? axis : axis + rank; - CHECK(axis >= 0 && axis < rank) << "The axis of stack should >=0 and = 0 ? axis : axis + rank; + CHECK(axis >= 0 && axis < rank) + << "The axis of stack should >=0 and [N, A, B]; N * [A, B] with axis=1 --> [A, N, B]; cinn::utils::ShapeType new_shape; @@ -102,54 +118,64 @@ void StackOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ct ctx.AddVarModelToProgram(out_name, out->id); } -void SplitOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void SplitOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); ; CHECK_GE(op_desc.Output("Out").size(), 1UL); auto out_names = op_desc.Output("Out"); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); int rank = x->shape.size(); auto axis = utils::GetAttrOrDefault(op_desc, "axis", 0); - CHECK(axis >= -rank && axis < rank) << "The [axis] should in [-" << rank << ", " << rank << "), but here is " << axis; + CHECK(axis >= -rank && axis < rank) + << "The [axis] should in [-" << rank << ", " << rank << "), but here is " + << axis; if (axis < 0) { axis += rank; } - auto num = utils::GetAttrOrDefault(op_desc, "num", 0); - auto sections = utils::GetAttrOrDefault>(op_desc, "sections"); + auto num = utils::GetAttrOrDefault(op_desc, "num", 0); + auto sections = + utils::GetAttrOrDefault>(op_desc, "sections"); auto dim = x->shape[axis]; - CHECK(num != 0 || !sections.empty()) << "The [num_or_sections] in split op should not empty! Please check."; + CHECK(num != 0 || !sections.empty()) + << "The [num_or_sections] in split op should not empty! Please check."; if (num != 0) { - CHECK(dim % num == 0) << "The num_or_sections:" << num << " cannot divided by the split axis:" << axis + CHECK(dim % num == 0) << "The num_or_sections:" << num + << " cannot divided by the split axis:" << axis << " 's dimension:" << dim; sections.clear(); sections.resize(num, dim / num); } CHECK_EQ(sections.size(), out_names.size()) - << "The output number of split op should be " << sections.size() << ", but actual " << out_names.size(); + << "The output number of split op should be " << sections.size() + << ", but actual " << out_names.size(); int neg_idx = -1, sum = 0; for (int i = 0; i < sections.size(); ++i) { if (sections[i] < 0) { - CHECK_LT(neg_idx, 0) << "The [num_or_sections] should only has one -1! But here " - << cinn::utils::Join(sections, ", "); + CHECK_LT(neg_idx, 0) + << "The [num_or_sections] should only has one -1! But here " + << cinn::utils::Join(sections, ", "); neg_idx = i; } else { sum += sections[i]; } } if (neg_idx > 0) { - CHECK_LT(sum, dim) << "The sum of [num_or_sections] should less than to the dimension of split [axis] when -1 " + CHECK_LT(sum, dim) << "The sum of [num_or_sections] should less than to " + "the dimension of split [axis] when -1 " "found in [num_or_sections]! But here " << cinn::utils::Join(sections, ", "); sections[neg_idx] = dim - sum; } else { - CHECK_EQ(sum, dim) << "The sum of [num_or_sections] should equal to the dimension of split [axis]! But here " + CHECK_EQ(sum, dim) << "The sum of [num_or_sections] should equal to the " + "dimension of split [axis]! But here " << cinn::utils::Join(sections, ", "); } @@ -166,7 +192,8 @@ void SplitOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ct } // namespace cinn CINN_REGISTER_HELPER(paddle_concat) { - CINN_REGISTER_OP_MAPPER(concat, cinn::frontend::paddle_mappers::ConcatOpMapper) + CINN_REGISTER_OP_MAPPER(concat, + cinn::frontend::paddle_mappers::ConcatOpMapper) CINN_REGISTER_OP_MAPPER(stack, cinn::frontend::paddle_mappers::StackOpMapper) CINN_REGISTER_OP_MAPPER(split, cinn::frontend::paddle_mappers::SplitOpMapper) return true; diff --git a/paddle/cinn/frontend/op_mappers/paddle/constant.cc b/paddle/cinn/frontend/op_mappers/paddle/constant.cc index e4c2b305a0806..8f38bb4ee9034 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/constant.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/constant.cc @@ -28,58 +28,69 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void AssignOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void AssignOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto out = ctx.Builder()->Identity(x); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); } -void ShapeOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ShapeOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("Input").size(), 1UL); auto x_name = op_desc.Input("Input").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = ctx.GetVar(x_name); - auto out = ctx.Builder()->Constant(x->shape, cinn::utils::TransValidVarName(out_name)); + auto x = ctx.GetVar(x_name); + auto out = ctx.Builder()->Constant(x->shape, + cinn::utils::TransValidVarName(out_name)); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); } -void FillConstantOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void FillConstantOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto y_name = op_desc.Output("Out").front(); const auto& cinn_name = cinn::utils::TransValidVarName(y_name); CheckVarNameValid(cinn_name); - auto shape = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "shape")); - auto value = utils::GetAttrOrDefault(op_desc, "value", 0.0f); - auto str_value = utils::GetAttrOrDefault(op_desc, "str_value", ""); + auto shape = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "shape")); + auto value = utils::GetAttrOrDefault(op_desc, "value", 0.0f); + auto str_value = + utils::GetAttrOrDefault(op_desc, "str_value", ""); auto force_cpu = utils::GetAttrOrDefault(op_desc, "force_cpu", false); - auto dtype = utils::GetPaddleDtype(op_desc, "dtype", paddle::cpp::VarDescAPI::Type::FP32); - CHECK(!dtype.empty()) << "The op \"fill_constant\"'s attribute \"dtype\" should not be unknown type! Please check."; + auto dtype = utils::GetPaddleDtype( + op_desc, "dtype", paddle::cpp::VarDescAPI::Type::FP32); + CHECK(!dtype.empty()) << "The op \"fill_constant\"'s attribute \"dtype\" " + "should not be unknown type! Please check."; absl::optional out; - if (op_desc.HasInput("ValueTensor") && !op_desc.Input("ValueTensor").empty()) { + if (op_desc.HasInput("ValueTensor") && + !op_desc.Input("ValueTensor").empty()) { CHECK_EQ(op_desc.Input("ValueTensor").size(), 1UL); - auto value_name = op_desc.Input("ValueTensor").front(); + auto value_name = op_desc.Input("ValueTensor").front(); auto value_tensor = ctx.GetVar(value_name); - VLOG(4) << "fill constant " << value_name << "=" << value_tensor << " with shape (" << cinn::utils::Join(shape, ",") + VLOG(4) << "fill constant " << value_name << "=" << value_tensor + << " with shape (" << cinn::utils::Join(shape, ",") << ") and dtype [" << dtype << "]"; - CHECK(value_tensor->shape == cinn::utils::ShapeType{1}) << "The shape of [ValueTensor] should be [1], but here [" - << cinn::utils::Join(value_tensor->shape, ", ") << "]"; + CHECK(value_tensor->shape == cinn::utils::ShapeType{1}) + << "The shape of [ValueTensor] should be [1], but here [" + << cinn::utils::Join(value_tensor->shape, ", ") << "]"; if (common::Type2Str(value_tensor->type) != dtype) { value_tensor = ctx.Builder()->Cast(value_tensor, dtype); } @@ -87,13 +98,17 @@ void FillConstantOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperCont out.value().set_id(cinn_name); } else { if (!str_value.empty()) { - VLOG(4) << "fill constant (" << str_value << ") with shape (" << cinn::utils::Join(shape, ",") << ") and dtype [" - << dtype << "]"; - out = ctx.Builder()->FillConstant(shape, str_value, cinn_name, dtype, force_cpu); + VLOG(4) << "fill constant (" << str_value << ") with shape (" + << cinn::utils::Join(shape, ",") << ") and dtype [" << dtype + << "]"; + out = ctx.Builder()->FillConstant( + shape, str_value, cinn_name, dtype, force_cpu); } else { - VLOG(4) << "fill constant (" << value << ") with shape (" << cinn::utils::Join(shape, ",") << ") and dtype [" - << dtype << "]"; - out = ctx.Builder()->FillConstant(shape, value, cinn_name, dtype, force_cpu); + VLOG(4) << "fill constant (" << value << ") with shape (" + << cinn::utils::Join(shape, ",") << ") and dtype [" << dtype + << "]"; + out = ctx.Builder()->FillConstant( + shape, value, cinn_name, dtype, force_cpu); } } @@ -101,10 +116,11 @@ void FillConstantOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperCont ctx.AddVarModelToProgram(y_name, out.value()->id); } -void FillAnyLikeOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void FillAnyLikeOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto y_name = op_desc.Output("Out").front(); @@ -112,13 +128,14 @@ void FillAnyLikeOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperConte auto shape = utils::ToShapeType(x->shape); auto value = utils::GetAttrOrDefault(op_desc, "value"); - auto dtype = utils::GetPaddleDtype(op_desc, "dtype", paddle::cpp::VarDescAPI::Type::FP32); + auto dtype = utils::GetPaddleDtype( + op_desc, "dtype", paddle::cpp::VarDescAPI::Type::FP32); if (dtype.empty()) { dtype = common::Type2Str(x->type); } - VLOG(4) << "FillAnyLikeOp: fill constant (" << value << ") with shape (" << cinn::utils::Join(shape, ", ") - << ") and dtype [" << dtype << "]"; + VLOG(4) << "FillAnyLikeOp: fill constant (" << value << ") with shape (" + << cinn::utils::Join(shape, ", ") << ") and dtype [" << dtype << "]"; const auto& cinn_name = cinn::utils::TransValidVarName(y_name); CheckVarNameValid(cinn_name); @@ -144,77 +161,98 @@ std::pair IsArithmeticSequence(const std::vector& vec) { return {true, first_diff}; } -void AssignValueOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void AssignValueOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Output("Out").size(), 1UL); - auto out_name = op_desc.Output("Out").front(); + auto out_name = op_desc.Output("Out").front(); const auto& cinn_out_name = cinn::utils::TransValidVarName(out_name); - const auto& bool_values_tmp = utils::GetAttrOrDefault>(op_desc, "bool_values"); + const auto& bool_values_tmp = + utils::GetAttrOrDefault>(op_desc, "bool_values"); std::vector bool_values; if (!bool_values_tmp.empty()) { - std::transform(bool_values_tmp.begin(), bool_values_tmp.end(), std::back_inserter(bool_values), [](int x) { - return static_cast(x); - }); + std::transform(bool_values_tmp.begin(), + bool_values_tmp.end(), + std::back_inserter(bool_values), + [](int x) { return static_cast(x); }); } - const auto& fp32_values = utils::GetAttrOrDefault>(op_desc, "fp32_values"); - const auto& int32_values = utils::GetAttrOrDefault>(op_desc, "int32_values"); - const auto& int64_values = utils::GetAttrOrDefault>(op_desc, "int64_values"); + const auto& fp32_values = + utils::GetAttrOrDefault>(op_desc, "fp32_values"); + const auto& int32_values = + utils::GetAttrOrDefault>(op_desc, "int32_values"); + const auto& int64_values = + utils::GetAttrOrDefault>(op_desc, "int64_values"); absl::optional out; if (!bool_values.empty()) { - VLOG(4) << "The input of assign_value is [" << cinn::utils::Join(bool_values, ", ") << "]"; + VLOG(4) << "The input of assign_value is [" + << cinn::utils::Join(bool_values, ", ") << "]"; out = ctx.Builder()->Constant(bool_values, cinn_out_name); } else if (!fp32_values.empty()) { - VLOG(4) << "The input of assign_value is [" << cinn::utils::Join(fp32_values, ", ") << "]"; + VLOG(4) << "The input of assign_value is [" + << cinn::utils::Join(fp32_values, ", ") << "]"; auto adj_diff = IsArithmeticSequence(fp32_values); if (adj_diff.first) { - VLOG(4) << "The input of assign_value is a arithmetic sequence. Using Arange instead of Constant."; - auto epsilone = - adj_diff.second > 0 ? std::numeric_limits::epsilon() : -std::numeric_limits::epsilon(); - - out = ctx.Builder()->Arange(fp32_values.front(), fp32_values.back() + epsilone, adj_diff.second, "float32"); + VLOG(4) << "The input of assign_value is a arithmetic sequence. Using " + "Arange instead of Constant."; + auto epsilone = adj_diff.second > 0 + ? std::numeric_limits::epsilon() + : -std::numeric_limits::epsilon(); + + out = ctx.Builder()->Arange(fp32_values.front(), + fp32_values.back() + epsilone, + adj_diff.second, + "float32"); } else { out = ctx.Builder()->Constant(fp32_values, cinn_out_name); } } else if (!int32_values.empty()) { - VLOG(4) << "The input of assign_value is [" << cinn::utils::Join(int32_values, ", ") << "]"; + VLOG(4) << "The input of assign_value is [" + << cinn::utils::Join(int32_values, ", ") << "]"; auto adj_diff = IsArithmeticSequence(int32_values); if (adj_diff.first) { - VLOG(4) << "The input of assign_value is a arithmetic sequence. Using Arange instead of Constant."; + VLOG(4) << "The input of assign_value is a arithmetic sequence. Using " + "Arange instead of Constant."; auto epsilone = adj_diff.second > 0 ? 1 : -1; - out = ctx.Builder()->Arange(static_cast(int32_values.front()), - static_cast(int32_values.back() + epsilone), - static_cast(adj_diff.second), - "int32"); + out = ctx.Builder()->Arange( + static_cast(int32_values.front()), + static_cast(int32_values.back() + epsilone), + static_cast(adj_diff.second), + "int32"); } else { out = ctx.Builder()->Constant(int32_values, cinn_out_name); } } else if (!int64_values.empty()) { - VLOG(4) << "The input of assign_value is [" << cinn::utils::Join(int64_values, ", ") << "]"; + VLOG(4) << "The input of assign_value is [" + << cinn::utils::Join(int64_values, ", ") << "]"; auto adj_diff = IsArithmeticSequence(int64_values); if (adj_diff.first) { - VLOG(4) << "The input of assign_value is a arithmetic sequence. Using Arange instead of Constant."; + VLOG(4) << "The input of assign_value is a arithmetic sequence. Using " + "Arange instead of Constant."; auto epsilone = adj_diff.second > 0 ? 1 : -1; - out = ctx.Builder()->Arange(static_cast(int64_values.front()), - static_cast(int64_values.back() + epsilone), - static_cast(adj_diff.second), - "int64"); + out = ctx.Builder()->Arange( + static_cast(int64_values.front()), + static_cast(int64_values.back() + epsilone), + static_cast(adj_diff.second), + "int64"); } else { out = ctx.Builder()->Constant(int64_values, cinn_out_name); } } - CHECK(out) << "assign_value's input should not empty, but " << out_name << "not! Please check."; - const auto& shape = utils::GetAttrOrDefault>(op_desc, "shape", out.value()->shape); + CHECK(out) << "assign_value's input should not empty, but " << out_name + << "not! Please check."; + const auto& shape = utils::GetAttrOrDefault>( + op_desc, "shape", out.value()->shape); if (shape != out.value()->shape) { out = ctx.Builder()->Reshape(out.value(), shape); } @@ -228,11 +266,15 @@ void AssignValueOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperConte } // namespace cinn CINN_REGISTER_HELPER(paddle_constant) { - CINN_REGISTER_OP_MAPPER(assign, cinn::frontend::paddle_mappers::AssignOpMapper) + CINN_REGISTER_OP_MAPPER(assign, + cinn::frontend::paddle_mappers::AssignOpMapper) CINN_REGISTER_OP_MAPPER(shape, cinn::frontend::paddle_mappers::ShapeOpMapper) - CINN_REGISTER_OP_MAPPER(fill_constant, cinn::frontend::paddle_mappers::FillConstantOpMapper) - CINN_REGISTER_OP_MAPPER(fill_any_like, cinn::frontend::paddle_mappers::FillAnyLikeOpMapper) - CINN_REGISTER_OP_MAPPER(assign_value, cinn::frontend::paddle_mappers::AssignValueOpMapper) + CINN_REGISTER_OP_MAPPER(fill_constant, + cinn::frontend::paddle_mappers::FillConstantOpMapper) + CINN_REGISTER_OP_MAPPER(fill_any_like, + cinn::frontend::paddle_mappers::FillAnyLikeOpMapper) + CINN_REGISTER_OP_MAPPER(assign_value, + cinn::frontend::paddle_mappers::AssignValueOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/conv2d.cc b/paddle/cinn/frontend/op_mappers/paddle/conv2d.cc index df524f5599fe7..21f1645752ffb 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/conv2d.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/conv2d.cc @@ -20,7 +20,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void Conv2dOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void Conv2dOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("Input").size(), 1UL); auto x_name = op_desc.Input("Input").front(); CHECK_EQ(op_desc.Input("Filter").size(), 1UL); @@ -29,35 +30,51 @@ void Conv2dOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c CHECK_EQ(op_desc.Output("Output").size(), 1UL); auto out_name = op_desc.Output("Output").front(); - auto strides = utils::GetAttrOrDefault>(op_desc, "strides", {1, 1}); - auto paddings = utils::GetAttrOrDefault>(op_desc, "paddings", {0, 0}); - auto dilations = utils::GetAttrOrDefault>(op_desc, "dilations", {1, 1}); - auto groups = utils::GetAttrOrDefault(op_desc, "groups", 1); + auto strides = + utils::GetAttrOrDefault>(op_desc, "strides", {1, 1}); + auto paddings = + utils::GetAttrOrDefault>(op_desc, "paddings", {0, 0}); + auto dilations = + utils::GetAttrOrDefault>(op_desc, "dilations", {1, 1}); + auto groups = utils::GetAttrOrDefault(op_desc, "groups", 1); - auto data_format = utils::GetAttrOrDefault(op_desc, "data_format", "AnyLayout"); + auto data_format = + utils::GetAttrOrDefault(op_desc, "data_format", "AnyLayout"); if (data_format == "AnyLayout") { data_format = "NCHW"; } - auto padding_algorithm = utils::GetAttrOrDefault(op_desc, "padding_algorithm", "EXPLICIT"); - auto x = ctx.GetVar(x_name); - auto y = ctx.GetVar(y_name); + auto padding_algorithm = utils::GetAttrOrDefault( + op_desc, "padding_algorithm", "EXPLICIT"); + auto x = ctx.GetVar(x_name); + auto y = ctx.GetVar(y_name); - CHECK_EQ(x->shape.size(), 4) << "CINN conv2d operator only support 4-D tensor now, but Input's shape is [" + CHECK_EQ(x->shape.size(), 4) << "CINN conv2d operator only support 4-D " + "tensor now, but Input's shape is [" << cinn::utils::Join(x->shape, ", ") << "]"; - CHECK_EQ(y->shape.size(), 4) << "CINN conv2d operator only support 4-D tensor now, but Filter's shape is [" + CHECK_EQ(y->shape.size(), 4) << "CINN conv2d operator only support 4-D " + "tensor now, but Filter's shape is [" << cinn::utils::Join(y->shape, ", ") << "]"; if (data_format == "NHWC") { - // the weight in paddle always be NCHW, but cudnn need the same as input, transpose before + // the weight in paddle always be NCHW, but cudnn need the same as input, + // transpose before y = ctx.Builder()->Transpose(y, {0, 2, 3, 1}); } - auto out = ctx.Builder()->Conv2d(x, y, strides, paddings, dilations, groups, data_format, padding_algorithm); + auto out = ctx.Builder()->Conv2d(x, + y, + strides, + paddings, + dilations, + groups, + data_format, + padding_algorithm); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); } -void DepthwiseConv2dOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void DepthwiseConv2dOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("Input").size(), 1UL); auto x_name = op_desc.Input("Input").front(); CHECK_EQ(op_desc.Input("Filter").size(), 1UL); @@ -66,32 +83,52 @@ void DepthwiseConv2dOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperC CHECK_EQ(op_desc.Output("Output").size(), 1UL); auto out_name = op_desc.Output("Output").front(); - auto strides = utils::GetAttrOrDefault>(op_desc, "strides", {1, 1}); - auto paddings = utils::GetAttrOrDefault>(op_desc, "paddings", {0, 0}); - auto dilations = utils::GetAttrOrDefault>(op_desc, "dilations", {1, 1}); - auto groups = utils::GetAttrOrDefault(op_desc, "groups", 1); + auto strides = + utils::GetAttrOrDefault>(op_desc, "strides", {1, 1}); + auto paddings = + utils::GetAttrOrDefault>(op_desc, "paddings", {0, 0}); + auto dilations = + utils::GetAttrOrDefault>(op_desc, "dilations", {1, 1}); + auto groups = utils::GetAttrOrDefault(op_desc, "groups", 1); - auto data_format = utils::GetAttrOrDefault(op_desc, "data_format", "NCHW"); + auto data_format = + utils::GetAttrOrDefault(op_desc, "data_format", "NCHW"); if (data_format == "AnyLayout") { data_format = "NCHW"; } - auto padding_algorithm = utils::GetAttrOrDefault(op_desc, "padding_algorithm", "EXPLICIT"); - auto x = ctx.GetVar(x_name); - auto y = ctx.GetVar(y_name); + auto padding_algorithm = utils::GetAttrOrDefault( + op_desc, "padding_algorithm", "EXPLICIT"); + auto x = ctx.GetVar(x_name); + auto y = ctx.GetVar(y_name); Variable out; if (ctx.Target().arch == Target::Arch::X86) { - out = ctx.Builder()->Conv2d(x, y, strides, paddings, dilations, groups, data_format, padding_algorithm); + out = ctx.Builder()->Conv2d(x, + y, + strides, + paddings, + dilations, + groups, + data_format, + padding_algorithm); } else { - out = ctx.Builder()->DepthwiseConv2d(x, y, strides, paddings, dilations, groups, data_format, padding_algorithm); + out = ctx.Builder()->DepthwiseConv2d(x, + y, + strides, + paddings, + dilations, + groups, + data_format, + padding_algorithm); } ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); } -void Conv2dGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void Conv2dGradOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { // get dy CHECK_EQ(op_desc.Input(paddle::GradVarName("Output")).size(), 1UL); auto dy_name = op_desc.Input(paddle::GradVarName("Output")).front(); @@ -113,49 +150,76 @@ void Conv2dGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContex CHECK_EQ(op_desc.Output(paddle::GradVarName("Filter")).size(), 1UL); auto dw_name = op_desc.Output(paddle::GradVarName("Filter")).front(); - auto strides = utils::GetAttrOrDefault>(op_desc, "strides", {1, 1}); - auto paddings = utils::GetAttrOrDefault>(op_desc, "paddings", {0, 0}); - auto dilations = utils::GetAttrOrDefault>(op_desc, "dilations", {1, 1}); - auto groups = utils::GetAttrOrDefault(op_desc, "groups", 1); + auto strides = + utils::GetAttrOrDefault>(op_desc, "strides", {1, 1}); + auto paddings = + utils::GetAttrOrDefault>(op_desc, "paddings", {0, 0}); + auto dilations = + utils::GetAttrOrDefault>(op_desc, "dilations", {1, 1}); + auto groups = utils::GetAttrOrDefault(op_desc, "groups", 1); - auto data_format = utils::GetAttrOrDefault(op_desc, "data_format", "AnyLayout"); + auto data_format = + utils::GetAttrOrDefault(op_desc, "data_format", "AnyLayout"); if (data_format == "AnyLayout") { data_format = "NCHW"; } - auto padding_algorithm = utils::GetAttrOrDefault(op_desc, "padding_algorithm", "EXPLICIT"); + auto padding_algorithm = utils::GetAttrOrDefault( + op_desc, "padding_algorithm", "EXPLICIT"); - auto dy = ctx.GetVar(dy_name); - auto x = ctx.GetVar(x_name); + auto dy = ctx.GetVar(dy_name); + auto x = ctx.GetVar(x_name); auto weight = ctx.GetVar(w_name); - CHECK_EQ(x->shape.size(), 4) << "CINN conv2d_grad operator only support 4-D tensor now, but Input's shape is [" + CHECK_EQ(x->shape.size(), 4) << "CINN conv2d_grad operator only support 4-D " + "tensor now, but Input's shape is [" << cinn::utils::Join(x->shape, ", ") << "]"; - CHECK_EQ(dy->shape.size(), 4) << "CINN conv2d_grad operator only support 4-D tensor now, but " - << paddle::GradVarName("Output") << "'s shape is [" - << cinn::utils::Join(dy->shape, ", ") << "]"; - CHECK_EQ(weight->shape.size(), 4) << "CINN conv2d_grad operator only support 4-D tensor now, but Filter's shape is [" - << cinn::utils::Join(weight->shape, ", ") << "]"; + CHECK_EQ(dy->shape.size(), 4) + << "CINN conv2d_grad operator only support 4-D tensor now, but " + << paddle::GradVarName("Output") << "'s shape is [" + << cinn::utils::Join(dy->shape, ", ") << "]"; + CHECK_EQ(weight->shape.size(), 4) + << "CINN conv2d_grad operator only support 4-D tensor now, but Filter's " + "shape is [" + << cinn::utils::Join(weight->shape, ", ") << "]"; if (data_format == "NHWC") { - // the weight in paddle always be NCHW, but cudnn need the same as input, transpose before + // the weight in paddle always be NCHW, but cudnn need the same as input, + // transpose before weight = ctx.Builder()->Transpose(weight, {0, 2, 3, 1}); } if (has_dx) { // create backward data - auto dx = ctx.Builder()->Conv( - weight, dy, strides, paddings, dilations, groups, "backward_data", data_format, padding_algorithm, x->shape); + auto dx = ctx.Builder()->Conv(weight, + dy, + strides, + paddings, + dilations, + groups, + "backward_data", + data_format, + padding_algorithm, + x->shape); ctx.AddVar(dx_name, dx); ctx.AddVarModelToProgram(dx_name, dx->id); } // create backward filter - auto dw = ctx.Builder()->Conv( - x, dy, strides, paddings, dilations, groups, "backward_filter", data_format, padding_algorithm, weight->shape); + auto dw = ctx.Builder()->Conv(x, + dy, + strides, + paddings, + dilations, + groups, + "backward_filter", + data_format, + padding_algorithm, + weight->shape); if (data_format == "NHWC") { - // the weight in paddle always be NCHW, but cudnn need the same as input, transpose back + // the weight in paddle always be NCHW, but cudnn need the same as input, + // transpose back dw = ctx.Builder()->Transpose(dw, {0, 3, 1, 2}); } @@ -168,11 +232,14 @@ void Conv2dGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContex } // namespace cinn CINN_REGISTER_HELPER(paddle_conv2d) { - CINN_REGISTER_OP_MAPPER(conv2d, cinn::frontend::paddle_mappers::Conv2dOpMapper) - CINN_REGISTER_OP_MAPPER(depthwise_conv2d, cinn::frontend::paddle_mappers::DepthwiseConv2dOpMapper) + CINN_REGISTER_OP_MAPPER(conv2d, + cinn::frontend::paddle_mappers::Conv2dOpMapper) + CINN_REGISTER_OP_MAPPER( + depthwise_conv2d, cinn::frontend::paddle_mappers::DepthwiseConv2dOpMapper) #ifdef CINN_WITH_CUDNN - CINN_REGISTER_OP_MAPPER(conv2d_grad, cinn::frontend::paddle_mappers::Conv2dGradOpMapper) + CINN_REGISTER_OP_MAPPER(conv2d_grad, + cinn::frontend::paddle_mappers::Conv2dGradOpMapper) #endif return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/cumsum.cc b/paddle/cinn/frontend/op_mappers/paddle/cumsum.cc index 3ffaafd958a22..080d53302dc17 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/cumsum.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/cumsum.cc @@ -20,34 +20,37 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void CumsumOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void CumsumOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto input = ctx.GetVar(x_name); - auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); - auto flatten = utils::GetAttrOrDefault(op_desc, "flatten", false); + auto input = ctx.GetVar(x_name); + auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); + auto flatten = utils::GetAttrOrDefault(op_desc, "flatten", false); auto exclusive = utils::GetAttrOrDefault(op_desc, "exclusive", false); - auto reverse = utils::GetAttrOrDefault(op_desc, "reverse", false); + auto reverse = utils::GetAttrOrDefault(op_desc, "reverse", false); - auto x = input; + auto x = input; int ndim = x->shape.size(); // If flatten = true, flatten x and do cumsum on axis 0. if (flatten) { - x = ctx.Builder()->Reshape(x, {-1}); + x = ctx.Builder()->Reshape(x, {-1}); axis = 0; ndim = x->shape.size(); } - CHECK(-ndim <= axis && axis < ndim) << "Axis expected to be in range of [" << -ndim << "," << ndim << "]. But got " - << axis << "."; + CHECK(-ndim <= axis && axis < ndim) + << "Axis expected to be in range of [" << -ndim << "," << ndim + << "]. But got " << axis << "."; if (axis < 0) { axis = ndim + axis; } - x = ctx.Builder()->ExpandDims(x, {axis + 1}); - auto rg = ctx.Builder()->Arange(0.0f, static_cast(x->shape[axis]), 1.0f, "int32"); + x = ctx.Builder()->ExpandDims(x, {axis + 1}); + auto rg = ctx.Builder()->Arange( + 0.0f, static_cast(x->shape[axis]), 1.0f, "int32"); cinn::frontend::Variable mask; if (reverse) { mask = ctx.Builder()->GreaterEqual(ctx.Builder()->ExpandDims(rg, {1}), rg); @@ -58,22 +61,26 @@ void CumsumOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c mask = ctx.Builder()->ExpandDims(mask, {-1}); } // Infer broadcast shape for x and mask - int x_ndim = x->shape.size(); + int x_ndim = x->shape.size(); int mask_ndim = mask->shape.size(); std::vector broadcast_shape(std::max(x_ndim, mask_ndim), 0); int broadcast_shape_size = broadcast_shape.size(); for (int i = broadcast_shape.size() - 1; i >= 0; --i) { if (i - (broadcast_shape_size - x_ndim) >= 0) { - broadcast_shape[i] = std::max(broadcast_shape[i], x->shape[i - (broadcast_shape_size - x_ndim)]); + broadcast_shape[i] = std::max( + broadcast_shape[i], x->shape[i - (broadcast_shape_size - x_ndim)]); } if (i - (broadcast_shape_size - mask_ndim) >= 0) { - broadcast_shape[i] = std::max(broadcast_shape[i], mask->shape[i - (broadcast_shape_size - mask_ndim)]); + broadcast_shape[i] = + std::max(broadcast_shape[i], + mask->shape[i - (broadcast_shape_size - mask_ndim)]); } } // Do broadcast shape on mask, x and false_value - mask = ctx.Builder()->BroadcastTo(mask, broadcast_shape); - x = ctx.Builder()->BroadcastTo(x, broadcast_shape); - auto false_value = ctx.Builder()->FillConstant(x->shape, 0, UniqName("false_value"), common::Type2Str(x->type)); + mask = ctx.Builder()->BroadcastTo(mask, broadcast_shape); + x = ctx.Builder()->BroadcastTo(x, broadcast_shape); + auto false_value = ctx.Builder()->FillConstant( + x->shape, 0, UniqName("false_value"), common::Type2Str(x->type)); // Select elements with mask auto selected_x = ctx.Builder()->Select(mask, x, false_value); // Do reduce sum @@ -91,6 +98,7 @@ void CumsumOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c } // namespace cinn CINN_REGISTER_HELPER(paddle_cumsum) { - CINN_REGISTER_OP_MAPPER(cumsum, cinn::frontend::paddle_mappers::CumsumOpMapper) + CINN_REGISTER_OP_MAPPER(cumsum, + cinn::frontend::paddle_mappers::CumsumOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/dropout.cc b/paddle/cinn/frontend/op_mappers/paddle/dropout.cc index 868bca420c2c2..0e88af83c718d 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/dropout.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/dropout.cc @@ -19,17 +19,20 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void DropoutInferOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void DropoutInferOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto dropout_prob = utils::GetAttrOrDefault(op_desc, "dropout_prob", 0.5f); - auto dropout_implementation = - utils::GetAttrOrDefault(op_desc, "dropout_implementation", "downgrade_in_infer"); - auto x = ctx.GetVar(x_name); - auto out = ctx.Builder()->DropoutInfer(x, dropout_prob, dropout_implementation); + auto dropout_prob = + utils::GetAttrOrDefault(op_desc, "dropout_prob", 0.5f); + auto dropout_implementation = utils::GetAttrOrDefault( + op_desc, "dropout_implementation", "downgrade_in_infer"); + auto x = ctx.GetVar(x_name); + auto out = + ctx.Builder()->DropoutInfer(x, dropout_prob, dropout_implementation); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); @@ -40,6 +43,7 @@ void DropoutInferOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperCont } // namespace cinn CINN_REGISTER_HELPER(paddle_dropout) { - CINN_REGISTER_OP_MAPPER(dropout, cinn::frontend::paddle_mappers::DropoutInferOpMapper) + CINN_REGISTER_OP_MAPPER(dropout, + cinn::frontend::paddle_mappers::DropoutInferOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/elementwise.cc b/paddle/cinn/frontend/op_mappers/paddle/elementwise.cc index c2cb5dede47d9..ee408a300eb9f 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/elementwise.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/elementwise.cc @@ -22,7 +22,17 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -enum class EltwiseType { kUnk = 0, kAdd, kDiv, kMul, kSub, kPow, kMod, kMax, kMin }; +enum class EltwiseType { + kUnk = 0, + kAdd, + kDiv, + kMul, + kSub, + kPow, + kMod, + kMax, + kMin +}; template std::string GetEltwiseTypeString(); @@ -46,10 +56,12 @@ EXPAND_ELTWISETYPE_STRING(kMin, " min ") template struct OpBuilder {}; -#define ELTWISE_SPEC(enum_t, function) \ - template <> \ - struct OpBuilder { \ - constexpr static Variable (NetBuilder::*func)(const Variable&, const Variable&, int){&function}; \ +#define ELTWISE_SPEC(enum_t, function) \ + template <> \ + struct OpBuilder { \ + constexpr static Variable (NetBuilder::*func)(const Variable&, \ + const Variable&, \ + int){&function}; \ } ELTWISE_SPEC(EltwiseType::kAdd, NetBuilder::Add); ELTWISE_SPEC(EltwiseType::kDiv, NetBuilder::Divide); @@ -61,7 +73,8 @@ ELTWISE_SPEC(EltwiseType::kMax, NetBuilder::Max); ELTWISE_SPEC(EltwiseType::kMin, NetBuilder::Min); #undef ELTWISE_SPEC -void AddOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void AddOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Y").size(), 1UL); @@ -71,8 +84,8 @@ void AddOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) VLOG(4) << out_name << " = " << x_name << " + " << y_name; - auto x = ctx.GetVar(x_name); - auto y = ctx.GetVar(y_name); + auto x = ctx.GetVar(x_name); + auto y = ctx.GetVar(y_name); auto out = ctx.Builder()->Add(x, y); ctx.AddVar(out_name, out); @@ -80,7 +93,8 @@ void AddOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) } template -void ElementwiseOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ElementwiseOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { VLOG(5) << "Elementwise operator mapping type: " << static_cast(Type); CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); @@ -91,17 +105,19 @@ void ElementwiseOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperConte auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); - VLOG(4) << out_name << " = " << x_name << GetEltwiseTypeString() << y_name << " at " << axis; + VLOG(4) << out_name << " = " << x_name << GetEltwiseTypeString() + << y_name << " at " << axis; - auto x = ctx.GetVar(x_name); - auto y = ctx.GetVar(y_name); + auto x = ctx.GetVar(x_name); + auto y = ctx.GetVar(y_name); auto out = (ctx.Builder()->*OpBuilder::func)(x, y, axis); ctx.AddVar(out_name, out, true); ctx.AddVarModelToProgram(out_name, out->id, true); } -void ElementwiseAddGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ElementwiseAddGradOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Y").size(), 1UL); @@ -121,11 +137,12 @@ void ElementwiseAddGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapp auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); - VLOG(4) << "{X@GRAD=" << dx_name << ", Y@GRAD=" << dy_name << "}=elementwise_add_grad(X=" << x_name - << ", Y=" << y_name << ", OUT@GRAD=" << dout_name << ", axis=" << axis << ")"; + VLOG(4) << "{X@GRAD=" << dx_name << ", Y@GRAD=" << dy_name + << "}=elementwise_add_grad(X=" << x_name << ", Y=" << y_name + << ", OUT@GRAD=" << dout_name << ", axis=" << axis << ")"; - auto x = ctx.GetVar(x_name); - auto y = ctx.GetVar(y_name); + auto x = ctx.GetVar(x_name); + auto y = ctx.GetVar(y_name); auto dout = ctx.GetVar(dout_name); auto outs = ctx.Builder()->ElementwiseAddGrad(dout, x, y, axis); CHECK_EQ(outs.size(), 2) << "elementwise_add_grad should return 2 variables"; @@ -142,7 +159,8 @@ void ElementwiseAddGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapp } } -void SumOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void SumOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_GE(op_desc.Input("X").size(), 1UL); auto x_names = op_desc.Input("X"); CHECK_EQ(op_desc.Output("Out").size(), 1UL); @@ -161,48 +179,60 @@ void SumOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) ctx.AddVarModelToProgram(out_name, out->id, true); } -void CastOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void CastOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - CHECK(op_desc.HasAttr("out_dtype")) << "The cast op should has [out_dtype] attribute!"; - auto dtype = utils::GetPaddleDtype(op_desc, "out_dtype", paddle::cpp::VarDescAPI::Type::FP32); - CHECK(!dtype.empty()) << "The op \"cast\"'s attribute \"out_dtype\" should not be unknown type! Please check."; + CHECK(op_desc.HasAttr("out_dtype")) + << "The cast op should has [out_dtype] attribute!"; + auto dtype = utils::GetPaddleDtype( + op_desc, "out_dtype", paddle::cpp::VarDescAPI::Type::FP32); + CHECK(!dtype.empty()) << "The op \"cast\"'s attribute \"out_dtype\" should " + "not be unknown type! Please check."; - VLOG(4) << out_name << " = cast(X:" << x_name << ", out_dtype=" << dtype << ")"; + VLOG(4) << out_name << " = cast(X:" << x_name << ", out_dtype=" << dtype + << ")"; - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto out = ctx.Builder()->Cast(x, dtype); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); } -void PowOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void PowOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); absl::optional y; - if (op_desc.HasInput("FactorTensor") && !op_desc.Input("FactorTensor").empty()) { + if (op_desc.HasInput("FactorTensor") && + !op_desc.Input("FactorTensor").empty()) { CHECK_EQ(op_desc.Input("FactorTensor").size(), 1UL); auto y_name = op_desc.Input("FactorTensor").front(); - y = ctx.GetVar(y_name); + y = ctx.GetVar(y_name); } else if (op_desc.HasAttr("factor")) { auto factor = utils::GetAttrOrDefault(op_desc, "factor"); - y = ctx.Builder()->FillConstant(x->shape, factor, cinn::UniqName(x_name + "_factor"), common::Type2Str(x->type)); + y = ctx.Builder()->FillConstant(x->shape, + factor, + cinn::UniqName(x_name + "_factor"), + common::Type2Str(x->type)); } else { - LOG(FATAL) << "Cannot found [FactorTensor] input or [factor] attribute in paddle.pow! Please check."; + LOG(FATAL) << "Cannot found [FactorTensor] input or [factor] attribute in " + "paddle.pow! Please check."; } VLOG(4) << out_name << " = pow(" << x_name << ", " << y.value()->id << ")"; - CHECK_EQ(x->type, y.value()->type) << "The data type of pow's inputs should be equal, but here x:" << x->type - << " != y:" << y.value()->type; + CHECK_EQ(x->type, y.value()->type) + << "The data type of pow's inputs should be equal, but here x:" << x->type + << " != y:" << y.value()->type; auto out = ctx.Builder()->Pow(x, y.value()); @@ -210,7 +240,8 @@ void PowOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) ctx.AddVarModelToProgram(out_name, out->id); } -void FloorDivideOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void FloorDivideOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Y").size(), 1UL); @@ -238,20 +269,30 @@ void FloorDivideOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperConte CINN_REGISTER_HELPER(paddle_elementwise) { using namespace cinn::frontend::paddle_mappers; CINN_REGISTER_OP_MAPPER(add, AddOpMapper) - CINN_REGISTER_OP_MAPPER(elementwise_add, ElementwiseOpMapper) + CINN_REGISTER_OP_MAPPER(elementwise_add, + ElementwiseOpMapper) CINN_REGISTER_OP_MAPPER(elementwise_add_grad, ElementwiseAddGradOpMapper) - CINN_REGISTER_OP_MAPPER(elementwise_mul, ElementwiseOpMapper) - CINN_REGISTER_OP_MAPPER(elementwise_div, ElementwiseOpMapper) - CINN_REGISTER_OP_MAPPER(elementwise_sub, ElementwiseOpMapper) - CINN_REGISTER_OP_MAPPER(elementwise_pow, ElementwiseOpMapper) - CINN_REGISTER_OP_MAPPER(elementwise_mod, ElementwiseOpMapper) - CINN_REGISTER_OP_MAPPER(elementwise_max, ElementwiseOpMapper) - CINN_REGISTER_OP_MAPPER(elementwise_min, ElementwiseOpMapper) + CINN_REGISTER_OP_MAPPER(elementwise_mul, + ElementwiseOpMapper) + CINN_REGISTER_OP_MAPPER(elementwise_div, + ElementwiseOpMapper) + CINN_REGISTER_OP_MAPPER(elementwise_sub, + ElementwiseOpMapper) + CINN_REGISTER_OP_MAPPER(elementwise_pow, + ElementwiseOpMapper) + CINN_REGISTER_OP_MAPPER(elementwise_mod, + ElementwiseOpMapper) + CINN_REGISTER_OP_MAPPER(elementwise_max, + ElementwiseOpMapper) + CINN_REGISTER_OP_MAPPER(elementwise_min, + ElementwiseOpMapper) CINN_REGISTER_OP_MAPPER(sum, SumOpMapper) CINN_REGISTER_OP_MAPPER(cast, CastOpMapper) CINN_REGISTER_OP_MAPPER(pow, PowOpMapper) - CINN_REGISTER_OP_MAPPER(grad_add, - ElementwiseOpMapper) // special elementwise_add for gradient accumulation + CINN_REGISTER_OP_MAPPER( + grad_add, + ElementwiseOpMapper) // special elementwise_add for + // gradient accumulation CINN_REGISTER_OP_MAPPER(elementwise_floordiv, FloorDivideOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/expand.cc b/paddle/cinn/frontend/op_mappers/paddle/expand.cc index 9fb4222d53a09..e05e573b82fae 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/expand.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/expand.cc @@ -19,23 +19,27 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void ExpandOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ExpandOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); CHECK(op_desc.HasAttr("expand_times")); - auto expand_times = utils::GetAttrOrDefault>(op_desc, "expand_times"); + auto expand_times = + utils::GetAttrOrDefault>(op_desc, "expand_times"); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto x_shape = x->shape; VLOG(4) << "expand: x shape: " << cinn::utils::Join(x_shape, ", "); - VLOG(4) << "expand: attr expand_times: " << cinn::utils::Join(expand_times, ", "); + VLOG(4) << "expand: attr expand_times: " + << cinn::utils::Join(expand_times, ", "); - CHECK_EQ(expand_times.size(), x_shape.size()) << "The size of `expand_times' should == the rank[" << x_shape.size() - << "] of x's shape, but got " << expand_times.size(); + CHECK_EQ(expand_times.size(), x_shape.size()) + << "The size of `expand_times' should == the rank[" << x_shape.size() + << "] of x's shape, but got " << expand_times.size(); std::vector out_shape(x_shape.size()); for (size_t i = 0; i < x_shape.size(); ++i) { @@ -50,7 +54,8 @@ void ExpandOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c ctx.AddVarModelToProgram(out_name, out->id); } -void ExpandV2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ExpandV2OpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); @@ -59,14 +64,15 @@ void ExpandV2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& CHECK(op_desc.HasAttr("shape")); auto shape = utils::GetAttrOrDefault>(op_desc, "shape"); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto x_shape = x->shape; VLOG(4) << "expand_v2: x shape: " << cinn::utils::Join(x_shape, ", "); VLOG(4) << "expand_v2: attr shape: " << cinn::utils::Join(shape, ", "); - CHECK_GE(shape.size(), x_shape.size()) << "The size of `shape' should >= the rank[" << x_shape.size() - << "] of x's shape, but got " << shape.size(); + CHECK_GE(shape.size(), x_shape.size()) + << "The size of `shape' should >= the rank[" << x_shape.size() + << "] of x's shape, but got " << shape.size(); auto diff = shape.size() - x_shape.size(); x_shape.insert(x_shape.begin(), diff, 1); @@ -74,23 +80,27 @@ void ExpandV2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& std::vector out_shape(x_shape.size()); for (size_t i = 0; i < x_shape.size(); ++i) { - CHECK_NE(shape[i], 0) << "The " << i << "th element in shape cannot be zero."; + CHECK_NE(shape[i], 0) << "The " << i + << "th element in shape cannot be zero."; if (i < diff) { - CHECK_GT(shape[i], 0) << "The " << i << "th element[" << shape[i] - << "] for non-existing dimensions must be positive."; + CHECK_GT(shape[i], 0) + << "The " << i << "th element[" << shape[i] + << "] for non-existing dimensions must be positive."; out_shape[i] = shape[i]; } else if (shape[i] > 0) { if (x_shape[i] != 1) { - CHECK_EQ(shape[i], x_shape[i]) << "The " << i << "th element[" << shape[i] - << "] of the non-singleton dimension does not match" - " the corresponding element[" - << x_shape[i] << "] in x's shape."; + CHECK_EQ(shape[i], x_shape[i]) + << "The " << i << "th element[" << shape[i] + << "] of the non-singleton dimension does not match" + " the corresponding element[" + << x_shape[i] << "] in x's shape."; out_shape[i] = shape[i]; } else { out_shape[i] = shape[i]; } } else { - CHECK_EQ(shape[i], -1) << "When the element in shape is negative for expand_v2 op, only -1 is supported, but got " + CHECK_EQ(shape[i], -1) << "When the element in shape is negative for " + "expand_v2 op, only -1 is supported, but got " << shape[i]; out_shape[i] = x_shape[i]; } @@ -114,7 +124,9 @@ void ExpandV2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& } // namespace cinn CINN_REGISTER_HELPER(paddle_expand) { - CINN_REGISTER_OP_MAPPER(expand, cinn::frontend::paddle_mappers::ExpandOpMapper) - CINN_REGISTER_OP_MAPPER(expand_v2, cinn::frontend::paddle_mappers::ExpandV2OpMapper) + CINN_REGISTER_OP_MAPPER(expand, + cinn::frontend::paddle_mappers::ExpandOpMapper) + CINN_REGISTER_OP_MAPPER(expand_v2, + cinn::frontend::paddle_mappers::ExpandV2OpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/fetch_feed.cc b/paddle/cinn/frontend/op_mappers/paddle/fetch_feed.cc index befa88869afcc..70b89718c0cd2 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/fetch_feed.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/fetch_feed.cc @@ -20,32 +20,35 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void FetchOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void FetchOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto output_name = op_desc.Input("X").front(); ctx.AddFetchVarName(output_name); VLOG(4) << "detect model output: [" << output_name << "]"; } -void FeedOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void FeedOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto feed_name = op_desc.Output("Out").front(); VLOG(4) << "Model get feed [" << feed_name << "]"; // For input parameters if (ctx.Scope().FindVar(cinn::utils::TransValidVarName(feed_name))) { - auto param = ctx.GetVar(feed_name); + auto param = ctx.GetVar(feed_name); const auto& var = ctx.Builder()->CreateInput(param); VLOG(4) << "Create param [" << feed_name << "]" - << " to " << var.id() << " with shape=[" << cinn::utils::Join(var.shape(), ",") - << "], dtype=" << var.type(); + << " to " << var.id() << " with shape=[" + << cinn::utils::Join(var.shape(), ",") << "], dtype=" << var.type(); return; } // For input variables const auto& feed_info = ctx.GetFeedInfo(feed_name); - auto cinn_id = cinn::utils::TransValidVarName(feed_name); - auto input = ctx.Builder()->CreateInput(feed_info.type, feed_info.shape, cinn_id); + auto cinn_id = cinn::utils::TransValidVarName(feed_name); + auto input = + ctx.Builder()->CreateInput(feed_info.type, feed_info.shape, cinn_id); ctx.AddVar(feed_name, input); ctx.AddVarModelToProgram(feed_name, input.id().data()); } diff --git a/paddle/cinn/frontend/op_mappers/paddle/flip.cc b/paddle/cinn/frontend/op_mappers/paddle/flip.cc index 321429f6ef261..419f7c9758db1 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/flip.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/flip.cc @@ -19,16 +19,19 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void FlipOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void FlipOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto axes = utils::GetAttrOrDefault>(op_desc, "axis", std::vector{}); - VLOG(4) << "out_name = flip(" << x_name << ", axis=[" << cinn::utils::Join(axes, ", ") << "])"; + auto axes = utils::GetAttrOrDefault>( + op_desc, "axis", std::vector{}); + VLOG(4) << "out_name = flip(" << x_name << ", axis=[" + << cinn::utils::Join(axes, ", ") << "])"; - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); const auto& ndim = x->shape.size(); for (auto& axis : axes) { if (axis < 0) { diff --git a/paddle/cinn/frontend/op_mappers/paddle/gather.cc b/paddle/cinn/frontend/op_mappers/paddle/gather.cc index 0337bc9ab5c1d..8cc37c1d92358 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/gather.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/gather.cc @@ -21,7 +21,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void GatherOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void GatherOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Index").size(), 1UL); @@ -29,12 +30,13 @@ void GatherOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto index = ctx.GetVar(index_name); auto axis = utils::GetAttrOrDefault(op_desc, "axis", 0); - VLOG(4) << "Gather X:" << x_name << "[" << cinn::utils::Join(x->shape, ",") << "] with index:" << index_name << "[" + VLOG(4) << "Gather X:" << x_name << "[" << cinn::utils::Join(x->shape, ",") + << "] with index:" << index_name << "[" << cinn::utils::Join(index->shape, ",") << "] at axis=" << axis; if (index->shape.size() > 1) { @@ -42,13 +44,15 @@ void GatherOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c bool is_rank_1 = false; for (auto dim : index->shape) { if (dim != 1) { - CHECK(!is_rank_1) << "The \"index\" of \"Gather\" only support rank 1 tensor, but here index.shape=[" + CHECK(!is_rank_1) << "The \"index\" of \"Gather\" only support rank 1 " + "tensor, but here index.shape=[" << cinn::utils::Join(index->shape, ",") << "]"; is_rank_1 = true; } } - auto num = std::accumulate(index->shape.begin(), index->shape.end(), 1, std::multiplies()); - index = ctx.Builder()->Reshape(index, {num}); + auto num = std::accumulate( + index->shape.begin(), index->shape.end(), 1, std::multiplies()); + index = ctx.Builder()->Reshape(index, {num}); } // now paddle science only need reduce sum @@ -63,6 +67,7 @@ void GatherOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c } // namespace cinn CINN_REGISTER_HELPER(paddle_gather) { - CINN_REGISTER_OP_MAPPER(gather, cinn::frontend::paddle_mappers::GatherOpMapper) + CINN_REGISTER_OP_MAPPER(gather, + cinn::frontend::paddle_mappers::GatherOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/gather_nd.cc b/paddle/cinn/frontend/op_mappers/paddle/gather_nd.cc index 29d54622fcd40..05327174fc0d9 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/gather_nd.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/gather_nd.cc @@ -19,7 +19,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void GatherNdOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void GatherNdOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Index").size(), 1UL); @@ -27,10 +28,11 @@ void GatherNdOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto index = ctx.GetVar(index_name); - VLOG(4) << "GatherND X:" << x_name << "[" << cinn::utils::Join(x->shape, ",") << "] with index:" << index_name << "[" + VLOG(4) << "GatherND X:" << x_name << "[" << cinn::utils::Join(x->shape, ",") + << "] with index:" << index_name << "[" << cinn::utils::Join(index->shape, ",") << "]"; auto out = ctx.Builder()->GatherNd(x, index); @@ -44,6 +46,7 @@ void GatherNdOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& } // namespace cinn CINN_REGISTER_HELPER(paddle_gather_nd) { - CINN_REGISTER_OP_MAPPER(gather_nd, cinn::frontend::paddle_mappers::GatherNdOpMapper) + CINN_REGISTER_OP_MAPPER(gather_nd, + cinn::frontend::paddle_mappers::GatherNdOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/gaussian_random.cc b/paddle/cinn/frontend/op_mappers/paddle/gaussian_random.cc index 50790ee2016c3..07c011b1a62ac 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/gaussian_random.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/gaussian_random.cc @@ -20,22 +20,28 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void GaussianRandomOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void GaussianRandomOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto shape_origin = utils::GetAttrOrDefault>(op_desc, "shape"); - auto shape = utils::ToShapeType(shape_origin); + auto shape_origin = + utils::GetAttrOrDefault>(op_desc, "shape"); + auto shape = utils::ToShapeType(shape_origin); auto mean = utils::GetAttrOrDefault(op_desc, "mean", 0.0f); - auto std = utils::GetAttrOrDefault(op_desc, "std", 1.0f); + auto std = utils::GetAttrOrDefault(op_desc, "std", 1.0f); auto seed = utils::GetAttrOrDefault(op_desc, "seed", 0); - auto dtype = utils::GetPaddleDtype(op_desc, "dtype", paddle::cpp::VarDescAPI::Type::FP32); - CHECK(!dtype.empty()) << "The op \"gaussian_random\"'s attribute \"dtype\" should not be unknown type! Please check."; + auto dtype = utils::GetPaddleDtype( + op_desc, "dtype", paddle::cpp::VarDescAPI::Type::FP32); + CHECK(!dtype.empty()) << "The op \"gaussian_random\"'s attribute \"dtype\" " + "should not be unknown type! Please check."; - VLOG(4) << out_name << "[" << cinn::utils::Join(shape, ", ") << "] = uniform_random(mean=" << mean << ", std=" << std - << ", seed=" << seed << ", dtype=" << dtype << ", shape=[" << cinn::utils::Join(shape, ", ") << "])"; + VLOG(4) << out_name << "[" << cinn::utils::Join(shape, ", ") + << "] = uniform_random(mean=" << mean << ", std=" << std + << ", seed=" << seed << ", dtype=" << dtype << ", shape=[" + << cinn::utils::Join(shape, ", ") << "])"; auto out = ctx.Builder()->GaussianRandom(shape, mean, std, seed, dtype); ctx.AddVar(out_name, out); @@ -47,6 +53,7 @@ void GaussianRandomOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperCo } // namespace cinn CINN_REGISTER_HELPER(paddle_gaussian_random) { - CINN_REGISTER_OP_MAPPER(gaussian_random, cinn::frontend::paddle_mappers::GaussianRandomOpMapper) + CINN_REGISTER_OP_MAPPER( + gaussian_random, cinn::frontend::paddle_mappers::GaussianRandomOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/layer_norm.cc b/paddle/cinn/frontend/op_mappers/paddle/layer_norm.cc index d53832ba9bd28..c9a138c3dbc4c 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/layer_norm.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/layer_norm.cc @@ -25,7 +25,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void LayerNormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void LayerNormOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { auto get_input = [&op_desc](const std::string& name) { CHECK_EQ(op_desc.Input(name).size(), 1UL); return op_desc.Input(name).front(); @@ -46,8 +47,9 @@ void LayerNormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext bias_name = get_input("Bias"); } // get attribute values - auto epsilon = utils::GetAttrOrDefault(op_desc, "epsilon", 1e-5f); - auto begin_norm_axis = utils::GetAttrOrDefault(op_desc, "begin_norm_axis", 1); + auto epsilon = utils::GetAttrOrDefault(op_desc, "epsilon", 1e-5f); + auto begin_norm_axis = + utils::GetAttrOrDefault(op_desc, "begin_norm_axis", 1); // get input variable auto x = ctx.GetVar(x_name); absl::optional scale; @@ -59,13 +61,16 @@ void LayerNormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext bias = ctx.GetVar(*bias_name); } - VLOG(4) << "layer_norm X=" << x_name << "[" << x << "], Scale=" << scale_name.value() << "[" << scale.value() - << "], Bias=" << bias_name.value() << "[" << bias.value() << "], epsilon=" << epsilon + VLOG(4) << "layer_norm X=" << x_name << "[" << x + << "], Scale=" << scale_name.value() << "[" << scale.value() + << "], Bias=" << bias_name.value() << "[" << bias.value() + << "], epsilon=" << epsilon << ", begin_norm_axis=" << begin_norm_axis; const auto& x_shape = x->shape; - auto x_ndim = x_shape.size(); - CHECK_LT(begin_norm_axis, x_ndim) << "`begin_norm_axis` must be less than the dimensions of X, but received " + auto x_ndim = x_shape.size(); + CHECK_LT(begin_norm_axis, x_ndim) << "`begin_norm_axis` must be less than " + "the dimensions of X, but received " << begin_norm_axis; VLOG(4) << "-- [layer_norm] begin_norm_axis = " << begin_norm_axis; int left = 1; @@ -88,27 +93,36 @@ void LayerNormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext std::vector shape{left, right}; auto x_reshape = builder->Reshape(x, shape); - auto x_reduce = builder->ReduceSum(x_reshape, {1}); - auto ele_num = builder->FillConstant( - {left}, static_cast(right), common::UniqName("layer_norm_ele_num"), common::Type2Str(x->type)); + auto x_reduce = builder->ReduceSum(x_reshape, {1}); + auto ele_num = builder->FillConstant({left}, + static_cast(right), + common::UniqName("layer_norm_ele_num"), + common::Type2Str(x->type)); auto x_mean = builder->Divide(x_reduce, ele_num); // use `E[|x|^2] - |E[x]|^2` instead of `E[|x - E[x]|^2])` to compute variance - auto x2 = builder->Multiply(x_reshape, builder->Identity(x_reshape)); + auto x2 = builder->Multiply(x_reshape, builder->Identity(x_reshape)); auto x2_reduce = builder->ReduceSum(x2, {1}); - auto x2_mean = builder->Divide(x2_reduce, ele_num); - auto x_mean2 = builder->Multiply(x_mean, builder->Identity(x_mean)); - auto zero = builder->FillConstant({left}, 0.f, common::UniqName("layer_norm_zero"), common::Type2Str(x->type)); - auto x_var = builder->Max(builder->Subtract(x2_mean, x_mean2), zero); + auto x2_mean = builder->Divide(x2_reduce, ele_num); + auto x_mean2 = builder->Multiply(x_mean, builder->Identity(x_mean)); + auto zero = builder->FillConstant({left}, + 0.f, + common::UniqName("layer_norm_zero"), + common::Type2Str(x->type)); + auto x_var = builder->Max(builder->Subtract(x2_mean, x_mean2), zero); // compute x norm auto x_mean_broadcast = builder->BroadcastTo(x_mean, shape, {0}); - auto y_sub = builder->Subtract(x_reshape, x_mean_broadcast); + auto y_sub = builder->Subtract(x_reshape, x_mean_broadcast); auto epsilon_var = - builder->FillConstant({left}, epsilon, common::UniqName("layer_norm_epsilon"), common::Type2Str(x->type)); - auto x_var_eps = builder->Add(x_var, epsilon_var); + builder->FillConstant({left}, + epsilon, + common::UniqName("layer_norm_epsilon"), + common::Type2Str(x->type)); + auto x_var_eps = builder->Add(x_var, epsilon_var); auto x_var_sqrt = builder->Sqrt(x_var_eps); - auto y_out = builder->Divide(y_sub, builder->BroadcastTo(x_var_sqrt, shape, {0})); + auto y_out = + builder->Divide(y_sub, builder->BroadcastTo(x_var_sqrt, shape, {0})); // multiply scale if (scale) { @@ -116,7 +130,7 @@ void LayerNormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext scale = ctx.Builder()->Cast(scale.value(), "float32"); } auto scale_broadcast = builder->BroadcastTo(*scale, shape, {1}); - y_out = builder->Multiply(y_out, scale_broadcast); + y_out = builder->Multiply(y_out, scale_broadcast); } // add bias @@ -125,7 +139,7 @@ void LayerNormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext bias = ctx.Builder()->Cast(bias.value(), "float32"); } auto bias_broadcast = builder->BroadcastTo(*bias, shape, {1}); - y_out = builder->Add(y_out, bias_broadcast); + y_out = builder->Add(y_out, bias_broadcast); } // reshape to the original shape @@ -138,8 +152,8 @@ void LayerNormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext } // get output names - auto y_name = get_output("Y"); - auto mean_name = get_output("Mean"); + auto y_name = get_output("Y"); + auto mean_name = get_output("Mean"); auto variance_name = get_output("Variance"); // re-mapper outputs ctx.AddVar(y_name, y_out); @@ -155,6 +169,7 @@ void LayerNormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext } // namespace cinn CINN_REGISTER_HELPER(paddle_layer_norm) { - CINN_REGISTER_OP_MAPPER(layer_norm, cinn::frontend::paddle_mappers::LayerNormOpMapper) + CINN_REGISTER_OP_MAPPER(layer_norm, + cinn::frontend::paddle_mappers::LayerNormOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/log.cc b/paddle/cinn/frontend/op_mappers/paddle/log.cc index 9a1197a869213..bb849119e93b4 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/log.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/log.cc @@ -19,46 +19,50 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void LogOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void LogOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto out = ctx.Builder()->Log(x); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); } -void Log2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void Log2OpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto out = ctx.Builder()->Log2(x); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); } -void Log10OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void Log10OpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto out = ctx.Builder()->Log10(x); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); } -void Log1pOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void Log1pOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); @@ -66,9 +70,11 @@ void Log1pOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ct auto x = ctx.GetVar(x_name); - auto one = - ctx.Builder()->FillConstant(x->shape, 1.0f, cinn::UniqName(x->id + "_1p"), cinn::common::Type2Str(x->type)); - auto y = ctx.Builder()->Add(x, one); + auto one = ctx.Builder()->FillConstant(x->shape, + 1.0f, + cinn::UniqName(x->id + "_1p"), + cinn::common::Type2Str(x->type)); + auto y = ctx.Builder()->Add(x, one); auto out = ctx.Builder()->Log(y); ctx.AddVar(out_name, out); diff --git a/paddle/cinn/frontend/op_mappers/paddle/lookup_table.cc b/paddle/cinn/frontend/op_mappers/paddle/lookup_table.cc index 9eefbae5bc9c0..4a7df547ce9c1 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/lookup_table.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/lookup_table.cc @@ -19,36 +19,40 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void LookupTableOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void LookupTableOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("W").size(), 1UL); auto w_name = op_desc.Input("W").front(); CHECK_EQ(op_desc.Input("Ids").size(), 1UL); auto ids_name = op_desc.Input("Ids").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto w = ctx.GetVar(w_name); - auto ids = ctx.GetVar(ids_name); + auto w = ctx.GetVar(w_name); + auto ids = ctx.GetVar(ids_name); CHECK(op_desc.HasAttr("padding_idx")); - auto padding_idx = utils::GetAttrOrDefault(op_desc, "padding_idx", -1); - auto out = ctx.Builder()->LookupTable(w, ids, padding_idx); + auto padding_idx = + utils::GetAttrOrDefault(op_desc, "padding_idx", -1); + auto out = ctx.Builder()->LookupTable(w, ids, padding_idx); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); } -void LookupTableV2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void LookupTableV2OpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("W").size(), 1UL); auto w_name = op_desc.Input("W").front(); CHECK_EQ(op_desc.Input("Ids").size(), 1UL); auto ids_name = op_desc.Input("Ids").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto w = ctx.GetVar(w_name); - auto ids = ctx.GetVar(ids_name); - ids = ctx.Builder()->ExpandDims(ids, {-1}); + auto w = ctx.GetVar(w_name); + auto ids = ctx.GetVar(ids_name); + ids = ctx.Builder()->ExpandDims(ids, {-1}); CHECK(op_desc.HasAttr("padding_idx")); - auto padding_idx = utils::GetAttrOrDefault(op_desc, "padding_idx", -1); - auto out = ctx.Builder()->LookupTable(w, ids, padding_idx); + auto padding_idx = + utils::GetAttrOrDefault(op_desc, "padding_idx", -1); + auto out = ctx.Builder()->LookupTable(w, ids, padding_idx); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); @@ -59,7 +63,9 @@ void LookupTableV2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperCon } // namespace cinn CINN_REGISTER_HELPER(paddle_lookup_table) { - CINN_REGISTER_OP_MAPPER(lookup_table, cinn::frontend::paddle_mappers::LookupTableOpMapper) - CINN_REGISTER_OP_MAPPER(lookup_table_v2, cinn::frontend::paddle_mappers::LookupTableV2OpMapper) + CINN_REGISTER_OP_MAPPER(lookup_table, + cinn::frontend::paddle_mappers::LookupTableOpMapper) + CINN_REGISTER_OP_MAPPER(lookup_table_v2, + cinn::frontend::paddle_mappers::LookupTableV2OpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/matmul.cc b/paddle/cinn/frontend/op_mappers/paddle/matmul.cc index 0a35cc56e708d..baf758a7fac75 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/matmul.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/matmul.cc @@ -19,7 +19,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void MatMulOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void MatMulOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Y").size(), 1UL); @@ -28,18 +29,19 @@ void MatMulOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c auto out_name = op_desc.Output("Out").front(); auto trans_x = utils::GetAttrOrDefault(op_desc, "trans_x", false); - trans_x = utils::GetAttrOrDefault(op_desc, "transpose_X", trans_x); + trans_x = utils::GetAttrOrDefault(op_desc, "transpose_X", trans_x); auto trans_y = utils::GetAttrOrDefault(op_desc, "trans_y", false); - trans_y = utils::GetAttrOrDefault(op_desc, "transpose_Y", trans_y); + trans_y = utils::GetAttrOrDefault(op_desc, "transpose_Y", trans_y); auto alpha = utils::GetAttrOrDefault(op_desc, "alpha", 1.0f); - VLOG(4) << out_name << "=matmul{" << x_name << ", " << y_name << ", trans_x=" << trans_x << ", trans_y=" << trans_y + VLOG(4) << out_name << "=matmul{" << x_name << ", " << y_name + << ", trans_x=" << trans_x << ", trans_y=" << trans_y << ", alpha=" << alpha << "}"; - auto x = ctx.GetVar(x_name); - auto y = ctx.GetVar(y_name); + auto x = ctx.GetVar(x_name); + auto y = ctx.GetVar(y_name); auto out = ctx.Builder()->Matmul(x, y, trans_x, trans_y, alpha); ctx.AddVar(out_name, out); @@ -51,7 +53,9 @@ void MatMulOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c } // namespace cinn CINN_REGISTER_HELPER(paddle_matmul) { - CINN_REGISTER_OP_MAPPER(matmul, cinn::frontend::paddle_mappers::MatMulOpMapper) - CINN_REGISTER_OP_MAPPER(matmul_v2, cinn::frontend::paddle_mappers::MatMulOpMapper) + CINN_REGISTER_OP_MAPPER(matmul, + cinn::frontend::paddle_mappers::MatMulOpMapper) + CINN_REGISTER_OP_MAPPER(matmul_v2, + cinn::frontend::paddle_mappers::MatMulOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/mul.cc b/paddle/cinn/frontend/op_mappers/paddle/mul.cc index a1f2985b912c9..eedb8c828c6a5 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/mul.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/mul.cc @@ -20,7 +20,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void MulOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void MulOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Y").size(), 1UL); @@ -30,8 +31,10 @@ void MulOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) auto y = ctx.GetVar(y_name); // Step1: flatten multi-dimension matrix input to two-dimension matrix - auto x_num_col_dims = utils::GetAttrOrDefault(op_desc, "x_num_col_dims", 1); - auto y_num_col_dims = utils::GetAttrOrDefault(op_desc, "y_num_col_dims", 1); + auto x_num_col_dims = + utils::GetAttrOrDefault(op_desc, "x_num_col_dims", 1); + auto y_num_col_dims = + utils::GetAttrOrDefault(op_desc, "y_num_col_dims", 1); auto out = ctx.Builder()->Mul(x, y, x_num_col_dims, y_num_col_dims); diff --git a/paddle/cinn/frontend/op_mappers/paddle/norm.cc b/paddle/cinn/frontend/op_mappers/paddle/norm.cc index 6bc7273a4bbd1..8e45ead8bf185 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/norm.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/norm.cc @@ -21,16 +21,19 @@ namespace paddle_mappers { struct NormHelper { NormHelper(NetBuilder* net_builder, int32_t axis) { - builder = net_builder; - reduce_dim = {axis}; + builder = net_builder; + reduce_dim = {axis}; num_instructions = builder->size(); } - ~NormHelper() { VLOG(4) << "norm is decomposed to " << builder->size() - num_instructions << " instructions."; } + ~NormHelper() { + VLOG(4) << "norm is decomposed to " << builder->size() - num_instructions + << " instructions."; + } // square_sum = reduce_sum(x * x) Variable SquareSum(Variable x) { - auto x_square = builder->Multiply(x, builder->Identity(x)); + auto x_square = builder->Multiply(x, builder->Identity(x)); auto x_square_sum = Reduce(x_square); return x_square_sum; @@ -38,20 +41,25 @@ struct NormHelper { // std_square_sum = sqrt(square_sum + epsilon) Variable StdSquareSum(Variable square_sum, float epsilon) { - auto epsilon_1d = builder->FillConstant( - square_sum->shape, epsilon, common::UniqName("norm_epsilon"), common::Type2Str(square_sum->type)); + auto epsilon_1d = builder->FillConstant(square_sum->shape, + epsilon, + common::UniqName("norm_epsilon"), + common::Type2Str(square_sum->type)); auto std_square_sum = builder->Sqrt(builder->Add(square_sum, epsilon_1d)); return std_square_sum; } - Variable Reduce(Variable x) { return builder->ReduceSum(x, reduce_dim, true); } + Variable Reduce(Variable x) { + return builder->ReduceSum(x, reduce_dim, true); + } NetBuilder* builder{nullptr}; std::vector reduce_dim; int num_instructions{0}; }; -void NormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void NormOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); @@ -64,14 +72,17 @@ void NormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx } CHECK(op_desc.HasAttr("axis")); - auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); + auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); auto epsilon = utils::GetAttrOrDefault(op_desc, "epsilon", 1.0e-10f); - auto is_test = utils::GetAttrOrDefault(op_desc, "is_test", norm_name.empty()); + auto is_test = + utils::GetAttrOrDefault(op_desc, "is_test", norm_name.empty()); auto x = ctx.GetVar(x_name); - VLOG(4) << "Out=" << out_name << ", Norm=" << norm_name << " = norm(X:" << x_name << "=" << x << ", axis=" << axis - << ", epsilon=" << epsilon << ", is_test=" << std::ios::boolalpha << is_test; + VLOG(4) << "Out=" << out_name << ", Norm=" << norm_name + << " = norm(X:" << x_name << "=" << x << ", axis=" << axis + << ", epsilon=" << epsilon << ", is_test=" << std::ios::boolalpha + << is_test; if (axis < 0) { axis += x->shape.size(); @@ -85,16 +96,17 @@ void NormOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx if (in_type.is_float16() || in_type.is_bfloat16()) { x = ctx.Builder()->Cast(x, "float32"); } - auto square_sum = helper.SquareSum(x); + auto square_sum = helper.SquareSum(x); auto std_square_sum = helper.StdSquareSum(square_sum, epsilon); - auto normalized = ctx.Builder()->Divide(x, std_square_sum); - auto y = ctx.Builder()->Cast(normalized, common::Type2Str(in_type)); + auto normalized = ctx.Builder()->Divide(x, std_square_sum); + auto y = ctx.Builder()->Cast(normalized, common::Type2Str(in_type)); ctx.AddVar(out_name, y); ctx.AddVarModelToProgram(out_name, y->id); if (!norm_name.empty()) { - auto norm_grad = ctx.Builder()->Cast(std_square_sum, common::Type2Str(in_type)); + auto norm_grad = + ctx.Builder()->Cast(std_square_sum, common::Type2Str(in_type)); ctx.AddVar(norm_name, norm_grad); ctx.AddVarModelToProgram(norm_name, norm_grad->id); } diff --git a/paddle/cinn/frontend/op_mappers/paddle/one_hot.cc b/paddle/cinn/frontend/op_mappers/paddle/one_hot.cc index 3cf731110f00f..40623d0150040 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/one_hot.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/one_hot.cc @@ -20,45 +20,56 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void OneHotOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void OneHotOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); auto depth = utils::GetAttrOrDefault(op_desc, "depth", 1); - auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); + auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); - auto on_value = ctx.Builder()->FillConstant({1}, 1, cinn::UniqName(x_name + "_on_value"), "int32"); - auto off_value = ctx.Builder()->FillConstant({1}, 0, cinn::UniqName(x_name + "_off_value"), "int32"); + auto on_value = ctx.Builder()->FillConstant( + {1}, 1, cinn::UniqName(x_name + "_on_value"), "int32"); + auto off_value = ctx.Builder()->FillConstant( + {1}, 0, cinn::UniqName(x_name + "_off_value"), "int32"); - auto dtype = utils::GetPaddleDtype(op_desc, "dtype", paddle::cpp::VarDescAPI::Type::FP32); - CHECK(!dtype.empty()) << "The op \"ont_hot\"'s attribute \"dtype\" should not be unknown type! Please check."; + auto dtype = utils::GetPaddleDtype( + op_desc, "dtype", paddle::cpp::VarDescAPI::Type::FP32); + CHECK(!dtype.empty()) << "The op \"ont_hot\"'s attribute \"dtype\" should " + "not be unknown type! Please check."; - auto x = ctx.GetVar(x_name); - x = ctx.Builder()->Slice(x, {static_cast(x->shape.size()) - 1}, {0}, {1}, {}, {1}, {}); - x = ctx.Builder()->Squeeze(x, {-1}); + auto x = ctx.GetVar(x_name); + x = ctx.Builder()->Slice( + x, {static_cast(x->shape.size()) - 1}, {0}, {1}, {}, {1}, {}); + x = ctx.Builder()->Squeeze(x, {-1}); auto out = ctx.Builder()->OneHot(x, on_value, off_value, depth, axis, dtype); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); } -void OneHotV2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void OneHotV2OpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); auto depth = utils::GetAttrOrDefault(op_desc, "depth", 1); - auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); + auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); - auto on_value = ctx.Builder()->FillConstant({1}, 1, cinn::UniqName(x_name + "_on_value"), "int32"); - auto off_value = ctx.Builder()->FillConstant({1}, 0, cinn::UniqName(x_name + "_off_value"), "int32"); + auto on_value = ctx.Builder()->FillConstant( + {1}, 1, cinn::UniqName(x_name + "_on_value"), "int32"); + auto off_value = ctx.Builder()->FillConstant( + {1}, 0, cinn::UniqName(x_name + "_off_value"), "int32"); - auto dtype = utils::GetPaddleDtype(op_desc, "dtype", paddle::cpp::VarDescAPI::Type::FP32); - CHECK(!dtype.empty()) << "The op \"one_hot_v2\"'s attribute \"dtype\" should not be unknown type! Please check."; + auto dtype = utils::GetPaddleDtype( + op_desc, "dtype", paddle::cpp::VarDescAPI::Type::FP32); + CHECK(!dtype.empty()) << "The op \"one_hot_v2\"'s attribute \"dtype\" should " + "not be unknown type! Please check."; - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto out = ctx.Builder()->OneHot(x, on_value, off_value, depth, axis, dtype); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); @@ -69,7 +80,9 @@ void OneHotV2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& } // namespace cinn CINN_REGISTER_HELPER(paddle_one_hot) { - CINN_REGISTER_OP_MAPPER(one_hot, cinn::frontend::paddle_mappers::OneHotOpMapper) - CINN_REGISTER_OP_MAPPER(one_hot_v2, cinn::frontend::paddle_mappers::OneHotV2OpMapper) + CINN_REGISTER_OP_MAPPER(one_hot, + cinn::frontend::paddle_mappers::OneHotOpMapper) + CINN_REGISTER_OP_MAPPER(one_hot_v2, + cinn::frontend::paddle_mappers::OneHotV2OpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/pool2d.cc b/paddle/cinn/frontend/op_mappers/paddle/pool2d.cc index 22da2aebd7433..838b2b0bb9182 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/pool2d.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/pool2d.cc @@ -19,28 +19,35 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void Pool2dOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void Pool2dOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); CHECK(op_desc.HasAttr("pooling_type")); - auto pooling_type = utils::GetAttrOrDefault(op_desc, "pooling_type"); + auto pooling_type = + utils::GetAttrOrDefault(op_desc, "pooling_type"); CHECK(op_desc.HasAttr("ksize")); auto ksize = utils::GetAttrOrDefault>(op_desc, "ksize"); - auto strides = utils::GetAttrOrDefault>(op_desc, "strides", {1, 1}); - auto padding_size = utils::GetAttrOrDefault>(op_desc, "paddings", {0, 0}); - - auto ceil_mode = utils::GetAttrOrDefault(op_desc, "ceil_mode", false); - auto exclusive = utils::GetAttrOrDefault(op_desc, "exclusive", true); - auto global_pooling = utils::GetAttrOrDefault(op_desc, "global_pooling", false); - auto data_format = utils::GetAttrOrDefault(op_desc, "data_format", "NCHW"); - auto adaptive = utils::GetAttrOrDefault(op_desc, "adaptive", false); - auto padding_algorithm = utils::GetAttrOrDefault(op_desc, "padding_algorithm", "EXPLICIT"); - auto x = ctx.GetVar(x_name); - auto out = ctx.Builder()->Pool2d(x, + auto strides = + utils::GetAttrOrDefault>(op_desc, "strides", {1, 1}); + auto padding_size = + utils::GetAttrOrDefault>(op_desc, "paddings", {0, 0}); + + auto ceil_mode = utils::GetAttrOrDefault(op_desc, "ceil_mode", false); + auto exclusive = utils::GetAttrOrDefault(op_desc, "exclusive", true); + auto global_pooling = + utils::GetAttrOrDefault(op_desc, "global_pooling", false); + auto data_format = + utils::GetAttrOrDefault(op_desc, "data_format", "NCHW"); + auto adaptive = utils::GetAttrOrDefault(op_desc, "adaptive", false); + auto padding_algorithm = utils::GetAttrOrDefault( + op_desc, "padding_algorithm", "EXPLICIT"); + auto x = ctx.GetVar(x_name); + auto out = ctx.Builder()->Pool2d(x, pooling_type, ksize, strides, @@ -56,7 +63,8 @@ void Pool2dOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c ctx.AddVarModelToProgram(out_name, out->id); } -void Pool2dGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void Pool2dGradOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Out").size(), 1UL); @@ -68,22 +76,28 @@ void Pool2dGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContex auto dx_name = op_desc.Output(paddle::GradVarName("X")).front(); CHECK(op_desc.HasAttr("pooling_type")); - auto pooling_type = utils::GetAttrOrDefault(op_desc, "pooling_type"); + auto pooling_type = + utils::GetAttrOrDefault(op_desc, "pooling_type"); CHECK(op_desc.HasAttr("ksize")); auto ksize = utils::GetAttrOrDefault>(op_desc, "ksize"); - auto strides = utils::GetAttrOrDefault>(op_desc, "strides", {1, 1}); - auto padding_size = utils::GetAttrOrDefault>(op_desc, "paddings", {0, 0}); - - auto ceil_mode = utils::GetAttrOrDefault(op_desc, "ceil_mode", false); - auto exclusive = utils::GetAttrOrDefault(op_desc, "exclusive", true); - auto global_pooling = utils::GetAttrOrDefault(op_desc, "global_pooling", false); - auto data_format = utils::GetAttrOrDefault(op_desc, "data_format", "NCHW"); - auto adaptive = utils::GetAttrOrDefault(op_desc, "adaptive", false); - auto padding_algorithm = utils::GetAttrOrDefault(op_desc, "padding_algorithm", "EXPLICIT"); - - auto x = ctx.GetVar(x_name); - auto y = ctx.GetVar(y_name); + auto strides = + utils::GetAttrOrDefault>(op_desc, "strides", {1, 1}); + auto padding_size = + utils::GetAttrOrDefault>(op_desc, "paddings", {0, 0}); + + auto ceil_mode = utils::GetAttrOrDefault(op_desc, "ceil_mode", false); + auto exclusive = utils::GetAttrOrDefault(op_desc, "exclusive", true); + auto global_pooling = + utils::GetAttrOrDefault(op_desc, "global_pooling", false); + auto data_format = + utils::GetAttrOrDefault(op_desc, "data_format", "NCHW"); + auto adaptive = utils::GetAttrOrDefault(op_desc, "adaptive", false); + auto padding_algorithm = utils::GetAttrOrDefault( + op_desc, "padding_algorithm", "EXPLICIT"); + + auto x = ctx.GetVar(x_name); + auto y = ctx.GetVar(y_name); auto dy = ctx.GetVar(dy_name); auto out = ctx.Builder()->Pool2dGrad(x, @@ -109,7 +123,9 @@ void Pool2dGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContex } // namespace cinn CINN_REGISTER_HELPER(paddle_pool2d) { - CINN_REGISTER_OP_MAPPER(pool2d, cinn::frontend::paddle_mappers::Pool2dOpMapper) - CINN_REGISTER_OP_MAPPER(pool2d_grad, cinn::frontend::paddle_mappers::Pool2dGradOpMapper) + CINN_REGISTER_OP_MAPPER(pool2d, + cinn::frontend::paddle_mappers::Pool2dOpMapper) + CINN_REGISTER_OP_MAPPER(pool2d_grad, + cinn::frontend::paddle_mappers::Pool2dGradOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/randint.cc b/paddle/cinn/frontend/op_mappers/paddle/randint.cc index 3fde481723ec1..cb072cf704a48 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/randint.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/randint.cc @@ -12,34 +12,42 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "glog/logging.h" #include "paddle/cinn/frontend/op_mapper_registry.h" #include "paddle/cinn/frontend/op_mappers/common_utils.h" #include "paddle/cinn/frontend/var_type_utils.h" -#include "glog/logging.h" namespace cinn { namespace frontend { namespace paddle_mappers { -void RandIntOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void RandIntOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - CHECK(op_desc.HasAttr("shape")) << "Cannot find attribute \"shape\" in paddle op \"randint\"! Please check."; - auto shape_origin = utils::GetAttrOrDefault>(op_desc, "shape"); - auto shape = utils::ToShapeType(shape_origin); + CHECK(op_desc.HasAttr("shape")) << "Cannot find attribute \"shape\" in " + "paddle op \"randint\"! Please check."; + auto shape_origin = + utils::GetAttrOrDefault>(op_desc, "shape"); + auto shape = utils::ToShapeType(shape_origin); - CHECK(op_desc.HasAttr("low")) << "Cannot find attribute \"low\" in paddle op \"randint\"! Please check."; + CHECK(op_desc.HasAttr("low")) << "Cannot find attribute \"low\" in paddle op " + "\"randint\"! Please check."; auto min = utils::GetAttrOrDefault(op_desc, "low", 0); - CHECK(op_desc.HasAttr("high")) << "Cannot find attribute \"high\" in paddle op \"randint\"! Please check."; + CHECK(op_desc.HasAttr("high")) << "Cannot find attribute \"high\" in paddle " + "op \"randint\"! Please check."; auto max = utils::GetAttrOrDefault(op_desc, "high", 0); - CHECK_GT(max, min) << "max(" << max << ") should greater than min(" << min << ")! Please check."; + CHECK_GT(max, min) << "max(" << max << ") should greater than min(" << min + << ")! Please check."; auto seed = utils::GetAttrOrDefault(op_desc, "seed", 0); - auto dtype = utils::GetPaddleDtype(op_desc, "dtype", paddle::cpp::VarDescAPI::Type::INT64); - CHECK(dtype == "int32" || dtype == "int64") << "the indices dtype must be int32 or int64, but got dtype = " << dtype; + auto dtype = utils::GetPaddleDtype( + op_desc, "dtype", paddle::cpp::VarDescAPI::Type::INT64); + CHECK(dtype == "int32" || dtype == "int64") + << "the indices dtype must be int32 or int64, but got dtype = " << dtype; auto out = ctx.Builder()->RandInt(shape, min, max, seed, dtype); ctx.AddVar(out_name, out); @@ -51,6 +59,7 @@ void RandIntOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& } // namespace cinn CINN_REGISTER_HELPER(paddle_randint) { - CINN_REGISTER_OP_MAPPER(randint, cinn::frontend::paddle_mappers::RandIntOpMapper) + CINN_REGISTER_OP_MAPPER(randint, + cinn::frontend::paddle_mappers::RandIntOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/reduce.cc b/paddle/cinn/frontend/op_mappers/paddle/reduce.cc index 24fbd99b2a0eb..0d52d7ec6d7cd 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/reduce.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/reduce.cc @@ -20,20 +20,24 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void ReduceOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx, const std::string& reduce_type) { +void ReduceOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx, + const std::string& reduce_type) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto axis = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "dim")); - auto keepdim = utils::GetAttrOrDefault(op_desc, "keep_dim", false); + auto axis = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "dim")); + auto keepdim = utils::GetAttrOrDefault(op_desc, "keep_dim", false); auto reduce_all = utils::GetAttrOrDefault(op_desc, "reduce_all", false); auto x = ctx.GetVar(x_name); - VLOG(4) << "Reudce " << reduce_type << " x:" << x_name << " from shape (" << cinn::utils::Join(x->shape, ",") - << "), with dim=[" << cinn::utils::Join(axis, ",") << "], keepdim=" << keepdim + VLOG(4) << "Reudce " << reduce_type << " x:" << x_name << " from shape (" + << cinn::utils::Join(x->shape, ",") << "), with dim=[" + << cinn::utils::Join(axis, ",") << "], keepdim=" << keepdim << ", reduce_all=" << reduce_all; if (reduce_all) { @@ -60,22 +64,27 @@ void ReduceOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c } else if (reduce_type == "Mean") { int num = 1; if (axis.empty()) { - num = std::accumulate(x->shape.begin(), x->shape.end(), 1, std::multiplies()); + num = std::accumulate( + x->shape.begin(), x->shape.end(), 1, std::multiplies()); } else { for (int i = 0; i < axis.size(); ++i) { num *= x->shape[axis[i]]; } } - const auto& sum = ctx.Builder()->ReduceSum(x, axis, keepdim); - const auto& size = ctx.Builder()->FillConstant( - sum->shape, num, cinn::common::UniqName(x->id + "_mean"), cinn::common::Type2Str(sum->type)); + const auto& sum = ctx.Builder()->ReduceSum(x, axis, keepdim); + const auto& size = + ctx.Builder()->FillConstant(sum->shape, + num, + cinn::common::UniqName(x->id + "_mean"), + cinn::common::Type2Str(sum->type)); out = ctx.Builder()->Divide(sum, size); } CHECK(out) << "Not support Reduce " << reduce_type << "! Please check."; - auto dtype = utils::GetPaddleDtype(op_desc, "out_dtype", static_cast(-1)); + auto dtype = utils::GetPaddleDtype( + op_desc, "out_dtype", static_cast(-1)); if (!dtype.empty() && common::Type2Str(out.value()->type) != dtype) { out = ctx.Builder()->Cast(out.value(), dtype); } @@ -84,9 +93,10 @@ void ReduceOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c ctx.AddVarModelToProgram(out_name, out.value()->id); } -#define EXPAND_REDUCE_OPMAPPER(ReduceType) \ - void Reduce##ReduceType##OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { \ - ReduceOpMapper(op_desc, ctx, #ReduceType); \ +#define EXPAND_REDUCE_OPMAPPER(ReduceType) \ + void Reduce##ReduceType##OpMapper(const paddle::cpp::OpDesc& op_desc, \ + const OpMapperContext& ctx) { \ + ReduceOpMapper(op_desc, ctx, #ReduceType); \ } EXPAND_REDUCE_OPMAPPER(Sum) @@ -104,7 +114,8 @@ EXPAND_REDUCE_OPMAPPER(Mean) CINN_REGISTER_HELPER(paddle_reduce) { #define EXPAND_REDUCE_OP_MAPPER_REGISTER(op_name, ReduceType) \ - CINN_REGISTER_OP_MAPPER(op_name, cinn::frontend::paddle_mappers::Reduce##ReduceType##OpMapper) + CINN_REGISTER_OP_MAPPER( \ + op_name, cinn::frontend::paddle_mappers::Reduce##ReduceType##OpMapper) EXPAND_REDUCE_OP_MAPPER_REGISTER(reduce_sum, Sum) EXPAND_REDUCE_OP_MAPPER_REGISTER(reduce_prod, Prod) diff --git a/paddle/cinn/frontend/op_mappers/paddle/relu.cc b/paddle/cinn/frontend/op_mappers/paddle/relu.cc index 828e6c98f60a9..0bfc802dbef3e 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/relu.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/relu.cc @@ -19,33 +19,36 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void ReluOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ReluOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = ctx.GetVar(x_name); - auto out = ctx.Builder()->Relu(x); + auto x = ctx.GetVar(x_name); + auto out = ctx.Builder()->Relu(x); ctx.AddVar(out_name, out, true); ctx.AddVarModelToProgram(out_name, out->id, true); } -void Relu6OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void Relu6OpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); auto threshold = utils::GetAttrOrDefault(op_desc, "threshold", 6.0f); - auto x = ctx.GetVar(x_name); - auto out = ctx.Builder()->Relu6(x, threshold); + auto x = ctx.GetVar(x_name); + auto out = ctx.Builder()->Relu6(x, threshold); ctx.AddVar(out_name, out, true); ctx.AddVarModelToProgram(out_name, out->id, true); } -void ReluGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ReluGradOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input(paddle::GradVarName("Out")).size(), 1UL); auto dout_name = op_desc.Input(paddle::GradVarName("Out")).front(); CHECK_EQ(op_desc.Input("Out").size(), 1UL); @@ -54,8 +57,8 @@ void ReluGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& auto dx_name = op_desc.Output(paddle::GradVarName("X")).front(); auto dout = ctx.GetVar(dout_name); - auto out = ctx.GetVar(out_name); - auto dx = ctx.Builder()->ReluGrad(dout, out); + auto out = ctx.GetVar(out_name); + auto dx = ctx.Builder()->ReluGrad(dout, out); ctx.AddVar(dx_name, dx, true); ctx.AddVarModelToProgram(dx_name, dx->id, true); @@ -67,7 +70,8 @@ void ReluGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& CINN_REGISTER_HELPER(paddle_relu) { CINN_REGISTER_OP_MAPPER(relu, cinn::frontend::paddle_mappers::ReluOpMapper) - CINN_REGISTER_OP_MAPPER(relu_grad, cinn::frontend::paddle_mappers::ReluGradOpMapper) + CINN_REGISTER_OP_MAPPER(relu_grad, + cinn::frontend::paddle_mappers::ReluGradOpMapper) CINN_REGISTER_OP_MAPPER(relu6, cinn::frontend::paddle_mappers::Relu6OpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/reshape.cc b/paddle/cinn/frontend/op_mappers/paddle/reshape.cc index 440870bd7fb20..f284501c71a5b 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/reshape.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/reshape.cc @@ -20,10 +20,11 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void ReshapeOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ReshapeOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto shape = utils::GetAttrOrDefault>(op_desc, "shape"); @@ -38,7 +39,8 @@ void ReshapeOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx.AddVarModelToProgram(out_name, out->id); } -void ReshapeGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ReshapeGradOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { auto get_input_var = [&op_desc, &ctx](const std::string& op_name) { CHECK_EQ(op_desc.Input(op_name).size(), 1UL); auto var_name = op_desc.Input(op_name).front(); @@ -63,10 +65,11 @@ void ReshapeGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperConte ctx.AddVarModelToProgram(out_name, out->id); } -void Reshape2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void Reshape2OpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto shape = utils::GetAttrOrDefault>(op_desc, "shape"); @@ -96,7 +99,8 @@ void Reshape2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& } } -void Reshape2GradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void Reshape2GradOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { auto get_input_var = [&op_desc, &ctx](const std::string& op_name) { CHECK_EQ(op_desc.Input(op_name).size(), 1UL); auto var_name = op_desc.Input(op_name).front(); @@ -126,10 +130,14 @@ void Reshape2GradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperCont } // namespace cinn CINN_REGISTER_HELPER(paddle_reshape) { - CINN_REGISTER_OP_MAPPER(reshape, cinn::frontend::paddle_mappers::ReshapeOpMapper) - CINN_REGISTER_OP_MAPPER(reshape2, cinn::frontend::paddle_mappers::Reshape2OpMapper) - - CINN_REGISTER_OP_MAPPER(reshape_grad, cinn::frontend::paddle_mappers::ReshapeGradOpMapper) - CINN_REGISTER_OP_MAPPER(reshape2_grad, cinn::frontend::paddle_mappers::Reshape2GradOpMapper) + CINN_REGISTER_OP_MAPPER(reshape, + cinn::frontend::paddle_mappers::ReshapeOpMapper) + CINN_REGISTER_OP_MAPPER(reshape2, + cinn::frontend::paddle_mappers::Reshape2OpMapper) + + CINN_REGISTER_OP_MAPPER(reshape_grad, + cinn::frontend::paddle_mappers::ReshapeGradOpMapper) + CINN_REGISTER_OP_MAPPER(reshape2_grad, + cinn::frontend::paddle_mappers::Reshape2GradOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/reverse.cc b/paddle/cinn/frontend/op_mappers/paddle/reverse.cc index 9c257e7767208..17418263ae0c5 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/reverse.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/reverse.cc @@ -19,16 +19,19 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void ReverseOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ReverseOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto axes = utils::GetAttrOrDefault>(op_desc, "axis", std::vector{}); - VLOG(4) << "out_name = reverse(" << x_name << ", axis=[" << cinn::utils::Join(axes, ", ") << "])"; + auto axes = utils::GetAttrOrDefault>( + op_desc, "axis", std::vector{}); + VLOG(4) << "out_name = reverse(" << x_name << ", axis=[" + << cinn::utils::Join(axes, ", ") << "])"; - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto out = ctx.Builder()->Reverse(x, axes); ctx.AddVar(out_name, out); @@ -40,6 +43,7 @@ void ReverseOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& } // namespace cinn CINN_REGISTER_HELPER(paddle_reverse) { - CINN_REGISTER_OP_MAPPER(reverse, cinn::frontend::paddle_mappers::ReverseOpMapper) + CINN_REGISTER_OP_MAPPER(reverse, + cinn::frontend::paddle_mappers::ReverseOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/roll.cc b/paddle/cinn/frontend/op_mappers/paddle/roll.cc index c2b6ee0136e38..9b268c2856888 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/roll.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/roll.cc @@ -20,7 +20,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void RollOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void RollOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { // input CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); @@ -31,27 +32,31 @@ void RollOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx // attr shifts and axis CHECK(op_desc.HasAttr("shifts")); CHECK(op_desc.HasAttr("axis")); - std::vector shifts = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "shifts", {1})); - std::vector axis = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "axis", {})); + std::vector shifts = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "shifts", {1})); + std::vector axis = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "axis", {})); - auto x = ctx.GetVar(x_name); - auto vec_x_dims = std::vector(x->shape); + auto x = ctx.GetVar(x_name); + auto vec_x_dims = std::vector(x->shape); std::vector output_shape = vec_x_dims; // check axis and shifts and when axis is None, we should flatten x. bool axis_None = false; if (axis.size() == 0) { - CHECK_EQ(shifts.size(), 1) << "shifts.size() should be equal to 1 when axis is None"; + CHECK_EQ(shifts.size(), 1) + << "shifts.size() should be equal to 1 when axis is None"; axis.push_back(0); - axis_None = true; + axis_None = true; int reshape_num = 1; for (int i = 0; i < vec_x_dims.size(); ++i) { reshape_num *= vec_x_dims[i]; } vec_x_dims = std::vector{reshape_num}; - x = ctx.Builder()->Reshape(x, vec_x_dims); + x = ctx.Builder()->Reshape(x, vec_x_dims); } else { - CHECK_EQ(shifts.size(), axis.size()) << "shifts.size() should be equal to axis.size()"; + CHECK_EQ(shifts.size(), axis.size()) + << "shifts.size() should be equal to axis.size()"; } // preprocessing the shifts and axis @@ -59,8 +64,10 @@ void RollOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx std::unordered_map axis_to_shifts; for (int i = 0; i < shifts_size; ++i) { int vec_x_dims_size = vec_x_dims.size(); - CHECK_GE(axis[i], -vec_x_dims_size) << "axis value should be >= " << -vec_x_dims_size; - CHECK_LT(axis[i], vec_x_dims_size) << "axis value should be < " << vec_x_dims_size; + CHECK_GE(axis[i], -vec_x_dims_size) + << "axis value should be >= " << -vec_x_dims_size; + CHECK_LT(axis[i], vec_x_dims_size) + << "axis value should be < " << vec_x_dims_size; if (axis[i] < 0) { axis[i] += vec_x_dims_size; } @@ -80,11 +87,13 @@ void RollOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx // use Split + Concat for each shift for (const auto& pair : axis_to_shifts) { if (pair.second > 0) { - int length = vec_x_dims[pair.first]; - auto front_slice = ctx.Builder()->Slice(output, {pair.first}, {0}, {length - pair.second}); - auto behind_slice = ctx.Builder()->Slice(output, {pair.first}, {length - pair.second}, {length}); + int length = vec_x_dims[pair.first]; + auto front_slice = ctx.Builder()->Slice( + output, {pair.first}, {0}, {length - pair.second}); + auto behind_slice = ctx.Builder()->Slice( + output, {pair.first}, {length - pair.second}, {length}); auto split_output = std::vector{behind_slice, front_slice}; - output = ctx.Builder()->Concat(split_output, pair.first); + output = ctx.Builder()->Concat(split_output, pair.first); } } diff --git a/paddle/cinn/frontend/op_mappers/paddle/scale.cc b/paddle/cinn/frontend/op_mappers/paddle/scale.cc index be1a9984e738b..b8c8b8c549885 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/scale.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/scale.cc @@ -22,37 +22,45 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void ScaleOpMapper(const paddle::cpp::OpDesc& op_desc, const cinn::frontend::OpMapperContext& ctx) { +void ScaleOpMapper(const paddle::cpp::OpDesc& op_desc, + const cinn::frontend::OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto bias = utils::GetAttrOrDefault(op_desc, "bias", 0.0f); - auto bias_after_scale = utils::GetAttrOrDefault(op_desc, "bias_after_scale", true); + auto bias = utils::GetAttrOrDefault(op_desc, "bias", 0.0f); + auto bias_after_scale = + utils::GetAttrOrDefault(op_desc, "bias_after_scale", true); auto x = ctx.GetVar(x_name); absl::optional out; - if (op_desc.HasInput("ScaleTensor") && !op_desc.Input("ScaleTensor").empty()) { + if (op_desc.HasInput("ScaleTensor") && + !op_desc.Input("ScaleTensor").empty()) { CHECK_EQ(op_desc.Input("ScaleTensor").size(), 1); - auto scale_name = op_desc.Input("ScaleTensor").front(); + auto scale_name = op_desc.Input("ScaleTensor").front(); auto scale_tensor = ctx.GetVar(scale_name); - VLOG(4) << out_name << " = scale(" << x_name << "=" << x << ", scale=" << scale_name << "[" << scale_tensor + VLOG(4) << out_name << " = scale(" << x_name << "=" << x + << ", scale=" << scale_name << "[" << scale_tensor << "], bias=" << bias << ", bias_after_scale=" << bias_after_scale; - CHECK(scale_tensor->shape == cinn::utils::ShapeType{1}) << "The shape of [ScaleTensor] should be [1], but here [" - << cinn::utils::Join(scale_tensor->shape, ", ") << "]"; + CHECK(scale_tensor->shape == cinn::utils::ShapeType{1}) + << "The shape of [ScaleTensor] should be [1], but here [" + << cinn::utils::Join(scale_tensor->shape, ", ") << "]"; scale_tensor = ctx.Builder()->Cast(scale_tensor, common::Type2Str(x->type)); scale_tensor = ctx.Builder()->BroadcastTo(scale_tensor, x->shape); if (bias != 0.0f) { - auto bias_tensor = ctx.Builder()->FillConstant(x->shape, bias, x->id + "_bias", common::Type2Str(x->type)); + auto bias_tensor = ctx.Builder()->FillConstant( + x->shape, bias, x->id + "_bias", common::Type2Str(x->type)); if (bias_after_scale) { - out = ctx.Builder()->Add(bias_tensor, ctx.Builder()->Multiply(x, scale_tensor)); + out = ctx.Builder()->Add(bias_tensor, + ctx.Builder()->Multiply(x, scale_tensor)); } else { - out = ctx.Builder()->Multiply(scale_tensor, ctx.Builder()->Add(x, bias_tensor)); + out = ctx.Builder()->Multiply(scale_tensor, + ctx.Builder()->Add(x, bias_tensor)); } } else { out = ctx.Builder()->Multiply(scale_tensor, x); @@ -60,7 +68,8 @@ void ScaleOpMapper(const paddle::cpp::OpDesc& op_desc, const cinn::frontend::OpM } else { auto scale = utils::GetAttrOrDefault(op_desc, "scale", 1.0f); - VLOG(4) << out_name << " = scale(" << x_name << "=" << x << ", scale=" << scale << ", bias=" << bias + VLOG(4) << out_name << " = scale(" << x_name << "=" << x + << ", scale=" << scale << ", bias=" << bias << ", bias_after_scale=" << bias_after_scale; out = ctx.Builder()->Scale(x, scale, bias, bias_after_scale); diff --git a/paddle/cinn/frontend/op_mappers/paddle/scatter.cc b/paddle/cinn/frontend/op_mappers/paddle/scatter.cc index 72697c4a84b1c..ca7f15ab254c7 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/scatter.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/scatter.cc @@ -20,7 +20,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void ScatterOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ScatterOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Ids").size(), 1UL); @@ -31,13 +32,14 @@ void ScatterOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& auto out_name = op_desc.Output("Out").front(); bool overwrite = utils::GetAttrOrDefault(op_desc, "overwrite", true); - VLOG(4) << "out_name = scatter(X=" << x_name << ", Ids=" << ids_name << ", Updates=" << updates_name - << ", overwrite=" << overwrite << ")"; + VLOG(4) << "out_name = scatter(X=" << x_name << ", Ids=" << ids_name + << ", Updates=" << updates_name << ", overwrite=" << overwrite << ")"; - const auto& input = ctx.GetVar(x_name); - auto indices = ctx.GetVar(ids_name); + const auto& input = ctx.GetVar(x_name); + auto indices = ctx.GetVar(ids_name); const auto& updates = ctx.GetVar(updates_name); - CHECK(input->type == updates->type) << "checks whether the type of the input and the updates are the same."; + CHECK(input->type == updates->type) + << "checks whether the type of the input and the updates are the same."; CHECK(indices->type == common::Int(32) || indices->type == common::Int(64)) << "checks whether the data type of the indices is either int32 or int64"; if (indices->type == common::Int(64)) { @@ -48,15 +50,19 @@ void ScatterOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& indices = ctx.Builder()->Reshape(indices, {1}); } if (indices->shape.size() == 2) { - indices = ctx.Builder()->Reshape(indices, {indices->shape[0] * indices->shape[1]}); + indices = ctx.Builder()->Reshape(indices, + {indices->shape[0] * indices->shape[1]}); } Variable out; if (overwrite) { out = ctx.Builder()->ScatterAssign(input, updates, indices); } else { - const auto& zeros = ctx.Builder()->FillConstant( - updates->shape, 0, common::UniqName("scatter_zeros"), common::Type2Str(updates->type)); + const auto& zeros = + ctx.Builder()->FillConstant(updates->shape, + 0, + common::UniqName("scatter_zeros"), + common::Type2Str(updates->type)); out = ctx.Builder()->ScatterAssign(input, zeros, indices); out = ctx.Builder()->ScatterAdd(out, updates, indices); } @@ -70,6 +76,7 @@ void ScatterOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& } // namespace cinn CINN_REGISTER_HELPER(paddle_scatter) { - CINN_REGISTER_OP_MAPPER(scatter, cinn::frontend::paddle_mappers::ScatterOpMapper) + CINN_REGISTER_OP_MAPPER(scatter, + cinn::frontend::paddle_mappers::ScatterOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/slice.cc b/paddle/cinn/frontend/op_mappers/paddle/slice.cc index cadd9fd3fe79d..6b62ec72410e6 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/slice.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/slice.cc @@ -19,7 +19,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void SliceOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void SliceOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("Input").size(), 1UL); auto x_name = op_desc.Input("Input").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); @@ -32,11 +33,14 @@ void SliceOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ct CHECK(op_desc.HasAttr("axes")); auto axes = utils::GetAttrOrDefault>(op_desc, "axes"); - auto infer_flags = utils::GetAttrOrDefault>(op_desc, "infer_flags"); - auto strides = utils::GetAttrOrDefault>(op_desc, "strides"); - auto decrease_axis = utils::GetAttrOrDefault>(op_desc, "decrease_axis"); - auto x = ctx.GetVar(x_name); - auto out = ctx.Builder()->Slice(x, axes, starts, ends, infer_flags, strides, decrease_axis); + auto infer_flags = + utils::GetAttrOrDefault>(op_desc, "infer_flags"); + auto strides = utils::GetAttrOrDefault>(op_desc, "strides"); + auto decrease_axis = + utils::GetAttrOrDefault>(op_desc, "decrease_axis"); + auto x = ctx.GetVar(x_name); + auto out = ctx.Builder()->Slice( + x, axes, starts, ends, infer_flags, strides, decrease_axis); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); diff --git a/paddle/cinn/frontend/op_mappers/paddle/softmax.cc b/paddle/cinn/frontend/op_mappers/paddle/softmax.cc index 12a1c86bef442..654c21c56c0b4 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/softmax.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/softmax.cc @@ -19,16 +19,18 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void SoftmaxOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void SoftmaxOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); - auto data_format = utils::GetAttrOrDefault(op_desc, "data_format", "AnyLayout"); + auto axis = utils::GetAttrOrDefault(op_desc, "axis", -1); + auto data_format = + utils::GetAttrOrDefault(op_desc, "data_format", "AnyLayout"); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto out = ctx.Builder()->Softmax(x, {axis}, data_format); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); @@ -39,6 +41,7 @@ void SoftmaxOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& } // namespace cinn CINN_REGISTER_HELPER(paddle_softmax) { - CINN_REGISTER_OP_MAPPER(softmax, cinn::frontend::paddle_mappers::SoftmaxOpMapper) + CINN_REGISTER_OP_MAPPER(softmax, + cinn::frontend::paddle_mappers::SoftmaxOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/squeeze.cc b/paddle/cinn/frontend/op_mappers/paddle/squeeze.cc index 430be863661dd..ee7606675bad4 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/squeeze.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/squeeze.cc @@ -19,10 +19,11 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void Squeeze2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void Squeeze2OpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto axes = utils::GetAttrOrDefault>(op_desc, "axes"); @@ -57,7 +58,8 @@ void Squeeze2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& } // namespace cinn CINN_REGISTER_HELPER(paddle_squeeze) { - CINN_REGISTER_OP_MAPPER(squeeze2, cinn::frontend::paddle_mappers::Squeeze2OpMapper) + CINN_REGISTER_OP_MAPPER(squeeze2, + cinn::frontend::paddle_mappers::Squeeze2OpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/strided_slice.cc b/paddle/cinn/frontend/op_mappers/paddle/strided_slice.cc index 318dc7e14a6f3..8a0a5765de555 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/strided_slice.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/strided_slice.cc @@ -20,26 +20,34 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void StridedSliceOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void StridedSliceOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("Input").size(), 1UL); auto x_name = op_desc.Input("Input").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); CHECK(op_desc.HasAttr("starts")); - auto starts = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "starts")); + auto starts = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "starts")); CHECK(op_desc.HasAttr("ends")); - auto ends = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "ends")); + auto ends = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "ends")); CHECK(op_desc.HasAttr("axes")); - auto axes = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "axes")); + auto axes = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "axes")); CHECK(op_desc.HasAttr("strides")); - auto strides = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "strides")); + auto strides = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "strides")); CHECK(op_desc.HasAttr("infer_flags")); - auto infer_flags = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "infer_flags")); - auto decrease_axis = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "decrease_axis")); + auto infer_flags = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "infer_flags")); + auto decrease_axis = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "decrease_axis")); - auto x = ctx.GetVar(x_name); - auto out = ctx.Builder()->Slice(x, axes, starts, ends, infer_flags, strides, decrease_axis); + auto x = ctx.GetVar(x_name); + auto out = ctx.Builder()->Slice( + x, axes, starts, ends, infer_flags, strides, decrease_axis); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); @@ -50,6 +58,7 @@ void StridedSliceOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperCont } // namespace cinn CINN_REGISTER_HELPER(paddle_strided_slice) { - CINN_REGISTER_OP_MAPPER(strided_slice, cinn::frontend::paddle_mappers::StridedSliceOpMapper) + CINN_REGISTER_OP_MAPPER(strided_slice, + cinn::frontend::paddle_mappers::StridedSliceOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/take_along_axis.cc b/paddle/cinn/frontend/op_mappers/paddle/take_along_axis.cc index 5c74cf55c10ca..deec9a0ee9c41 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/take_along_axis.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/take_along_axis.cc @@ -20,13 +20,14 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void TakeAlongAxis2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void TakeAlongAxis2OpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("Input").size(), 1UL); auto x_name = op_desc.Input("Input").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); CHECK_EQ(op_desc.Input("Index").size(), 1UL); auto index_name = op_desc.Input("Index").front(); - auto index = ctx.GetVar(index_name); + auto index = ctx.GetVar(index_name); auto axis = utils::GetAttrOrDefault(op_desc, "Axis"); @@ -47,6 +48,7 @@ void TakeAlongAxis2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperCo } // namespace cinn CINN_REGISTER_HELPER(paddle_take_along_axis) { - CINN_REGISTER_OP_MAPPER(take_along_axis, cinn::frontend::paddle_mappers::TakeAlongAxis2OpMapper) + CINN_REGISTER_OP_MAPPER( + take_along_axis, cinn::frontend::paddle_mappers::TakeAlongAxis2OpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/tile.cc b/paddle/cinn/frontend/op_mappers/paddle/tile.cc index 01f83bdccfbb2..4eb2db1d6f08d 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/tile.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/tile.cc @@ -20,7 +20,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void TileOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void TileOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { // input CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); @@ -29,7 +30,8 @@ void TileOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx auto out_name = op_desc.Output("Out").front(); // attr repeat_times - std::vector repeat_times = op_desc.GetAttr>("repeat_times"); + std::vector repeat_times = + op_desc.GetAttr>("repeat_times"); for (auto i : repeat_times) { CHECK_GT(i, 0) << "repeat_times's element must be greater than 0"; @@ -48,7 +50,8 @@ void TileOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx } CHECK_EQ(vec_x_dims.size(), repeat_times.size()) - << "vec_x_dims's size must be equal to repeat_times's size after promotion"; + << "vec_x_dims's size must be equal to repeat_times's size after " + "promotion"; // output's shape std::vector output_shape = vec_x_dims; @@ -60,9 +63,11 @@ void TileOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx VLOG(4) << "output_shape: " << cinn::utils::Join(output_shape, ","); - // NOTE(wuweilong): Paddle's tile is implemented by Eigen's broadcast directly, but CINN's tile can not be implemented - // by BroadcastTo directly, because it is different from Eigen's broadcast. The semantics of Eigen's broadcast is same - // as tile, but CINN can not use Eigen's broadcast. So we need to Combine Reshape and BroadcastTo to implement tile. + // NOTE(wuweilong): Paddle's tile is implemented by Eigen's broadcast + // directly, but CINN's tile can not be implemented by BroadcastTo directly, + // because it is different from Eigen's broadcast. The semantics of Eigen's + // broadcast is same as tile, but CINN can not use Eigen's broadcast. So we + // need to Combine Reshape and BroadcastTo to implement tile. // make a copy of vec_x_dims std::vector vec_x_dims_copy = vec_x_dims; @@ -81,7 +86,7 @@ void TileOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx } } - auto tmp = ctx.Builder()->BroadcastTo(x, vec_x_dims_copy); + auto tmp = ctx.Builder()->BroadcastTo(x, vec_x_dims_copy); auto output = ctx.Builder()->Reshape(tmp, output_shape); ctx.AddVar(out_name, output); diff --git a/paddle/cinn/frontend/op_mappers/paddle/top_k.cc b/paddle/cinn/frontend/op_mappers/paddle/top_k.cc index 10c610c9b5eb1..07ab1c12b8802 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/top_k.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/top_k.cc @@ -19,17 +19,18 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void TopKOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void TopKOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); CHECK_EQ(op_desc.Output("Indices").size(), 1UL); auto indices_name = op_desc.Output("Indices").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); CHECK(op_desc.HasAttr("k")); - auto k = utils::GetAttrOrDefault(op_desc, "k"); + auto k = utils::GetAttrOrDefault(op_desc, "k"); auto outs = ctx.Builder()->TopK(x, k, -1, true); ctx.AddVar(out_name, outs[0]); diff --git a/paddle/cinn/frontend/op_mappers/paddle/transpose.cc b/paddle/cinn/frontend/op_mappers/paddle/transpose.cc index bc79123d042c7..e835a930ad641 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/transpose.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/transpose.cc @@ -20,10 +20,11 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void TransposeOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void TransposeOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto axis = utils::GetAttrOrDefault>(op_desc, "axis"); @@ -38,10 +39,11 @@ void TransposeOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext ctx.AddVarModelToProgram(out_name, out->id); } -void Transpose2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void Transpose2OpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto axis = utils::GetAttrOrDefault>(op_desc, "axis"); @@ -76,7 +78,9 @@ void Transpose2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContex } // namespace cinn CINN_REGISTER_HELPER(paddle_transpose) { - CINN_REGISTER_OP_MAPPER(transpose, cinn::frontend::paddle_mappers::TransposeOpMapper) - CINN_REGISTER_OP_MAPPER(transpose2, cinn::frontend::paddle_mappers::Transpose2OpMapper) + CINN_REGISTER_OP_MAPPER(transpose, + cinn::frontend::paddle_mappers::TransposeOpMapper) + CINN_REGISTER_OP_MAPPER(transpose2, + cinn::frontend::paddle_mappers::Transpose2OpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/triangular_solve.cc b/paddle/cinn/frontend/op_mappers/paddle/triangular_solve.cc index e2d9f55d650d3..756a335d604dc 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/triangular_solve.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/triangular_solve.cc @@ -19,7 +19,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void TriangularSolveOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void TriangularSolveOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Y").size(), 1UL); @@ -28,15 +29,19 @@ void TriangularSolveOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperC auto out_name = op_desc.Output("Out").front(); constexpr bool left_side = true; - auto upper = utils::GetAttrOrDefault(op_desc, "upper", true); - auto transpose = utils::GetAttrOrDefault(op_desc, "transpose", false); - auto unitriangular = utils::GetAttrOrDefault(op_desc, "unitriangular", false); - VLOG(4) << "out_name = triangular_solve(" << x_name << ", left_side=" << left_side << ", upper=" << upper - << ", transpose=" << transpose << ", unitriangular=" << unitriangular << ")"; - - auto x = ctx.GetVar(x_name); - auto y = ctx.GetVar(y_name); - auto out = ctx.Builder()->TriangularSolve(x, y, left_side, upper, transpose, unitriangular); + auto upper = utils::GetAttrOrDefault(op_desc, "upper", true); + auto transpose = utils::GetAttrOrDefault(op_desc, "transpose", false); + auto unitriangular = + utils::GetAttrOrDefault(op_desc, "unitriangular", false); + VLOG(4) << "out_name = triangular_solve(" << x_name + << ", left_side=" << left_side << ", upper=" << upper + << ", transpose=" << transpose << ", unitriangular=" << unitriangular + << ")"; + + auto x = ctx.GetVar(x_name); + auto y = ctx.GetVar(y_name); + auto out = ctx.Builder()->TriangularSolve( + x, y, left_side, upper, transpose, unitriangular); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); @@ -47,6 +52,7 @@ void TriangularSolveOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperC } // namespace cinn CINN_REGISTER_HELPER(paddle_triangular_solve) { - CINN_REGISTER_OP_MAPPER(triangular_solve, cinn::frontend::paddle_mappers::TriangularSolveOpMapper) + CINN_REGISTER_OP_MAPPER( + triangular_solve, cinn::frontend::paddle_mappers::TriangularSolveOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/unary.cc b/paddle/cinn/frontend/op_mappers/paddle/unary.cc index c7ce01f0a169a..575b2d04eb188 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/unary.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/unary.cc @@ -19,17 +19,19 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -#define UNARY_OPMAPPER_FUNCTION(OP_NAME) \ - void OP_NAME##OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { \ - CHECK_EQ(op_desc.Input("X").size(), 1UL); \ - auto x_name = op_desc.Input("X").front(); \ - CHECK_EQ(op_desc.Output("Out").size(), 1UL); \ - auto out_name = op_desc.Output("Out").front(); \ - auto x = ctx.GetVar(x_name); \ - VLOG(4) << #OP_NAME << " X:" << x_name << "[" << cinn::utils::Join(x->shape, ",") << "] to Out:" << out_name; \ - auto out = ctx.Builder()->OP_NAME(x); \ - ctx.AddVar(out_name, out); \ - ctx.AddVarModelToProgram(out_name, out->id); \ +#define UNARY_OPMAPPER_FUNCTION(OP_NAME) \ + void OP_NAME##OpMapper(const paddle::cpp::OpDesc& op_desc, \ + const OpMapperContext& ctx) { \ + CHECK_EQ(op_desc.Input("X").size(), 1UL); \ + auto x_name = op_desc.Input("X").front(); \ + CHECK_EQ(op_desc.Output("Out").size(), 1UL); \ + auto out_name = op_desc.Output("Out").front(); \ + auto x = ctx.GetVar(x_name); \ + VLOG(4) << #OP_NAME << " X:" << x_name << "[" \ + << cinn::utils::Join(x->shape, ",") << "] to Out:" << out_name; \ + auto out = ctx.Builder()->OP_NAME(x); \ + ctx.AddVar(out_name, out); \ + ctx.AddVarModelToProgram(out_name, out->id); \ } UNARY_OPMAPPER_FUNCTION(LogicalNot) @@ -70,7 +72,8 @@ UNARY_OPMAPPER_FUNCTION(IsInf) CINN_REGISTER_HELPER(paddle_unary) { #define UNARY_OPMAPPER_REGISTER(PD_OP, CINN_OP) \ - CINN_REGISTER_OP_MAPPER(PD_OP, cinn::frontend::paddle_mappers::CINN_OP##OpMapper) + CINN_REGISTER_OP_MAPPER(PD_OP, \ + cinn::frontend::paddle_mappers::CINN_OP##OpMapper) UNARY_OPMAPPER_REGISTER(logical_not, LogicalNot) UNARY_OPMAPPER_REGISTER(bitwise_not, BitwiseNot) diff --git a/paddle/cinn/frontend/op_mappers/paddle/uniform_random.cc b/paddle/cinn/frontend/op_mappers/paddle/uniform_random.cc index bc842c5bd3ef9..c477072c9877f 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/uniform_random.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/uniform_random.cc @@ -20,28 +20,35 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void UniformRandomOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void UniformRandomOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto shape_origin = utils::GetAttrOrDefault>(op_desc, "shape"); - auto shape = utils::ToShapeType(shape_origin); + auto shape_origin = + utils::GetAttrOrDefault>(op_desc, "shape"); + auto shape = utils::ToShapeType(shape_origin); - auto min = utils::GetAttrOrDefault(op_desc, "min", -1.0f); - auto max = utils::GetAttrOrDefault(op_desc, "max", 1.0f); + auto min = utils::GetAttrOrDefault(op_desc, "min", -1.0f); + auto max = utils::GetAttrOrDefault(op_desc, "max", 1.0f); auto seed = utils::GetAttrOrDefault(op_desc, "seed", 0); - auto diag_num = utils::GetAttrOrDefault(op_desc, "diag_num", 0); + auto diag_num = utils::GetAttrOrDefault(op_desc, "diag_num", 0); auto diag_step = utils::GetAttrOrDefault(op_desc, "diag_step", 0); - auto diag_val = utils::GetAttrOrDefault(op_desc, "diag_val", 1.0f); + auto diag_val = utils::GetAttrOrDefault(op_desc, "diag_val", 1.0f); - auto dtype = utils::GetPaddleDtype(op_desc, "dtype", paddle::cpp::VarDescAPI::Type::FP32); - CHECK(!dtype.empty()) << "The op \"uniform_random\"'s attribute \"dtype\" should not be unknown type! Please check."; + auto dtype = utils::GetPaddleDtype( + op_desc, "dtype", paddle::cpp::VarDescAPI::Type::FP32); + CHECK(!dtype.empty()) << "The op \"uniform_random\"'s attribute \"dtype\" " + "should not be unknown type! Please check."; - VLOG(4) << out_name << "[" << cinn::utils::Join(shape, ", ") << "] = uniform_random(min=" << min << ", max=" << max - << ", seed=" << seed << ", dtype=" << dtype << ", shape=[" << cinn::utils::Join(shape, ", ") << "])"; + VLOG(4) << out_name << "[" << cinn::utils::Join(shape, ", ") + << "] = uniform_random(min=" << min << ", max=" << max + << ", seed=" << seed << ", dtype=" << dtype << ", shape=[" + << cinn::utils::Join(shape, ", ") << "])"; - auto out = ctx.Builder()->UniformRandom(shape, min, max, seed, dtype, diag_num, diag_step, diag_val); + auto out = ctx.Builder()->UniformRandom( + shape, min, max, seed, dtype, diag_num, diag_step, diag_val); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); } @@ -51,6 +58,7 @@ void UniformRandomOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperCon } // namespace cinn CINN_REGISTER_HELPER(paddle_uniform_random) { - CINN_REGISTER_OP_MAPPER(uniform_random, cinn::frontend::paddle_mappers::UniformRandomOpMapper) + CINN_REGISTER_OP_MAPPER(uniform_random, + cinn::frontend::paddle_mappers::UniformRandomOpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/unsqueeze.cc b/paddle/cinn/frontend/op_mappers/paddle/unsqueeze.cc index 90b480edecf0b..80ad42c3dd3cd 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/unsqueeze.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/unsqueeze.cc @@ -20,10 +20,11 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void UnSqueeze2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void UnSqueeze2OpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto axes = utils::GetAttrOrDefault>(op_desc, "axes"); @@ -58,6 +59,7 @@ void UnSqueeze2OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContex } // namespace cinn CINN_REGISTER_HELPER(paddle_unsqueeze) { - CINN_REGISTER_OP_MAPPER(unsqueeze2, cinn::frontend::paddle_mappers::UnSqueeze2OpMapper) + CINN_REGISTER_OP_MAPPER(unsqueeze2, + cinn::frontend::paddle_mappers::UnSqueeze2OpMapper) return true; } diff --git a/paddle/cinn/frontend/op_mappers/paddle/where.cc b/paddle/cinn/frontend/op_mappers/paddle/where.cc index c2e7386d5c94b..23b8798aeb0a2 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/where.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/where.cc @@ -19,7 +19,8 @@ namespace cinn { namespace frontend { namespace paddle_mappers { -void WhereOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void WhereOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("Condition").size(), 1UL); auto c_name = op_desc.Input("Condition").front(); CHECK_EQ(op_desc.Input("X").size(), 1UL); diff --git a/paddle/cinn/frontend/op_mappers/science/compare.cc b/paddle/cinn/frontend/op_mappers/science/compare.cc index 54c39029dbd11..5917890723e67 100644 --- a/paddle/cinn/frontend/op_mappers/science/compare.cc +++ b/paddle/cinn/frontend/op_mappers/science/compare.cc @@ -19,19 +19,20 @@ namespace cinn { namespace frontend { namespace science_mappers { -#define COMPARE_OPMAPPER_FUNCTION(OP_NAME) \ - void OP_NAME##OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { \ - CHECK_EQ(op_desc.Input("X").size(), 1UL); \ - auto x_name = op_desc.Input("X").front(); \ - CHECK_EQ(op_desc.Input("Y").size(), 1UL); \ - auto y_name = op_desc.Input("Y").front(); \ - CHECK_EQ(op_desc.Output("Z").size(), 1UL); \ - auto out_name = op_desc.Output("Z").front(); \ - auto x = ctx.GetVar(x_name); \ - auto y = ctx.GetVar(y_name); \ - auto out = ctx.Builder()->OP_NAME(x, y); \ - ctx.AddVar(out_name, out); \ - ctx.AddVarModelToProgram(out_name, out->id); \ +#define COMPARE_OPMAPPER_FUNCTION(OP_NAME) \ + void OP_NAME##OpMapper(const paddle::cpp::OpDesc& op_desc, \ + const OpMapperContext& ctx) { \ + CHECK_EQ(op_desc.Input("X").size(), 1UL); \ + auto x_name = op_desc.Input("X").front(); \ + CHECK_EQ(op_desc.Input("Y").size(), 1UL); \ + auto y_name = op_desc.Input("Y").front(); \ + CHECK_EQ(op_desc.Output("Z").size(), 1UL); \ + auto out_name = op_desc.Output("Z").front(); \ + auto x = ctx.GetVar(x_name); \ + auto y = ctx.GetVar(y_name); \ + auto out = ctx.Builder()->OP_NAME(x, y); \ + ctx.AddVar(out_name, out); \ + ctx.AddVarModelToProgram(out_name, out->id); \ } COMPARE_OPMAPPER_FUNCTION(GreaterThan) @@ -49,7 +50,8 @@ COMPARE_OPMAPPER_FUNCTION(NotEqual) CINN_REGISTER_HELPER(science_compare) { #define COMPARE_OPMAPPER_REGISTER(PD_OP, CINN_OP) \ - CINN_REGISTER_OP_MAPPER(PD_OP, cinn::frontend::science_mappers::CINN_OP##OpMapper) + CINN_REGISTER_OP_MAPPER(PD_OP, \ + cinn::frontend::science_mappers::CINN_OP##OpMapper) COMPARE_OPMAPPER_REGISTER(gt_p, GreaterThan) COMPARE_OPMAPPER_REGISTER(ge_p, GreaterEqual) diff --git a/paddle/cinn/frontend/op_mappers/science/math.cc b/paddle/cinn/frontend/op_mappers/science/math.cc index 843d6095b3b0e..078a74bed634c 100644 --- a/paddle/cinn/frontend/op_mappers/science/math.cc +++ b/paddle/cinn/frontend/op_mappers/science/math.cc @@ -19,20 +19,22 @@ namespace cinn { namespace frontend { namespace science_mappers { -#define BINARY_OPMAPPER(op_name) \ - void op_name##OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { \ - CHECK_EQ(op_desc.Input("X").size(), 1UL); \ - auto x_name = op_desc.Input("X").front(); \ - CHECK_EQ(op_desc.Input("Y").size(), 1UL); \ - auto y_name = op_desc.Input("Y").front(); \ - CHECK_EQ(op_desc.Output("Z").size(), 1UL); \ - auto out_name = op_desc.Output("Z").front(); \ - VLOG(3) << out_name << " = " << #op_name << "(" << x_name << ", " << y_name << ")"; \ - auto x = ctx.GetVar(x_name); \ - auto y = ctx.GetVar(y_name); \ - auto out = ctx.Builder()->op_name(x, y); \ - ctx.AddVar(out_name, out); \ - ctx.AddVarModelToProgram(out_name, out->id); \ +#define BINARY_OPMAPPER(op_name) \ + void op_name##OpMapper(const paddle::cpp::OpDesc& op_desc, \ + const OpMapperContext& ctx) { \ + CHECK_EQ(op_desc.Input("X").size(), 1UL); \ + auto x_name = op_desc.Input("X").front(); \ + CHECK_EQ(op_desc.Input("Y").size(), 1UL); \ + auto y_name = op_desc.Input("Y").front(); \ + CHECK_EQ(op_desc.Output("Z").size(), 1UL); \ + auto out_name = op_desc.Output("Z").front(); \ + VLOG(3) << out_name << " = " << #op_name << "(" << x_name << ", " \ + << y_name << ")"; \ + auto x = ctx.GetVar(x_name); \ + auto y = ctx.GetVar(y_name); \ + auto out = ctx.Builder()->op_name(x, y); \ + ctx.AddVar(out_name, out); \ + ctx.AddVarModelToProgram(out_name, out->id); \ } BINARY_OPMAPPER(Add) @@ -46,17 +48,18 @@ BINARY_OPMAPPER(Min) #undef BINARY_OPMAPPER -#define UNARY_OPMAPPER(op_name) \ - void op_name##OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { \ - CHECK_EQ(op_desc.Input("X").size(), 1UL); \ - auto x_name = op_desc.Input("X").front(); \ - CHECK_EQ(op_desc.Output("Y").size(), 1UL); \ - auto out_name = op_desc.Output("Y").front(); \ - VLOG(3) << out_name << " = " << #op_name << "(" << x_name << ")"; \ - auto x = ctx.GetVar(x_name); \ - auto out = ctx.Builder()->op_name(x); \ - ctx.AddVar(out_name, out); \ - ctx.AddVarModelToProgram(out_name, out->id); \ +#define UNARY_OPMAPPER(op_name) \ + void op_name##OpMapper(const paddle::cpp::OpDesc& op_desc, \ + const OpMapperContext& ctx) { \ + CHECK_EQ(op_desc.Input("X").size(), 1UL); \ + auto x_name = op_desc.Input("X").front(); \ + CHECK_EQ(op_desc.Output("Y").size(), 1UL); \ + auto out_name = op_desc.Output("Y").front(); \ + VLOG(3) << out_name << " = " << #op_name << "(" << x_name << ")"; \ + auto x = ctx.GetVar(x_name); \ + auto out = ctx.Builder()->op_name(x); \ + ctx.AddVar(out_name, out); \ + ctx.AddVarModelToProgram(out_name, out->id); \ } UNARY_OPMAPPER(Sqrt) @@ -78,7 +81,8 @@ UNARY_OPMAPPER(Abs) CINN_REGISTER_HELPER(science_math) { #define EXPAND_OP_MAPPER_REGISTER(psci_op, cinn_op) \ - CINN_REGISTER_OP_MAPPER(psci_op, cinn::frontend::science_mappers::cinn_op##OpMapper) + CINN_REGISTER_OP_MAPPER(psci_op, \ + cinn::frontend::science_mappers::cinn_op##OpMapper) EXPAND_OP_MAPPER_REGISTER(add_p, Add) EXPAND_OP_MAPPER_REGISTER(sub_p, Subtract) diff --git a/paddle/cinn/frontend/op_mappers/science/transform.cc b/paddle/cinn/frontend/op_mappers/science/transform.cc index 2048a37c01dbd..45faa1961790d 100644 --- a/paddle/cinn/frontend/op_mappers/science/transform.cc +++ b/paddle/cinn/frontend/op_mappers/science/transform.cc @@ -27,7 +27,8 @@ namespace science_mappers { using cinn::utils::ShapeType; -void ConcatOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ConcatOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_GE(op_desc.Input("XS").size(), 1UL); auto x_names = op_desc.Input("XS"); CHECK_EQ(op_desc.Output("Y").size(), 1UL); @@ -37,14 +38,15 @@ void ConcatOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c if (x_names.size() == 1) { // if concat only has one input, using Identity to copy the input and return auto x = ctx.GetVar(x_names.front()); - out = ctx.Builder()->Identity(x); + out = ctx.Builder()->Identity(x); } else { std::vector xs; for (const auto& name : x_names) { xs.emplace_back(ctx.GetVar(name)); } - auto axis = utils::ToDimType(utils::GetAttrOrDefault(op_desc, "axis", 0)); + auto axis = + utils::ToDimType(utils::GetAttrOrDefault(op_desc, "axis", 0)); out = ctx.Builder()->Concat(xs, axis); } @@ -53,52 +55,65 @@ void ConcatOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c ctx.AddVarModelToProgram(out_name, out->id); } -void SplitOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void SplitOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_GE(op_desc.Output("YS").size(), 1UL); auto out_name = op_desc.Output("YS"); CHECK(op_desc.HasAttr("num_or_sections")); - auto num_or_sections = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "num_or_sections")); + auto num_or_sections = + utils::ToShapeType(utils::GetAttrOrDefault>( + op_desc, "num_or_sections")); - CHECK(!num_or_sections.empty()) << "The Split op cannot found [num_or_sections] attrbute! ! Please check."; + CHECK(!num_or_sections.empty()) + << "The Split op cannot found [num_or_sections] attrbute! ! Please " + "check."; - auto axis = utils::ToDimType(utils::GetAttrOrDefault(op_desc, "axis", 0)); + auto axis = + utils::ToDimType(utils::GetAttrOrDefault(op_desc, "axis", 0)); auto x = ctx.GetVar(x_name); auto x_shape = x->shape; if (num_or_sections.size() == 1U) { CHECK_EQ(x_shape[axis] % num_or_sections[0], 0) - << "If the attribute 'num_or_sections' is a number, it should be divisible by the " + << "If the attribute 'num_or_sections' is a number, it should be " + "divisible by the " "axis's dimension of inputs A ! Please check."; } else { cinn::utils::DimType sec_sum = 0; - bool has_neg = false; + bool has_neg = false; for (auto sec : num_or_sections) { if (sec > 0) { sec_sum += sec; } else if (sec == -1 && !has_neg) { has_neg = true; } else if (sec == 0) { - LOG(FATAL) << "The attribute 'num_or_sections' of split should not has 0 ! Please check."; + LOG(FATAL) << "The attribute 'num_or_sections' of split should not has " + "0 ! Please check."; } else { - LOG(FATAL) << "The attribute 'num_or_sections' of split can only have at most one '-1' ! Please check."; + LOG(FATAL) << "The attribute 'num_or_sections' of split can only have " + "at most one '-1' ! Please check."; } } CHECK(!has_neg && sec_sum == x_shape[axis]) - << "The sum of attr sections should be equal with the axis's dimension value of " + << "The sum of attr sections should be equal with the axis's dimension " + "value of " "inputs A in Split ! Please check."; } - VLOG(4) << "Split " << x_name << " with shape (" << cinn::utils::Join(x->shape, ",") << ") " - << " to section (" << cinn::utils::Join(num_or_sections, ",") << ") at dimension " << axis; + VLOG(4) << "Split " << x_name << " with shape (" + << cinn::utils::Join(x->shape, ",") << ") " + << " to section (" << cinn::utils::Join(num_or_sections, ",") + << ") at dimension " << axis; auto out = ctx.Builder()->Split(x, num_or_sections, axis); - CHECK_EQ(out.size(), out_name.size()) << "The Split op should has " << out_name.size() << " output, but only " - << out.size(); + CHECK_EQ(out.size(), out_name.size()) + << "The Split op should has " << out_name.size() << " output, but only " + << out.size(); for (int i = 0; i < out.size(); ++i) { ctx.AddVar(out_name[i], out[i]); @@ -106,15 +121,18 @@ void SplitOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ct } } -void ReshapeOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ReshapeOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - auto shape = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "shape")); + auto shape = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "shape")); auto x = ctx.GetVar(x_name); - VLOG(4) << "Reshape " << x_name << "from shape (" << cinn::utils::Join(x->shape, ",") << ") to (" + VLOG(4) << "Reshape " << x_name << "from shape (" + << cinn::utils::Join(x->shape, ",") << ") to (" << cinn::utils::Join(shape, ",") << ")."; auto out = ctx.Builder()->Reshape(x, shape); @@ -125,7 +143,8 @@ void ReshapeOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx.AddVarModelToProgram(out_name, out->id); } -void TransposeOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void TransposeOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Y").size(), 1UL); @@ -134,7 +153,8 @@ void TransposeOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext auto x = ctx.GetVar(x_name); CHECK(x->shape.size() == 2) << "Now transpose_p only support 2-dim matrix."; - VLOG(4) << "Transpose " << x_name << " with shape (" << cinn::utils::Join(x->shape, ",") << ")."; + VLOG(4) << "Transpose " << x_name << " with shape (" + << cinn::utils::Join(x->shape, ",") << ")."; auto out = ctx.Builder()->Transpose(x, {1, 0}); @@ -142,26 +162,34 @@ void TransposeOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext ctx.AddVarModelToProgram(out_name, out->id); } -void SliceSelectOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void SliceSelectOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Y").size(), 1UL); auto out_name = op_desc.Output("Y").front(); CHECK(op_desc.HasAttr("starts")); - auto starts = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "starts")); + auto starts = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "starts")); CHECK(op_desc.HasAttr("ends")); - auto ends = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "ends")); + auto ends = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "ends")); CHECK(op_desc.HasAttr("axis")); - auto axes = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "axis")); + auto axes = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "axis")); CHECK(op_desc.HasAttr("strides")); - auto strides = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "strides")); + auto strides = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "strides")); auto x = ctx.GetVar(x_name); - VLOG(4) << "SliceSelect " << x_name << " from shape (" << cinn::utils::Join(x->shape, ",") << ") with starts [" - << cinn::utils::Join(starts, ",") << "], ends [" << cinn::utils::Join(ends, ",") << "], axis [" - << cinn::utils::Join(axes, ",") << "], strides [" << cinn::utils::Join(strides, ",") << "]."; + VLOG(4) << "SliceSelect " << x_name << " from shape (" + << cinn::utils::Join(x->shape, ",") << ") with starts [" + << cinn::utils::Join(starts, ",") << "], ends [" + << cinn::utils::Join(ends, ",") << "], axis [" + << cinn::utils::Join(axes, ",") << "], strides [" + << cinn::utils::Join(strides, ",") << "]."; auto out = ctx.Builder()->Slice(x, axes, starts, ends, ShapeType{}, strides); @@ -169,7 +197,8 @@ void SliceSelectOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperConte ctx.AddVarModelToProgram(out_name, out->id); } -void SliceAssignOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void SliceAssignOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Y").size(), 1UL); @@ -178,20 +207,27 @@ void SliceAssignOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperConte auto out_name = op_desc.Output("Z").front(); CHECK(op_desc.HasAttr("starts")); - auto starts = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "starts")); + auto starts = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "starts")); CHECK(op_desc.HasAttr("ends")); - auto ends = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "ends")); + auto ends = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "ends")); CHECK(op_desc.HasAttr("axis")); - auto axes = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "axis")); + auto axes = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "axis")); CHECK(op_desc.HasAttr("strides")); - auto strides = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "strides")); + auto strides = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "strides")); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto assign = ctx.GetVar(y_name); - VLOG(4) << "SliceAssign " << x_name << " from shape (" << cinn::utils::Join(x->shape, ",") << ") with starts [" - << cinn::utils::Join(starts, ",") << "], ends [" << cinn::utils::Join(ends, ",") << "], axis [" - << cinn::utils::Join(axes, ",") << "], strides [" << cinn::utils::Join(strides, ",") << "]."; + VLOG(4) << "SliceAssign " << x_name << " from shape (" + << cinn::utils::Join(x->shape, ",") << ") with starts [" + << cinn::utils::Join(starts, ",") << "], ends [" + << cinn::utils::Join(ends, ",") << "], axis [" + << cinn::utils::Join(axes, ",") << "], strides [" + << cinn::utils::Join(strides, ",") << "]."; absl::optional out; if (x->shape == assign->shape) { @@ -204,19 +240,23 @@ void SliceAssignOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperConte ctx.AddVarModelToProgram(out_name, out.value()->id); } -void ReduceOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx, const std::string& reduce_type) { +void ReduceOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx, + const std::string& reduce_type) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Y").size(), 1UL); auto out_name = op_desc.Output("Y").front(); - auto axis = utils::ToShapeType(utils::GetAttrOrDefault>(op_desc, "axis")); + auto axis = utils::ToShapeType( + utils::GetAttrOrDefault>(op_desc, "axis")); auto keepdim = utils::GetAttrOrDefault(op_desc, "keepdim", false); auto x = ctx.GetVar(x_name); - VLOG(4) << "Reudce " << reduce_type << " x:" << x_name << " from shape (" << cinn::utils::Join(x->shape, ",") - << "), with axis [" << cinn::utils::Join(axis, ",") << "], keepdim " << keepdim; + VLOG(4) << "Reudce " << reduce_type << " x:" << x_name << " from shape (" + << cinn::utils::Join(x->shape, ",") << "), with axis [" + << cinn::utils::Join(axis, ",") << "], keepdim " << keepdim; // now paddle science only need reduce sum absl::optional out; @@ -240,9 +280,10 @@ void ReduceOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c ctx.AddVarModelToProgram(out_name, out.value()->id); } -#define EXPAND_REDUCE_OPMAPPER(ReduceType) \ - void Reduce##ReduceType##OpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { \ - ReduceOpMapper(op_desc, ctx, #ReduceType); \ +#define EXPAND_REDUCE_OPMAPPER(ReduceType) \ + void Reduce##ReduceType##OpMapper(const paddle::cpp::OpDesc& op_desc, \ + const OpMapperContext& ctx) { \ + ReduceOpMapper(op_desc, ctx, #ReduceType); \ } EXPAND_REDUCE_OPMAPPER(Sum) @@ -253,7 +294,8 @@ EXPAND_REDUCE_OPMAPPER(All) EXPAND_REDUCE_OPMAPPER(Any) #undef EXPAND_REDUCE_OPMAPPER -void GatherOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void GatherOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("IndexTensor").size(), 1UL); @@ -261,12 +303,14 @@ void GatherOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c CHECK_EQ(op_desc.Output("Y").size(), 1UL); auto out_name = op_desc.Output("Y").front(); - auto axis = utils::ToDimType(utils::GetAttrOrDefault(op_desc, "axis", 0)); + auto axis = + utils::ToDimType(utils::GetAttrOrDefault(op_desc, "axis", 0)); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto index = ctx.GetVar(index_name); - VLOG(4) << "Gather " << index_name << " (" << cinn::utils::Join(index->shape, ",") << ") from " << x_name + VLOG(4) << "Gather " << index_name << " (" + << cinn::utils::Join(index->shape, ",") << ") from " << x_name << " shape (" << cinn::utils::Join(x->shape, ",") << ") " << "at dimension " << axis; @@ -276,7 +320,8 @@ void GatherOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c ctx.AddVarModelToProgram(out_name, out->id); } -void IndexAssignOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void IndexAssignOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Y").size(), 1UL); @@ -286,15 +331,17 @@ void IndexAssignOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperConte CHECK_EQ(op_desc.Output("Z").size(), 1UL); auto out_name = op_desc.Output("Z").front(); - auto axis = utils::ToDimType(utils::GetAttrOrDefault(op_desc, "axis", 0)); + auto axis = + utils::ToDimType(utils::GetAttrOrDefault(op_desc, "axis", 0)); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto updates = ctx.GetVar(updates_name); - auto index = ctx.GetVar(index_name); + auto index = ctx.GetVar(index_name); auto out = ctx.Builder()->ScatterAssign(x, updates, index, axis); - VLOG(4) << "IndexAssign " << updates_name << " (" << cinn::utils::Join(updates->shape, ",") << ") to " << x_name + VLOG(4) << "IndexAssign " << updates_name << " (" + << cinn::utils::Join(updates->shape, ",") << ") to " << x_name << " shape (" << cinn::utils::Join(x->shape, ",") << ") " << "at dimension " << axis; @@ -302,7 +349,8 @@ void IndexAssignOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperConte ctx.AddVarModelToProgram(out_name, out->id); } -void ScatterAddOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void ScatterAddOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Y").size(), 1UL); @@ -312,15 +360,17 @@ void ScatterAddOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContex CHECK_EQ(op_desc.Output("Z").size(), 1UL); auto out_name = op_desc.Output("Z").front(); - auto axis = utils::ToDimType(utils::GetAttrOrDefault(op_desc, "axis", 0)); + auto axis = + utils::ToDimType(utils::GetAttrOrDefault(op_desc, "axis", 0)); - auto x = ctx.GetVar(x_name); + auto x = ctx.GetVar(x_name); auto updates = ctx.GetVar(updates_name); - auto index = ctx.GetVar(index_name); + auto index = ctx.GetVar(index_name); auto out = ctx.Builder()->ScatterAdd(x, updates, index, axis); - VLOG(4) << "ScatterAdd " << updates_name << " (" << cinn::utils::Join(updates->shape, ",") << ") to " << x_name + VLOG(4) << "ScatterAdd " << updates_name << " (" + << cinn::utils::Join(updates->shape, ",") << ") to " << x_name << " shape (" << cinn::utils::Join(x->shape, ",") << ") " << "at dimension " << axis; @@ -328,7 +378,8 @@ void ScatterAddOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContex ctx.AddVarModelToProgram(out_name, out->id); } -void SelectOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void SelectOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("Condition").size(), 1UL); auto cond_name = op_desc.Input("Condition").front(); CHECK_EQ(op_desc.Input("X").size(), 1UL); @@ -341,15 +392,16 @@ void SelectOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c VLOG(4) << cond_name << " ? " << x_name << " : " << y_name; auto cond = ctx.GetVar(cond_name); - auto x = ctx.GetVar(x_name); - auto y = ctx.GetVar(y_name); - auto out = ctx.Builder()->Select(cond, x, y); + auto x = ctx.GetVar(x_name); + auto y = ctx.GetVar(y_name); + auto out = ctx.Builder()->Select(cond, x, y); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); } -void CastOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void CastOpMapper(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Y").size(), 1UL); @@ -357,10 +409,11 @@ void CastOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx auto x = ctx.GetVar(x_name); - auto dtype_id = utils::GetAttrOrDefault(op_desc, "dtype", static_cast(paddle::cpp::VarDescAPI::Type::FP32)); + auto dtype_id = utils::GetAttrOrDefault( + op_desc, "dtype", static_cast(paddle::cpp::VarDescAPI::Type::FP32)); auto dtype_pd = static_cast(dtype_id); auto dtype_cinn = utils::CppVarType2CommonType(dtype_pd); - auto dtype = common::Type2Str(dtype_cinn); + auto dtype = common::Type2Str(dtype_cinn); VLOG(4) << out_name << " = cast(" << x_name << ", dtype=" << dtype << ")"; @@ -375,22 +428,35 @@ void CastOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx } // namespace cinn CINN_REGISTER_HELPER(science_transform) { - CINN_REGISTER_OP_MAPPER(concat_p, cinn::frontend::science_mappers::ConcatOpMapper) - CINN_REGISTER_OP_MAPPER(split_p, cinn::frontend::science_mappers::SplitOpMapper) - CINN_REGISTER_OP_MAPPER(reshape_p, cinn::frontend::science_mappers::ReshapeOpMapper) - CINN_REGISTER_OP_MAPPER(transpose_p, cinn::frontend::science_mappers::TransposeOpMapper) - CINN_REGISTER_OP_MAPPER(slice_select_p, cinn::frontend::science_mappers::SliceSelectOpMapper) - CINN_REGISTER_OP_MAPPER(slice_assign_p, cinn::frontend::science_mappers::SliceAssignOpMapper) - CINN_REGISTER_OP_MAPPER(index_select_p, cinn::frontend::science_mappers::GatherOpMapper) - CINN_REGISTER_OP_MAPPER(gather_p, cinn::frontend::science_mappers::GatherOpMapper) - CINN_REGISTER_OP_MAPPER(index_assign_p, cinn::frontend::science_mappers::IndexAssignOpMapper) - CINN_REGISTER_OP_MAPPER(scatter_add_p, cinn::frontend::science_mappers::ScatterAddOpMapper) - CINN_REGISTER_OP_MAPPER(reduce_p, cinn::frontend::science_mappers::ReduceSumOpMapper) - CINN_REGISTER_OP_MAPPER(select_p, cinn::frontend::science_mappers::SelectOpMapper) + CINN_REGISTER_OP_MAPPER(concat_p, + cinn::frontend::science_mappers::ConcatOpMapper) + CINN_REGISTER_OP_MAPPER(split_p, + cinn::frontend::science_mappers::SplitOpMapper) + CINN_REGISTER_OP_MAPPER(reshape_p, + cinn::frontend::science_mappers::ReshapeOpMapper) + CINN_REGISTER_OP_MAPPER(transpose_p, + cinn::frontend::science_mappers::TransposeOpMapper) + CINN_REGISTER_OP_MAPPER(slice_select_p, + cinn::frontend::science_mappers::SliceSelectOpMapper) + CINN_REGISTER_OP_MAPPER(slice_assign_p, + cinn::frontend::science_mappers::SliceAssignOpMapper) + CINN_REGISTER_OP_MAPPER(index_select_p, + cinn::frontend::science_mappers::GatherOpMapper) + CINN_REGISTER_OP_MAPPER(gather_p, + cinn::frontend::science_mappers::GatherOpMapper) + CINN_REGISTER_OP_MAPPER(index_assign_p, + cinn::frontend::science_mappers::IndexAssignOpMapper) + CINN_REGISTER_OP_MAPPER(scatter_add_p, + cinn::frontend::science_mappers::ScatterAddOpMapper) + CINN_REGISTER_OP_MAPPER(reduce_p, + cinn::frontend::science_mappers::ReduceSumOpMapper) + CINN_REGISTER_OP_MAPPER(select_p, + cinn::frontend::science_mappers::SelectOpMapper) CINN_REGISTER_OP_MAPPER(cast_p, cinn::frontend::science_mappers::CastOpMapper) #define EXPAND_REDUCE_OP_MAPPER_REGISTER(op_name, ReduceType) \ - CINN_REGISTER_OP_MAPPER(op_name, cinn::frontend::science_mappers::Reduce##ReduceType##OpMapper) + CINN_REGISTER_OP_MAPPER( \ + op_name, cinn::frontend::science_mappers::Reduce##ReduceType##OpMapper) EXPAND_REDUCE_OP_MAPPER_REGISTER(reduce_sum_p, Sum) EXPAND_REDUCE_OP_MAPPER_REGISTER(reduce_prod_p, Prod) diff --git a/paddle/cinn/frontend/optimize.cc b/paddle/cinn/frontend/optimize.cc index 393f5b35ea0b4..3387d32612c7b 100644 --- a/paddle/cinn/frontend/optimize.cc +++ b/paddle/cinn/frontend/optimize.cc @@ -56,7 +56,8 @@ OptimizeOptions DefaultTrainingOptimizeOptions() { auto can_find_custom_call_deny_op = [](const std::string& op) { return FLAGS_cinn_custom_call_deny_ops.find(op) != std::string::npos; }; - bool is_gemm_use_cublas = FLAGS_cinn_use_custom_call && !can_find_custom_call_deny_op("matmul") && + bool is_gemm_use_cublas = FLAGS_cinn_use_custom_call && + !can_find_custom_call_deny_op("matmul") && !can_find_custom_call_deny_op("cublas_gemm") && !can_find_custom_call_deny_op("cublas_matmul"); if (is_gemm_use_cublas) { @@ -105,9 +106,10 @@ OptimizeOptions DefaultTrainingOptimizeOptions() { #endif // WARNING: the pass must be the last pass !!! - if (!cinn::runtime::CheckStringFlagFalse(FLAGS_cinn_check_fusion_accuracy_pass)) { - // Check the correct of fusion kernels, if the results not satisfied 'allclose(rtol=1e-05f, atol=1e-08f)', report - // error and exited. + if (!cinn::runtime::CheckStringFlagFalse( + FLAGS_cinn_check_fusion_accuracy_pass)) { + // Check the correct of fusion kernels, if the results not satisfied + // 'allclose(rtol=1e-05f, atol=1e-08f)', report error and exited. options.graph_passes.emplace_back("CheckFusionAccuracyPass"); options.graph_passes.emplace_back("TransToCustomCallPass"); } @@ -122,16 +124,19 @@ std::vector DefaultOpFusionPasses() { return passes; } -std::shared_ptr Optimize(frontend::Program* program, - const std::unordered_set& fetch_ids, - common::Target target, - const OptimizeOptions& options) { +std::shared_ptr Optimize( + frontend::Program* program, + const std::unordered_set& fetch_ids, + common::Target target, + const OptimizeOptions& options) { cinn::hlir::framework::PassPrinter::GetInstance()->Begin(fetch_ids); // Apply program passes VLOG(3) << "Before frontend::ProgramPass::Apply"; - frontend::ProgramPass::Apply(program, fetch_ids, target, options.program_passes); + frontend::ProgramPass::Apply( + program, fetch_ids, target, options.program_passes); // Apply graph passes - auto graph = std::make_shared(*program, fetch_ids, target); + auto graph = + std::make_shared(*program, fetch_ids, target); VLOG(3) << "Before hlir::framework::ApplyPasses"; hlir::framework::ApplyPasses(graph.get(), options.graph_passes); @@ -139,17 +144,19 @@ std::shared_ptr Optimize(frontend::Program* program, return graph; } -std::shared_ptr Optimize(frontend::Program* program, - const std::unordered_set& fetch_ids, - common::Target target, - const std::vector& passes) { +std::shared_ptr Optimize( + frontend::Program* program, + const std::unordered_set& fetch_ids, + common::Target target, + const std::vector& passes) { OptimizeOptions options; bool enbale_fusion = false; if (!passes.empty()) { for (const auto& pass : passes) { auto* p_pass = ProgramPassRegistry::Global()->Find(pass); - auto* g_pass = Registry::Global()->Find(pass); + auto* g_pass = + Registry::Global()->Find(pass); if (p_pass) { options.program_passes.emplace_back(pass); } else if (g_pass) { @@ -158,7 +165,8 @@ std::shared_ptr Optimize(frontend::Program* program, enbale_fusion = true; } } else { - LOG(FATAL) << "Pass " << pass << " unsupported in CINN! Please check.\n"; + LOG(FATAL) << "Pass " << pass + << " unsupported in CINN! Please check.\n"; } } diff --git a/paddle/cinn/frontend/optimize.h b/paddle/cinn/frontend/optimize.h index 2646a52359c91..543c027308d7b 100755 --- a/paddle/cinn/frontend/optimize.h +++ b/paddle/cinn/frontend/optimize.h @@ -35,15 +35,17 @@ OptimizeOptions DefaultTrainingOptimizeOptions(); std::vector DefaultOpFusionPasses(); -std::shared_ptr Optimize(frontend::Program* program, - const std::unordered_set& fetch_ids, - common::Target target, - const OptimizeOptions& options = DefaultTrainingOptimizeOptions()); - -std::shared_ptr Optimize(frontend::Program* program, - const std::unordered_set& fetch_ids, - common::Target target, - const std::vector& passes); +std::shared_ptr Optimize( + frontend::Program* program, + const std::unordered_set& fetch_ids, + common::Target target, + const OptimizeOptions& options = DefaultTrainingOptimizeOptions()); + +std::shared_ptr Optimize( + frontend::Program* program, + const std::unordered_set& fetch_ids, + common::Target target, + const std::vector& passes); } // namespace frontend } // namespace cinn diff --git a/paddle/cinn/frontend/paddle/compatible_pb.cc b/paddle/cinn/frontend/paddle/compatible_pb.cc index 93f392e86bd97..d7cf6c056d6dc 100644 --- a/paddle/cinn/frontend/paddle/compatible_pb.cc +++ b/paddle/cinn/frontend/paddle/compatible_pb.cc @@ -26,20 +26,22 @@ namespace cinn::frontend::paddle { namespace framework_proto = ::cinn::frontend::paddle::proto; /// For VarDesc transfrom -#define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \ - template <> \ - void TransformVarDescCppToAny(const cpp::VarDesc &cpp_desc, T *any_desc) { \ - any_desc->SetName(cpp_desc.Name()); \ - any_desc->SetType(cpp_desc.GetType()); \ - any_desc->SetPersistable(cpp_desc.Persistable()); \ - if (cpp_desc.Name() != "feed" && cpp_desc.Name() != "fetch") { \ - any_desc->SetShape(cpp_desc.GetShape()); \ - any_desc->SetDataType(cpp_desc.GetDataType()); \ - } \ +#define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \ + template <> \ + void TransformVarDescCppToAny(const cpp::VarDesc &cpp_desc, \ + T *any_desc) { \ + any_desc->SetName(cpp_desc.Name()); \ + any_desc->SetType(cpp_desc.GetType()); \ + any_desc->SetPersistable(cpp_desc.Persistable()); \ + if (cpp_desc.Name() != "feed" && cpp_desc.Name() != "fetch") { \ + any_desc->SetShape(cpp_desc.GetShape()); \ + any_desc->SetDataType(cpp_desc.GetDataType()); \ + } \ } template <> -void TransformVarDescAnyToCpp(const pb::VarDesc &any_desc, cpp::VarDesc *cpp_desc) { +void TransformVarDescAnyToCpp(const pb::VarDesc &any_desc, + cpp::VarDesc *cpp_desc) { cpp_desc->SetName(any_desc.Name()); cpp_desc->SetType(any_desc.GetType()); cpp_desc->SetPersistable(any_desc.Persistable()); @@ -81,34 +83,41 @@ void OpOutputsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { template void OpAttrsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) { using AttrType = cpp::OpDescAPI::AttrType; - auto set_attr = [&](const std::string &name, AttrType type) { + auto set_attr = [&](const std::string &name, AttrType type) { switch (type) { case AttrType::INT: - cpp_desc->SetAttr(name, any_desc.template GetAttr(name)); + cpp_desc->SetAttr(name, + any_desc.template GetAttr(name)); break; case AttrType::FLOAT: cpp_desc->SetAttr(name, any_desc.template GetAttr(name)); break; case AttrType::STRING: - cpp_desc->SetAttr(name, any_desc.template GetAttr(name)); + cpp_desc->SetAttr( + name, any_desc.template GetAttr(name)); break; case AttrType::LONG: - cpp_desc->SetAttr(name, any_desc.template GetAttr(name)); + cpp_desc->SetAttr(name, + any_desc.template GetAttr(name)); break; case AttrType::INTS: - cpp_desc->SetAttr>(name, any_desc.template GetAttr>(name)); + cpp_desc->SetAttr>( + name, any_desc.template GetAttr>(name)); break; case AttrType::FLOATS: - cpp_desc->SetAttr>(name, any_desc.template GetAttr>(name)); + cpp_desc->SetAttr>( + name, any_desc.template GetAttr>(name)); break; case AttrType::BOOLEAN: cpp_desc->SetAttr(name, any_desc.template GetAttr(name)); break; case AttrType::STRINGS: - cpp_desc->SetAttr>(name, any_desc.template GetAttr>(name)); + cpp_desc->SetAttr>( + name, any_desc.template GetAttr>(name)); break; case AttrType::LONGS: - cpp_desc->SetAttr>(name, any_desc.template GetAttr>(name)); + cpp_desc->SetAttr>( + name, any_desc.template GetAttr>(name)); break; case AttrType::BLOCK: { auto i = any_desc.template GetAttr(name); @@ -132,7 +141,7 @@ void OpAttrsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) { template void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { using AttrType = cpp::OpDescAPI::AttrType; - auto set_attr = [&](const std::string &name, AttrType type) { + auto set_attr = [&](const std::string &name, AttrType type) { switch (type) { #define IMPL_ONE(type__, T) \ case AttrType::type__: \ @@ -176,81 +185,90 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { } /// For BlockDesc transform -#define TRANS_BLOCK_ANY_WITH_CPP_IMPL(T, NT) \ - template <> \ - void TransformBlockDescAnyToCpp(const NT::T &any_desc, cpp::BlockDesc *cpp_desc) { \ - NT::T desc = any_desc; \ - cpp_desc->SetIdx(desc.Idx()); \ - cpp_desc->SetParentIdx(desc.ParentIdx()); \ - cpp_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \ - \ - cpp_desc->ClearOps(); \ - for (size_t i = 0; i < desc.OpsSize(); ++i) { \ - auto any_op_desc = NT::OpDesc(desc.GetOp(i)); \ - auto *cpp_op_desc = cpp_desc->AddOp(); \ - TransformOpDescAnyToCpp(any_op_desc, cpp_op_desc); \ - } \ - \ - cpp_desc->ClearVars(); \ - for (size_t i = 0; i < desc.VarsSize(); ++i) { \ - auto any_var_desc = NT::VarDesc(desc.GetVar(i)); \ - auto *cpp_var_desc = cpp_desc->AddVar(); \ - TransformVarDescAnyToCpp(any_var_desc, cpp_var_desc); \ - } \ - } \ - \ - template <> \ - void TransformBlockDescCppToAny(const cpp::T &cpp_desc, NT::T *any_desc) { \ - auto desc = cpp_desc; \ - any_desc->SetIdx(desc.Idx()); \ - any_desc->SetParentIdx(desc.ParentIdx()); \ - any_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \ - \ - any_desc->ClearOps(); \ - for (size_t i = 0; i < desc.OpsSize(); ++i) { \ - auto *cpp_op_desc = desc.GetOp(i); \ - auto any_op_desc = NT::OpDesc(any_desc->AddOp()); \ - TransformOpDescCppToAny(*cpp_op_desc, &any_op_desc); \ - } \ - \ - any_desc->ClearVars(); \ - for (size_t i = 0; i < desc.VarsSize(); ++i) { \ - auto *cpp_var_desc = desc.GetVar(i); \ - auto any_var_desc = NT::VarDesc(any_desc->AddVar()); \ - TransformVarDescCppToAny(*cpp_var_desc, &any_var_desc); \ - } \ +#define TRANS_BLOCK_ANY_WITH_CPP_IMPL(T, NT) \ + template <> \ + void TransformBlockDescAnyToCpp(const NT::T &any_desc, \ + cpp::BlockDesc *cpp_desc) { \ + NT::T desc = any_desc; \ + cpp_desc->SetIdx(desc.Idx()); \ + cpp_desc->SetParentIdx(desc.ParentIdx()); \ + cpp_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \ + \ + cpp_desc->ClearOps(); \ + for (size_t i = 0; i < desc.OpsSize(); ++i) { \ + auto any_op_desc = NT::OpDesc(desc.GetOp(i)); \ + auto *cpp_op_desc = cpp_desc->AddOp(); \ + TransformOpDescAnyToCpp(any_op_desc, cpp_op_desc); \ + } \ + \ + cpp_desc->ClearVars(); \ + for (size_t i = 0; i < desc.VarsSize(); ++i) { \ + auto any_var_desc = \ + NT::VarDesc(desc.GetVar(i)); \ + auto *cpp_var_desc = cpp_desc->AddVar(); \ + TransformVarDescAnyToCpp(any_var_desc, cpp_var_desc); \ + } \ + } \ + \ + template <> \ + void TransformBlockDescCppToAny(const cpp::T &cpp_desc, \ + NT::T *any_desc) { \ + auto desc = cpp_desc; \ + any_desc->SetIdx(desc.Idx()); \ + any_desc->SetParentIdx(desc.ParentIdx()); \ + any_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \ + \ + any_desc->ClearOps(); \ + for (size_t i = 0; i < desc.OpsSize(); ++i) { \ + auto *cpp_op_desc = desc.GetOp(i); \ + auto any_op_desc = \ + NT::OpDesc(any_desc->AddOp()); \ + TransformOpDescCppToAny(*cpp_op_desc, &any_op_desc); \ + } \ + \ + any_desc->ClearVars(); \ + for (size_t i = 0; i < desc.VarsSize(); ++i) { \ + auto *cpp_var_desc = desc.GetVar(i); \ + auto any_var_desc = \ + NT::VarDesc(any_desc->AddVar()); \ + TransformVarDescCppToAny(*cpp_var_desc, &any_var_desc); \ + } \ } /// For ProgramDesc transform -#define TRANS_PROGRAM_ANY_WITH_CPP_IMPL(T, NT) \ - template <> \ - void TransformProgramDescAnyToCpp(const NT::T &any_desc, cpp::ProgramDesc *cpp_desc) { \ - NT::T desc = any_desc; \ - if (desc.HasVersion()) { \ - cpp_desc->SetVersion(desc.Version()); \ - } \ - \ - cpp_desc->ClearBlocks(); \ - for (size_t i = 0; i < desc.BlocksSize(); ++i) { \ - auto any_block_desc = NT::BlockDesc(desc.GetBlock(i)); \ - auto *cpp_block_desc = cpp_desc->AddBlock(); \ - TransformBlockDescAnyToCpp(any_block_desc, cpp_block_desc); \ - } \ - } \ - \ - template <> \ - void TransformProgramDescCppToAny(const cpp::T &cpp_desc, NT::T *any_desc) { \ - auto desc = cpp_desc; \ - if (desc.HasVersion()) { \ - any_desc->SetVersion(desc.Version()); \ - } \ - \ - any_desc->ClearBlocks(); \ - for (size_t i = 0; i < desc.BlocksSize(); ++i) { \ - auto *cpp_block_desc = desc.GetBlock(i); \ - auto any_block_desc = NT::BlockDesc(any_desc->AddBlock()); \ - TransformBlockDescCppToAny(*cpp_block_desc, &any_block_desc); \ - } \ +#define TRANS_PROGRAM_ANY_WITH_CPP_IMPL(T, NT) \ + template <> \ + void TransformProgramDescAnyToCpp(const NT::T &any_desc, \ + cpp::ProgramDesc *cpp_desc) { \ + NT::T desc = any_desc; \ + if (desc.HasVersion()) { \ + cpp_desc->SetVersion(desc.Version()); \ + } \ + \ + cpp_desc->ClearBlocks(); \ + for (size_t i = 0; i < desc.BlocksSize(); ++i) { \ + auto any_block_desc = \ + NT::BlockDesc(desc.GetBlock(i)); \ + auto *cpp_block_desc = cpp_desc->AddBlock(); \ + TransformBlockDescAnyToCpp(any_block_desc, cpp_block_desc); \ + } \ + } \ + \ + template <> \ + void TransformProgramDescCppToAny(const cpp::T &cpp_desc, \ + NT::T *any_desc) { \ + auto desc = cpp_desc; \ + if (desc.HasVersion()) { \ + any_desc->SetVersion(desc.Version()); \ + } \ + \ + any_desc->ClearBlocks(); \ + for (size_t i = 0; i < desc.BlocksSize(); ++i) { \ + auto *cpp_block_desc = desc.GetBlock(i); \ + auto any_block_desc = \ + NT::BlockDesc(any_desc->AddBlock()); \ + TransformBlockDescCppToAny(*cpp_block_desc, &any_block_desc); \ + } \ } TRANS_VAR_ANY_WITH_CPP_IMPL(pb::VarDesc); diff --git a/paddle/cinn/frontend/paddle/compatible_pb.h b/paddle/cinn/frontend/paddle/compatible_pb.h index e8478d555c392..130c9a2de33a4 100644 --- a/paddle/cinn/frontend/paddle/compatible_pb.h +++ b/paddle/cinn/frontend/paddle/compatible_pb.h @@ -23,11 +23,13 @@ namespace cinn::frontend::paddle { /// Transform an VarDesc from VarDescType to cpp format. template -void TransformVarDescAnyToCpp(const VarDescType& any_desc, cpp::VarDesc* cpp_desc); +void TransformVarDescAnyToCpp(const VarDescType& any_desc, + cpp::VarDesc* cpp_desc); /// Transform an VarDesc from cpp to VarDescType format. template -void TransformVarDescCppToAny(const cpp::VarDesc& cpp_desc, VarDescType* any_desc); +void TransformVarDescCppToAny(const cpp::VarDesc& cpp_desc, + VarDescType* any_desc); /// Transform an OpDesc from OpDescType to cpp format. template @@ -39,18 +41,22 @@ void TransformOpDescCppToAny(const cpp::OpDesc& cpp_desc, OpDescType* any_desc); /// Transform an BlockDesc from BlockDescType to cpp format. template -void TransformBlockDescAnyToCpp(const BlockDescType& any_desc, cpp::BlockDesc* cpp_desc); +void TransformBlockDescAnyToCpp(const BlockDescType& any_desc, + cpp::BlockDesc* cpp_desc); /// Transform an BlockDesc from cpp to BlockDescType format. template -void TransformBlockDescCppToAny(const cpp::BlockDesc& cpp_desc, BlockDescType* any_desc); +void TransformBlockDescCppToAny(const cpp::BlockDesc& cpp_desc, + BlockDescType* any_desc); /// Transform an ProgramDesc from ProgramDescType to cpp format. template -void TransformProgramDescAnyToCpp(const ProgramDescType& any_desc, cpp::ProgramDesc* cpp_desc); +void TransformProgramDescAnyToCpp(const ProgramDescType& any_desc, + cpp::ProgramDesc* cpp_desc); /// Transform an ProgramDesc from cpp to ProgramDescType format. template -void TransformProgramDescCppToAny(const cpp::ProgramDesc& cpp_desc, ProgramDescType* any_desc); +void TransformProgramDescCppToAny(const cpp::ProgramDesc& cpp_desc, + ProgramDescType* any_desc); } // namespace cinn::frontend::paddle diff --git a/paddle/cinn/frontend/paddle/cpp/desc_api.h b/paddle/cinn/frontend/paddle/cpp/desc_api.h index ae30474cd3ec5..555a02a2ee957 100644 --- a/paddle/cinn/frontend/paddle/cpp/desc_api.h +++ b/paddle/cinn/frontend/paddle/cpp/desc_api.h @@ -25,46 +25,47 @@ namespace cinn::frontend::paddle::cpp { /* * Compatible interfaces for all the different kinds of XXXDesc. All the XXXDesc * classes should implement this. - * ref to: https://github.com/PaddlePaddle/Paddle/blob/v2.4.1/paddle/fluid/framework/framework.proto#L118 + * ref to: + * https://github.com/PaddlePaddle/Paddle/blob/v2.4.1/paddle/fluid/framework/framework.proto#L118 */ class VarDescAPI { public: enum class Type { // Pod Types - BOOL = 0, + BOOL = 0, INT16 = 1, INT32 = 2, INT64 = 3, - FP16 = 4, - FP32 = 5, - FP64 = 6, + FP16 = 4, + FP32 = 5, + FP64 = 6, // Tensor is used in C++. - SIZE_T = 19, - UINT8 = 20, - INT8 = 21, - BF16 = 22, - COMPLEX64 = 23, + SIZE_T = 19, + UINT8 = 20, + INT8 = 21, + BF16 = 22, + COMPLEX64 = 23, COMPLEX128 = 24, // Other types that may need additional descriptions - LOD_TENSOR = 7, - SELECTED_ROWS = 8, - FEED_MINIBATCH = 9, - FETCH_LIST = 10, - STEP_SCOPES = 11, - LOD_RANK_TABLE = 12, + LOD_TENSOR = 7, + SELECTED_ROWS = 8, + FEED_MINIBATCH = 9, + FETCH_LIST = 10, + STEP_SCOPES = 11, + LOD_RANK_TABLE = 12, LOD_TENSOR_ARRAY = 13, - PLACE_LIST = 14, - READER = 15, + PLACE_LIST = 14, + READER = 15, // Any runtime decided variable type is raw // raw variables should manage their own allocations // in operators like nccl_op - RAW = 17, + RAW = 17, TUPLE = 18, - STRING = 25, - STRINGS = 26, - VOCAB = 27, + STRING = 25, + STRINGS = 26, + VOCAB = 27, FEED_LIST = 28, // The data type of phi::StringTensor PSTRING = 29, @@ -103,26 +104,27 @@ class VarDescAPI { class OpDescAPI { public: // The AttrType is used to make the proto::AttrType portable. - // ref to https://github.com/PaddlePaddle/Paddle/blob/v2.4.1/paddle/fluid/framework/framework.proto#L25 + // ref to + // https://github.com/PaddlePaddle/Paddle/blob/v2.4.1/paddle/fluid/framework/framework.proto#L25 enum class AttrType { - INT = 0, - FLOAT = 1, - STRING = 2, - INTS = 3, - FLOATS = 4, - STRINGS = 5, - BOOLEAN = 6, + INT = 0, + FLOAT = 1, + STRING = 2, + INTS = 3, + FLOATS = 4, + STRINGS = 5, + BOOLEAN = 6, BOOLEANS = 7, - BLOCK = 8, - LONG = 9, - BLOCKS = 10, - LONGS = 11, + BLOCK = 8, + LONG = 9, + BLOCKS = 10, + LONGS = 11, FLOAT64S = 12, - VAR = 13, - VARS = 14, - FLOAT64 = 15, - SCALAR = 16, - SCALARS = 17 + VAR = 13, + VARS = 14, + FLOAT64 = 15, + SCALAR = 16, + SCALARS = 17 }; virtual ~OpDescAPI() = default; @@ -140,8 +142,10 @@ class OpDescAPI { /// Get parameters. virtual std::vector OutputArgumentNames() const = 0; /// Set a input given the parameter and arguments. - virtual void SetInput(const std::string& param, const std::vector& args) = 0; - virtual void SetOutput(const std::string& param, const std::vector& args) = 0; + virtual void SetInput(const std::string& param, + const std::vector& args) = 0; + virtual void SetOutput(const std::string& param, + const std::vector& args) = 0; /// Tell whether this desc has an attribute. virtual bool HasAttr(const std::string& name) const = 0; diff --git a/paddle/cinn/frontend/paddle/cpp/op_desc.cc b/paddle/cinn/frontend/paddle/cpp/op_desc.cc index 91abefd737837..35c790999f903 100644 --- a/paddle/cinn/frontend/paddle/cpp/op_desc.cc +++ b/paddle/cinn/frontend/paddle/cpp/op_desc.cc @@ -19,7 +19,8 @@ namespace cinn::frontend::paddle::cpp { -inline std::string AttrTypeToString(paddle::cpp::OpDescAPI::AttrType attr_type) { +inline std::string AttrTypeToString( + paddle::cpp::OpDescAPI::AttrType attr_type) { using AttrType = paddle::cpp::OpDescAPI::AttrType; switch (attr_type) { #define EXPAND_SWITCH_CASE(ATTR_TYPE) \ @@ -47,7 +48,7 @@ inline std::string AttrTypeToString(paddle::cpp::OpDescAPI::AttrType attr_type) template <> \ void OpDesc::SetAttr(const std::string& name, const T& v) { \ attr_types_[name] = AttrType::repr__; \ - attrs_[name] = v; \ + attrs_[name] = v; \ } SET_ATTR_IMPL(int32_t, INT); @@ -65,24 +66,26 @@ SET_ATTR_IMPL(std::vector, LONGS); #undef SET_ATTR_IMPL -std::pair FindAttr(const OpDesc& desc, - const std::string& name) { +std::pair +FindAttr(const OpDesc& desc, const std::string& name) { auto it = desc.attrs().find(name); - CHECK(it != desc.attrs().end()) << "No attributes called " << name << " found"; + CHECK(it != desc.attrs().end()) + << "No attributes called " << name << " found"; auto attr_it = desc.attr_types().find(name); CHECK(attr_it != desc.attr_types().end()); return std::make_pair(it, attr_it); } -#define GET_IMPL_ONE(T, repr__) \ - template <> \ - T OpDesc::GetAttr(const std::string& name) const { \ - auto pair = FindAttr(*this, name); \ - CHECK(pair.second->second == AttrType::repr__) \ - << "The op \"" << Type() << "\"'s attrbute \"" << pair.second->first \ - << "\"'s type doesn't match the target type! Try get \"" << #repr__ << "\", but real \"" \ - << AttrTypeToString(pair.second->second) << "\". Please check."; \ - return absl::any_cast(pair.first->second); \ +#define GET_IMPL_ONE(T, repr__) \ + template <> \ + T OpDesc::GetAttr(const std::string& name) const { \ + auto pair = FindAttr(*this, name); \ + CHECK(pair.second->second == AttrType::repr__) \ + << "The op \"" << Type() << "\"'s attrbute \"" << pair.second->first \ + << "\"'s type doesn't match the target type! Try get \"" << #repr__ \ + << "\", but real \"" << AttrTypeToString(pair.second->second) \ + << "\". Please check."; \ + return absl::any_cast(pair.first->second); \ } GET_IMPL_ONE(int32_t, INT); diff --git a/paddle/cinn/frontend/paddle/cpp/op_desc.h b/paddle/cinn/frontend/paddle/cpp/op_desc.h index bf8110f2f1b9d..a8b99f7699264 100644 --- a/paddle/cinn/frontend/paddle/cpp/op_desc.h +++ b/paddle/cinn/frontend/paddle/cpp/op_desc.h @@ -35,7 +35,7 @@ namespace cpp { */ class OpDesc : public OpDescAPI { public: - using attrs_t = std::map; + using attrs_t = std::map; using attr_types_t = std::map; protected: @@ -51,10 +51,18 @@ class OpDesc : public OpDescAPI { std::string Type() const override { return type_; } void SetType(const std::string& x) override { type_ = x; } - const std::map>& inputs() const { return inputs_; } - const std::map>& outputs() const { return outputs_; } - std::map>* mutable_inputs() { return &inputs_; } - std::map>* mutable_outputs() { return &outputs_; } + const std::map>& inputs() const { + return inputs_; + } + const std::map>& outputs() const { + return outputs_; + } + std::map>* mutable_inputs() { + return &inputs_; + } + std::map>* mutable_outputs() { + return &outputs_; + } bool HasInput(const std::string& param) const { auto it = inputs_.find(param); @@ -74,11 +82,19 @@ class OpDesc : public OpDescAPI { std::vector Output(const std::string& param) const override; - void SetInput(const std::string& param, const std::vector& args) override { inputs_[param] = args; } + void SetInput(const std::string& param, + const std::vector& args) override { + inputs_[param] = args; + } - void SetOutput(const std::string& param, const std::vector& args) override { outputs_[param] = args; } + void SetOutput(const std::string& param, + const std::vector& args) override { + outputs_[param] = args; + } - bool HasAttr(const std::string& name) const override { return attrs_.count(name); } + bool HasAttr(const std::string& name) const override { + return attrs_.count(name); + } AttrType GetAttrType(const std::string& name) const override { auto it = attr_types_.find(name); @@ -101,7 +117,9 @@ class OpDesc : public OpDescAPI { T GetAttr(const std::string& name) const; const std::map& attrs() const { return attrs_; } - const std::map& attr_types() const { return attr_types_; } + const std::map& attr_types() const { + return attr_types_; + } }; } // namespace cpp diff --git a/paddle/cinn/frontend/paddle/cpp/var_desc.cc b/paddle/cinn/frontend/paddle/cpp/var_desc.cc index 8a17b0c130194..f73c0e90ec67f 100644 --- a/paddle/cinn/frontend/paddle/cpp/var_desc.cc +++ b/paddle/cinn/frontend/paddle/cpp/var_desc.cc @@ -14,4 +14,5 @@ #include "paddle/cinn/frontend/paddle/cpp/var_desc.h" -namespace cinn::frontend::paddle::cpp {} // namespace cinn::frontend::paddle::cpp +namespace cinn::frontend::paddle::cpp { +} // namespace cinn::frontend::paddle::cpp diff --git a/paddle/cinn/frontend/paddle/model_parser.cc b/paddle/cinn/frontend/paddle/model_parser.cc index a21f29d91cdda..6feb6f60a33bf 100755 --- a/paddle/cinn/frontend/paddle/model_parser.cc +++ b/paddle/cinn/frontend/paddle/model_parser.cc @@ -47,7 +47,9 @@ int SizeOfType(framework_proto::VarType::Type type) { return -1; } -void TensorFromStream(std::istream &is, hlir::framework::_Tensor_ *tensor, const common::Target &target) { +void TensorFromStream(std::istream &is, + hlir::framework::_Tensor_ *tensor, + const common::Target &target) { using Type = framework_proto::VarType::Type; uint32_t version; is.read(reinterpret_cast(&version), sizeof(version)); @@ -66,7 +68,8 @@ void TensorFromStream(std::istream &is, hlir::framework::_Tensor_ *tensor, const // read tensor std::vector dims_vec; - std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims_vec)); + std::copy( + desc.dims().begin(), desc.dims().end(), std::back_inserter(dims_vec)); hlir::framework::Shape dims(dims_vec); tensor->Resize(dims); void *buf; @@ -93,14 +96,17 @@ void TensorFromStream(std::istream &is, hlir::framework::_Tensor_ *tensor, const is.read(static_cast(buf), size); } else if (target.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDA - if (desc.data_type() != Type::VarType_Type_FP32) LOG(FATAL) << "[CUDA] The type is not fp32!!"; + if (desc.data_type() != Type::VarType_Type_FP32) + LOG(FATAL) << "[CUDA] The type is not fp32!!"; auto *data = tensor->mutable_data(target); tensor->set_type(Float(32)); std::vector temp(tensor->shape().numel()); // LOG(INFO) <<"[CUDA] The tensor's size is "<< tensor->shape().numel(); is.read(reinterpret_cast(temp.data()), size); - CUDA_CALL(cudaMemcpy( - reinterpret_cast(data), temp.data(), tensor->shape().numel() * sizeof(float), cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy(reinterpret_cast(data), + temp.data(), + tensor->shape().numel() * sizeof(float), + cudaMemcpyHostToDevice)); #else LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; #endif @@ -109,7 +115,9 @@ void TensorFromStream(std::istream &is, hlir::framework::_Tensor_ *tensor, const } } -void LoadLoDTensor(std::istream &is, hlir::framework::Variable *var, const common::Target &target) { +void LoadLoDTensor(std::istream &is, + hlir::framework::Variable *var, + const common::Target &target) { auto &tensor = absl::get(*var); uint32_t version{}; is.read(reinterpret_cast(&version), sizeof(version)); @@ -123,7 +131,8 @@ void LoadLoDTensor(std::istream &is, hlir::framework::Variable *var, const commo uint64_t size; is.read(reinterpret_cast(&size), sizeof(size)); std::vector tmp(size / sizeof(uint64_t)); - is.read(reinterpret_cast(tmp.data()), static_cast(size)); + is.read(reinterpret_cast(tmp.data()), + static_cast(size)); // lod[i] = tmp; } @@ -142,8 +151,10 @@ void ReadBinaryFile(const std::string &filename, std::string *contents) { fin.close(); } -std::unique_ptr LoadProgram(const std::string &path, bool program_from_memory) { - std::unique_ptr main_program(new framework_proto::ProgramDesc); +std::unique_ptr LoadProgram( + const std::string &path, bool program_from_memory) { + std::unique_ptr main_program( + new framework_proto::ProgramDesc); if (!program_from_memory) { std::string desc_str; ReadBinaryFile(path, &desc_str); @@ -157,15 +168,19 @@ std::unique_ptr LoadProgram(const std::string &pat void LoadParams(const std::string &path) {} // Load directly to CPU, and latter transfer to other devices. -void LoadParam(const std::string &path, hlir::framework::Variable *out, const common::Target &target) { +void LoadParam(const std::string &path, + hlir::framework::Variable *out, + const common::Target &target) { std::ifstream fin(path, std::ios::binary); CHECK(fin.is_open()) << "failed to open file " << path; LoadLoDTensor(fin, out, target); } bool IsPersistable(const cpp::VarDesc &var) { - if (var.Persistable() && var.GetType() != cpp::VarDescAPI::Type::FEED_MINIBATCH && - var.GetType() != cpp::VarDescAPI::Type::FETCH_LIST && var.GetType() != cpp::VarDescAPI::Type::RAW) { + if (var.Persistable() && + var.GetType() != cpp::VarDescAPI::Type::FEED_MINIBATCH && + var.GetType() != cpp::VarDescAPI::Type::FETCH_LIST && + var.GetType() != cpp::VarDescAPI::Type::RAW) { return true; } return false; @@ -177,7 +192,7 @@ void LoadCombinedParamsPb(const std::string &path, bool params_from_memory, const common::Target &target) { CHECK(scope); - auto prog = cpp_prog; + auto prog = cpp_prog; auto &main_block_desc = *prog.GetBlock(0); // Get vars @@ -192,9 +207,11 @@ void LoadCombinedParamsPb(const std::string &path, // Load vars auto load_var_func = [&](std::istream &is) { for (size_t i = 0; i < paramlist.size(); ++i) { - auto *var = scope->Var(utils::TransValidVarName(paramlist[i])); + auto *var = scope->Var( + utils::TransValidVarName(paramlist[i])); // Error checking - CHECK(static_cast(is)) << "There is a problem with loading model parameters"; + CHECK(static_cast(is)) + << "There is a problem with loading model parameters"; LoadLoDTensor(is, var, target); } is.peek(); @@ -228,13 +245,14 @@ void LoadModelPb(const std::string &model_dir, VLOG(3) << "param_file is: " << param_file; // Load model VLOG(4) << "Start load model program..."; - std::string prog_path = model_dir + "/__model__"; + std::string prog_path = model_dir + "/__model__"; std::string param_file_temp = param_file; if (combined) { // prog_path = model_file; param_file_temp = model_dir + "/params"; } - framework_proto::ProgramDesc pb_proto_prog = *LoadProgram(prog_path, model_from_memory); + framework_proto::ProgramDesc pb_proto_prog = + *LoadProgram(prog_path, model_from_memory); pb::ProgramDesc pb_prog(&pb_proto_prog); // Transform to cpp::ProgramDesc TransformProgramDescAnyToCpp(pb_prog, cpp_prog); @@ -242,15 +260,18 @@ void LoadModelPb(const std::string &model_dir, // Load Params // NOTE: Only main block be used now. VLOG(4) << "Start load model params..."; - CHECK(!(!combined && model_from_memory)) << "If you want use the model_from_memory," - << " you should load the combined model using cfg.set_model_buffer " - "interface."; + CHECK(!(!combined && model_from_memory)) + << "If you want use the model_from_memory," + << " you should load the combined model using cfg.set_model_buffer " + "interface."; if (combined) { - LoadCombinedParamsPb(param_file_temp, scope, *cpp_prog, model_from_memory, target); + LoadCombinedParamsPb( + param_file_temp, scope, *cpp_prog, model_from_memory, target); } else { auto main_block = pb_proto_prog.blocks(0); for (auto &var : main_block.vars()) { - if (var.name() == "feed" || var.name() == "fetch" || !var.persistable()) continue; + if (var.name() == "feed" || var.name() == "fetch" || !var.persistable()) + continue; std::string file_path = model_dir + "/" + var.name(); VLOG(4) << "reading weight " << var.name(); @@ -258,7 +279,10 @@ void LoadModelPb(const std::string &model_dir, std::ifstream file(file_path, std::ios::binary); switch (var.type().type()) { case framework_proto::VarType_Type_LOD_TENSOR: - LoadLoDTensor(file, scope->Var(utils::TransValidVarName(var.name())), target); + LoadLoDTensor(file, + scope->Var( + utils::TransValidVarName(var.name())), + target); break; default: LOG(FATAL) << "unknown weight type"; diff --git a/paddle/cinn/frontend/paddle/model_parser.h b/paddle/cinn/frontend/paddle/model_parser.h index bda1285666dc0..8bc10108b79de 100644 --- a/paddle/cinn/frontend/paddle/model_parser.h +++ b/paddle/cinn/frontend/paddle/model_parser.h @@ -35,32 +35,39 @@ void LoadModelPb(const std::string& model_dir, const std::string& param_file, hlir::framework::Scope* scope, cpp::ProgramDesc* cpp_prog, - bool combined = true, - bool model_from_memory = false, + bool combined = true, + bool model_from_memory = false, const common::Target& target = common::DefaultHostTarget()); // Read a __model__ file. -std::unique_ptr LoadProgram(const std::string& path, bool program_from_memory = false); +std::unique_ptr LoadProgram( + const std::string& path, bool program_from_memory = false); -void LoadLoDTensor(std::istream& is, hlir::framework::Variable* var, const common::Target& target); +void LoadLoDTensor(std::istream& is, + hlir::framework::Variable* var, + const common::Target& target); // Read a single file containing all the parameters. void LoadParams(const std::string& path); // Load a single parameter to an output tensor. -void LoadParam(const std::string& path, hlir::framework::Variable* out, const common::Target& target); +void LoadParam(const std::string& path, + hlir::framework::Variable* out, + const common::Target& target); -void LoadCombinedParamsPb(const std::string& path, - hlir::framework::Scope* scope, - const pb::ProgramDesc& prog, - bool params_from_memory = false, - const common::Target& target = common::DefaultHostTarget()); +void LoadCombinedParamsPb( + const std::string& path, + hlir::framework::Scope* scope, + const pb::ProgramDesc& prog, + bool params_from_memory = false, + const common::Target& target = common::DefaultHostTarget()); // LoDTensor to ostream void TensorToStream(std::ostream& os, const hlir::framework::_Tensor_& tensor); -void TensorFromStream(std::istream& is, - hlir::framework::_Tensor_* tensor, - const common::Target& target = common::DefaultHostTarget()); +void TensorFromStream( + std::istream& is, + hlir::framework::_Tensor_* tensor, + const common::Target& target = common::DefaultHostTarget()); void ReadBinaryFile(const std::string& filename, std::string* contents); } // namespace cinn::frontend::paddle diff --git a/paddle/cinn/frontend/paddle/pb/block_desc.cc b/paddle/cinn/frontend/paddle/pb/block_desc.cc index 9cea5f25b1cc8..0a7984535dc13 100644 --- a/paddle/cinn/frontend/paddle/pb/block_desc.cc +++ b/paddle/cinn/frontend/paddle/pb/block_desc.cc @@ -17,7 +17,8 @@ namespace cinn::frontend::paddle::pb { template <> -framework_proto::VarDesc* BlockDesc::GetVar(int32_t idx) { +framework_proto::VarDesc* BlockDesc::GetVar( + int32_t idx) { CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; return desc_->mutable_vars(idx); } @@ -28,7 +29,8 @@ framework_proto::VarDesc* BlockDesc::AddVar() { } template <> -framework_proto::OpDesc* BlockDesc::GetOp(int32_t idx) { +framework_proto::OpDesc* BlockDesc::GetOp( + int32_t idx) { CHECK_LT(idx, OpsSize()) << "idx >= ops.size()"; return desc_->mutable_ops(idx); } diff --git a/paddle/cinn/frontend/paddle/pb/block_desc.h b/paddle/cinn/frontend/paddle/pb/block_desc.h index e5229bf3c4aba..8d06089ba6cc5 100644 --- a/paddle/cinn/frontend/paddle/pb/block_desc.h +++ b/paddle/cinn/frontend/paddle/pb/block_desc.h @@ -26,7 +26,9 @@ class BlockDesc : public cpp::BlockDescAPI { public: BlockDesc() = delete; - explicit BlockDesc(framework_proto::BlockDesc* desc) : desc_(desc) { CHECK(desc_); } + explicit BlockDesc(framework_proto::BlockDesc* desc) : desc_(desc) { + CHECK(desc_); + } framework_proto::BlockDesc* Proto() { return desc_; } @@ -60,9 +62,13 @@ class BlockDesc : public cpp::BlockDescAPI { template T* AddOp(); - int32_t ForwardBlockIdx() const override { return desc_->forward_block_idx(); } + int32_t ForwardBlockIdx() const override { + return desc_->forward_block_idx(); + } - void SetForwardBlockIdx(int32_t idx) override { desc_->set_forward_block_idx(idx); } + void SetForwardBlockIdx(int32_t idx) override { + desc_->set_forward_block_idx(idx); + } private: framework_proto::BlockDesc* desc_; // not_own diff --git a/paddle/cinn/frontend/paddle/pb/op_desc.cc b/paddle/cinn/frontend/paddle/pb/op_desc.cc index 679f50b93c8ed..55a2594451635 100644 --- a/paddle/cinn/frontend/paddle/pb/op_desc.cc +++ b/paddle/cinn/frontend/paddle/pb/op_desc.cc @@ -16,14 +16,20 @@ namespace cinn::frontend::paddle::pb { -google::protobuf::internal::RepeatedPtrIterator FindAttr(framework_proto::OpDesc *desc, - const std::string &name) { +google::protobuf::internal::RepeatedPtrIterator +FindAttr(framework_proto::OpDesc *desc, const std::string &name) { auto &xs = *desc->mutable_attrs(); - auto it = std::find_if(xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { return x.name() == name; }); + auto it = std::find_if( + xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { + return x.name() == name; + }); if (it == xs.end()) { auto *attr = xs.Add(); attr->set_name(name); - it = std::find_if(xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { return x.name() == name; }); + it = std::find_if( + xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { + return x.name() == name; + }); } return it; } @@ -41,7 +47,8 @@ SET_IMPL_ONE(bool, BOOLEAN, b); SET_IMPL_ONE(int64_t, LONG, l); template <> -void OpDesc::SetAttr>(const std::string &name, const std::vector &v) { +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v) { auto it = FindAttr(desc_, name); it->set_type(framework_proto::INTS); it->clear_ints(); @@ -51,14 +58,16 @@ void OpDesc::SetAttr>(const std::string &name, const std::vecto } template <> -void OpDesc::SetAttr(const std::string &name, const std::string &v) { +void OpDesc::SetAttr(const std::string &name, + const std::string &v) { auto it = FindAttr(desc_, name); it->set_type(framework_proto::STRING); it->set_s(v.c_str()); } template <> -void OpDesc::SetAttr>(const std::string &name, const std::vector &v) { +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v) { auto it = FindAttr(desc_, name); it->set_type(framework_proto::FLOATS); it->clear_floats(); @@ -68,7 +77,8 @@ void OpDesc::SetAttr>(const std::string &name, const std::vec } template <> -void OpDesc::SetAttr>(const std::string &name, const std::vector &v) { +void OpDesc::SetAttr>( + const std::string &name, const std::vector &v) { auto it = FindAttr(desc_, name); it->set_type(framework_proto::STRINGS); it->clear_strings(); @@ -78,7 +88,8 @@ void OpDesc::SetAttr>(const std::string &name, const st } template <> -void OpDesc::SetAttr>(const std::string &name, const std::vector &v) { +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v) { auto it = FindAttr(desc_, name); it->set_type(framework_proto::LONGS); it->clear_longs(); @@ -86,10 +97,14 @@ void OpDesc::SetAttr>(const std::string &name, const std::v it->add_longs(i); } } -google::protobuf::internal::RepeatedPtrIterator GetFindAttr( - const framework_proto::OpDesc &desc, const std::string &name) { +google::protobuf::internal::RepeatedPtrIterator< + const framework_proto::OpDesc_Attr> +GetFindAttr(const framework_proto::OpDesc &desc, const std::string &name) { auto &xs = desc.attrs(); - auto it = std::find_if(xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { return x.name() == name; }); + auto it = std::find_if( + xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { + return x.name() == name; + }); return it; } diff --git a/paddle/cinn/frontend/paddle/pb/op_desc.h b/paddle/cinn/frontend/paddle/pb/op_desc.h index 51d2187e4f6e1..82e1477270fa4 100644 --- a/paddle/cinn/frontend/paddle/pb/op_desc.h +++ b/paddle/cinn/frontend/paddle/pb/op_desc.h @@ -22,7 +22,8 @@ namespace cinn::frontend::paddle::pb { namespace framework_proto = ::cinn::frontend::paddle::proto; -using Attribute = absl::variant, std::vector>; +using Attribute = + absl::variant, std::vector>; using VariableNameMap = std::map>; /* @@ -49,9 +50,12 @@ class OpDesc : public cpp::OpDescAPI { return GetArguments(desc_->inputs(), param); } - std::vector InputArgumentNames() const override { return GetArgumentNames(desc_->inputs()); } + std::vector InputArgumentNames() const override { + return GetArgumentNames(desc_->inputs()); + } - void SetInput(const std::string ¶m, const std::vector &args) override { + void SetInput(const std::string ¶m, + const std::vector &args) override { SetArgument(desc_->mutable_inputs(), param, args); } @@ -59,23 +63,30 @@ class OpDesc : public cpp::OpDescAPI { return GetArguments(desc_->outputs(), param); } - std::vector OutputArgumentNames() const override { return GetArgumentNames(desc_->outputs()); } + std::vector OutputArgumentNames() const override { + return GetArgumentNames(desc_->outputs()); + } - void SetOutput(const std::string ¶m, const std::vector &args) override { + void SetOutput(const std::string ¶m, + const std::vector &args) override { SetArgument(desc_->mutable_outputs(), param, args); } bool HasAttr(const std::string &name) const override { const auto &xs = desc_->attrs(); - auto it = - std::find_if(xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { return x.name() == name; }); + auto it = std::find_if( + xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { + return x.name() == name; + }); return it != xs.end(); } AttrType GetAttrType(const std::string &name) const override { const auto &xs = desc_->attrs(); - auto it = - std::find_if(xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { return x.name() == name; }); + auto it = std::find_if( + xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { + return x.name() == name; + }); CHECK(it != xs.end()); #define DEF_ONE(type__) \ case framework_proto::AttrType::type__: \ @@ -105,7 +116,10 @@ class OpDesc : public cpp::OpDescAPI { std::vector res; const auto &xs = desc_->attrs(); std::transform( - xs.begin(), xs.end(), std::back_inserter(res), [](const framework_proto::OpDesc_Attr &x) { return x.name(); }); + xs.begin(), + xs.end(), + std::back_inserter(res), + [](const framework_proto::OpDesc_Attr &x) { return x.name(); }); return res; } @@ -116,23 +130,32 @@ class OpDesc : public cpp::OpDescAPI { T GetAttr(const std::string &name) const; private: - std::vector GetArguments(const google::protobuf::RepeatedPtrField &xs, - const std::string ¶m) const { + std::vector GetArguments( + const google::protobuf::RepeatedPtrField &xs, + const std::string ¶m) const { std::vector res; auto it = std::find_if( - xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Var &it) { return it.parameter() == param; }); + xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Var &it) { + return it.parameter() == param; + }); CHECK(it != xs.end()); const auto &ys = it->arguments(); - std::transform(ys.begin(), ys.end(), std::back_inserter(res), [](const std::string &x) { return x; }); + std::transform(ys.begin(), + ys.end(), + std::back_inserter(res), + [](const std::string &x) { return x; }); return res; } - void SetArgument(google::protobuf::RepeatedPtrField *xs, - const std::string ¶m, - const std::vector &args) { + void SetArgument( + google::protobuf::RepeatedPtrField *xs, + const std::string ¶m, + const std::vector &args) { auto it = std::find_if( - xs->begin(), xs->end(), [&](const framework_proto::OpDesc_Var &it) { return it.parameter() == param; }); + xs->begin(), xs->end(), [&](const framework_proto::OpDesc_Var &it) { + return it.parameter() == param; + }); if (it == xs->end()) { auto *new_arg = xs->Add(); new_arg->set_parameter(param); @@ -148,11 +171,14 @@ class OpDesc : public cpp::OpDescAPI { } std::vector GetArgumentNames( - const google::protobuf::RepeatedPtrField &xs) const { + const google::protobuf::RepeatedPtrField &xs) + const { std::vector res; - std::transform(xs.begin(), xs.end(), std::back_inserter(res), [](const framework_proto::OpDesc_Var &x) { - return x.parameter(); - }); + std::transform( + xs.begin(), + xs.end(), + std::back_inserter(res), + [](const framework_proto::OpDesc_Var &x) { return x.parameter(); }); return res; } @@ -161,9 +187,11 @@ class OpDesc : public cpp::OpDescAPI { }; template <> -void OpDesc::SetAttr(const std::string &name, const std::string &v); +void OpDesc::SetAttr(const std::string &name, + const std::string &v); template <> -void OpDesc::SetAttr>(const std::string &name, const std::vector &v); +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v); } // namespace cinn::frontend::paddle::pb diff --git a/paddle/cinn/frontend/paddle/pb/program_desc.cc b/paddle/cinn/frontend/paddle/pb/program_desc.cc index dc4eef70f8235..77e0014b31071 100644 --- a/paddle/cinn/frontend/paddle/pb/program_desc.cc +++ b/paddle/cinn/frontend/paddle/pb/program_desc.cc @@ -20,13 +20,15 @@ namespace cinn::frontend::paddle::pb { template <> -framework_proto::BlockDesc* ProgramDesc::GetBlock(int32_t idx) { +framework_proto::BlockDesc* ProgramDesc::GetBlock( + int32_t idx) { CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()"; return desc_->mutable_blocks(idx); } template <> -framework_proto::BlockDesc* ProgramDesc::AddBlock() { +framework_proto::BlockDesc* +ProgramDesc::AddBlock() { return desc_->add_blocks(); } diff --git a/paddle/cinn/frontend/paddle/pb/program_desc.h b/paddle/cinn/frontend/paddle/pb/program_desc.h index f89b1739e6a93..9f2ac8246bc20 100644 --- a/paddle/cinn/frontend/paddle/pb/program_desc.h +++ b/paddle/cinn/frontend/paddle/pb/program_desc.h @@ -28,7 +28,9 @@ class ProgramDesc : public cpp::ProgramDescAPI { public: ProgramDesc() = delete; - explicit ProgramDesc(framework_proto::ProgramDesc *desc) : desc_(desc) { CHECK(desc_); } + explicit ProgramDesc(framework_proto::ProgramDesc *desc) : desc_(desc) { + CHECK(desc_); + } framework_proto::ProgramDesc *Proto() { return desc_; } @@ -48,7 +50,9 @@ class ProgramDesc : public cpp::ProgramDescAPI { int64_t Version() const override { return desc_->version().version(); } - void SetVersion(int64_t version) override { desc_->mutable_version()->set_version(version); } + void SetVersion(int64_t version) override { + desc_->mutable_version()->set_version(version); + } private: framework_proto::ProgramDesc *desc_; // not_own diff --git a/paddle/cinn/frontend/paddle/pb/var_desc.cc b/paddle/cinn/frontend/paddle/pb/var_desc.cc index 2ecd927e995e2..efee4f211d662 100644 --- a/paddle/cinn/frontend/paddle/pb/var_desc.cc +++ b/paddle/cinn/frontend/paddle/pb/var_desc.cc @@ -74,7 +74,8 @@ void VarDesc::SetShape(const std::vector &dims) { void VarDesc::SetTensorDescNum(size_t num) { switch (desc_->type().type()) { case framework_proto::VarType::READER: { - auto *lod_tensors_ptr = desc_->mutable_type()->mutable_reader()->mutable_lod_tensor(); + auto *lod_tensors_ptr = + desc_->mutable_type()->mutable_reader()->mutable_lod_tensor(); lod_tensors_ptr->Clear(); for (size_t i = 0; i < num; ++i) { lod_tensors_ptr->Add(); @@ -101,20 +102,25 @@ size_t VarDesc::GetTensorDescNum() const { return 0; } -void VarDesc::SetShapes(const std::vector> &multiple_dims) { +void VarDesc::SetShapes( + const std::vector> &multiple_dims) { if (multiple_dims.size() != GetTensorDescNum()) { VLOG(3) << "WARNING: The number of given shapes(" << multiple_dims.size() - << ") doesn't match the existing tensor number(" << GetTensorDescNum() + << ") doesn't match the existing tensor number(" + << GetTensorDescNum() << "). The Reader is going to be reinitialized."; SetTensorDescNum(multiple_dims.size()); } - std::vector tensors = mutable_tensor_descs(); + std::vector tensors = + mutable_tensor_descs(); for (size_t i = 0; i < multiple_dims.size(); ++i) { VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims()); } } -std::vector VarDesc::GetShape() const { return RepeatedToVector(tensor_desc().dims()); } +std::vector VarDesc::GetShape() const { + return RepeatedToVector(tensor_desc().dims()); +} std::vector> VarDesc::GetShapes() const { std::vector descs = tensor_descs(); @@ -150,14 +156,18 @@ void VarDesc::SetDataType(VarDescAPI::VarDataType data_type) { #undef SET_DATA_TYPE_CASE_ITEM } -void VarDesc::SetDataTypes(const std::vector &multiple_data_type) { +void VarDesc::SetDataTypes( + const std::vector &multiple_data_type) { if (multiple_data_type.size() != GetTensorDescNum()) { - VLOG(3) << "WARNING: The number of given data types(" << multiple_data_type.size() - << ") doesn't match the existing tensor number(" << GetTensorDescNum() + VLOG(3) << "WARNING: The number of given data types(" + << multiple_data_type.size() + << ") doesn't match the existing tensor number(" + << GetTensorDescNum() << "). The Reader is going to be reinitialized."; SetTensorDescNum(multiple_data_type.size()); } - std::vector tensor_descs = mutable_tensor_descs(); + std::vector tensor_descs = + mutable_tensor_descs(); for (size_t i = 0; i < multiple_data_type.size(); ++i) { tensor_descs[i]->set_data_type(multiple_data_type[i]); } @@ -215,26 +225,33 @@ void VarDesc::SetLoDLevel(int32_t lod_level) { desc_->mutable_type()->mutable_tensor_array()->set_lod_level(lod_level); break; default: - LOG(FATAL) << "Setting 'lod_level' is not supported by the type of var %s." << this->Name(); + LOG(FATAL) + << "Setting 'lod_level' is not supported by the type of var %s." + << this->Name(); } } void VarDesc::SetLoDLevels(const std::vector &multiple_lod_level) { if (multiple_lod_level.size() != GetTensorDescNum()) { - VLOG(3) << "WARNING: The number of given lod_levels(" << multiple_lod_level.size() - << ") doesn't match the existing tensor number(" << GetTensorDescNum() + VLOG(3) << "WARNING: The number of given lod_levels(" + << multiple_lod_level.size() + << ") doesn't match the existing tensor number(" + << GetTensorDescNum() << "). The Reader is going to be reinitialized."; SetTensorDescNum(multiple_lod_level.size()); } switch (desc_->type().type()) { case framework_proto::VarType::READER: { size_t i = 0; - for (auto &lod_tensor : *desc_->mutable_type()->mutable_reader()->mutable_lod_tensor()) { + for (auto &lod_tensor : + *desc_->mutable_type()->mutable_reader()->mutable_lod_tensor()) { lod_tensor.set_lod_level(multiple_lod_level[i++]); } } break; default: - LOG(FATAL) << "Setting 'lod_levels' is not supported by the type of var %s." << this->Name(); + LOG(FATAL) + << "Setting 'lod_levels' is not supported by the type of var %s." + << this->Name(); } } @@ -245,7 +262,9 @@ int32_t VarDesc::GetLoDLevel() const { case framework_proto::VarType::LOD_TENSOR_ARRAY: return desc_->type().tensor_array().lod_level(); default: - LOG(FATAL) << "Getting 'lod_level' is not supported by the type of var %s." << this->Name(); + LOG(FATAL) + << "Getting 'lod_level' is not supported by the type of var %s." + << this->Name(); } return 0; } @@ -261,7 +280,9 @@ std::vector VarDesc::GetLoDLevels() const { return res; break; default: - LOG(FATAL) << "Getting 'lod_levels' is not supported by the type of var %s." << this->Name(); + LOG(FATAL) + << "Getting 'lod_levels' is not supported by the type of var %s." + << this->Name(); } return std::vector(); } @@ -277,12 +298,15 @@ const framework_proto::VarType::TensorDesc &VarDesc::tensor_desc() const { case framework_proto::VarType::LOD_TENSOR_ARRAY: return desc_->type().tensor_array().tensor(); default: - LOG(FATAL) << "Getting 'tensor_desc' is not supported by the type of var %s." << this->Name(); + LOG(FATAL) + << "Getting 'tensor_desc' is not supported by the type of var %s." + << this->Name(); } return framework_proto::VarDesc().type().lod_tensor().tensor(); } -std::vector VarDesc::tensor_descs() const { +std::vector VarDesc::tensor_descs() + const { CHECK(desc_->has_type()) << "The var type hasn't been set."; std::vector res; res.reserve(GetTensorDescNum()); @@ -293,9 +317,10 @@ std::vector VarDesc::tensor_descs() const } return res; default: - LOG(FATAL) << "Getting 'tensor_descs' is not supported by the type of var " - "%s." - << this->Name(); + LOG(FATAL) + << "Getting 'tensor_descs' is not supported by the type of var " + "%s." + << this->Name(); } return std::vector(); } @@ -319,21 +344,24 @@ framework_proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() { return nullptr; } -std::vector VarDesc::mutable_tensor_descs() { +std::vector +VarDesc::mutable_tensor_descs() { CHECK(desc_->has_type()) << "The var type hasn't been set."; CHECK(desc_->type().has_type()) << "The var type hasn't been set."; std::vector res; res.reserve(GetTensorDescNum()); switch (desc_->type().type()) { case framework_proto::VarType::READER: - for (auto &lod_tensor : *desc_->mutable_type()->mutable_reader()->mutable_lod_tensor()) { + for (auto &lod_tensor : + *desc_->mutable_type()->mutable_reader()->mutable_lod_tensor()) { res.push_back(lod_tensor.mutable_tensor()); } return res; default: - LOG(FATAL) << "Getting 'tensor_descs' is not supported by the type of var " - "%s." - << this->Name(); + LOG(FATAL) + << "Getting 'tensor_descs' is not supported by the type of var " + "%s." + << this->Name(); } return std::vector(); } diff --git a/paddle/cinn/frontend/paddle/pb/var_desc.h b/paddle/cinn/frontend/paddle/pb/var_desc.h index ccb4aa9f534f6..55b878d2474ae 100644 --- a/paddle/cinn/frontend/paddle/pb/var_desc.h +++ b/paddle/cinn/frontend/paddle/pb/var_desc.h @@ -28,15 +28,18 @@ namespace framework_proto = ::cinn::frontend::paddle::proto; // convert between std::vector and protobuf repeated. template -inline std::vector RepeatedToVector(const google::protobuf::RepeatedField &repeated_field) { +inline std::vector RepeatedToVector( + const google::protobuf::RepeatedField &repeated_field) { std::vector ret; ret.reserve(repeated_field.size()); - std::copy(repeated_field.begin(), repeated_field.end(), std::back_inserter(ret)); + std::copy( + repeated_field.begin(), repeated_field.end(), std::back_inserter(ret)); return ret; } template -inline void VectorToRepeated(const std::vector &vec, RepeatedField *repeated_field) { +inline void VectorToRepeated(const std::vector &vec, + RepeatedField *repeated_field) { repeated_field->Clear(); repeated_field->Reserve(vec.size()); for (const auto &elem : vec) { @@ -46,7 +49,8 @@ inline void VectorToRepeated(const std::vector &vec, RepeatedField *repeated_ // Specialize vector. template -inline void VectorToRepeated(const std::vector &vec, RepeatedField *repeated_field) { +inline void VectorToRepeated(const std::vector &vec, + RepeatedField *repeated_field) { repeated_field->Clear(); repeated_field->Reserve(vec.size()); for (auto elem : vec) { @@ -58,7 +62,9 @@ class VarDesc : public cpp::VarDescAPI { public: VarDesc() = delete; - explicit VarDesc(framework_proto::VarDesc *desc) : desc_(desc) { CHECK(desc_); } + explicit VarDesc(framework_proto::VarDesc *desc) : desc_(desc) { + CHECK(desc_); + } ::cinn::frontend::paddle::proto::VarDesc *Proto() { return desc_; } const framework_proto::VarDesc &ReadonlyProto() const { return *desc_; } @@ -81,7 +87,8 @@ class VarDesc : public cpp::VarDescAPI { void SetDataType(VarDescAPI::VarDataType data_type); - void SetDataTypes(const std::vector &multiple_data_type); + void SetDataTypes( + const std::vector &multiple_data_type); VarDescAPI::VarDataType GetDataType() const; @@ -101,7 +108,9 @@ class VarDesc : public cpp::VarDescAPI { bool Persistable() const override { return desc_->persistable(); } - void SetPersistable(bool persistable) override { desc_->set_persistable(persistable); } + void SetPersistable(bool persistable) override { + desc_->set_persistable(persistable); + } private: const framework_proto::VarType::TensorDesc &tensor_desc() const; diff --git a/paddle/cinn/frontend/paddle_model_convertor.cc b/paddle/cinn/frontend/paddle_model_convertor.cc index c9f978f99c4a9..0312a629e234e 100644 --- a/paddle/cinn/frontend/paddle_model_convertor.cc +++ b/paddle/cinn/frontend/paddle_model_convertor.cc @@ -32,29 +32,37 @@ namespace frontend { using cinn::utils::Attribute; -PaddleModelConvertor::PaddleModelConvertor() : PaddleModelConvertor(common::DefaultTarget(), nullptr, nullptr) {} +PaddleModelConvertor::PaddleModelConvertor() + : PaddleModelConvertor(common::DefaultTarget(), nullptr, nullptr) {} -PaddleModelConvertor::PaddleModelConvertor(const common::Target& target, - std::shared_ptr builder, - std::shared_ptr scope) +PaddleModelConvertor::PaddleModelConvertor( + const common::Target& target, + std::shared_ptr builder, + std::shared_ptr scope) : target_(target), builder_(builder), scope_(scope) { if (!builder_) { // do not need scope - builder_ = std::make_shared(cinn::UniqName("PaddleModelConvertor")); + builder_ = + std::make_shared(cinn::UniqName("PaddleModelConvertor")); } if (!scope_) { // do not need scope scope_ = hlir::framework::Scope::Create(); } - ctx_ = std::make_unique( - *scope_, target_, builder_.get(), &var_map_, &var_model_to_program_map_, &fetch_var_names_); + ctx_ = std::make_unique(*scope_, + target_, + builder_.get(), + &var_map_, + &var_model_to_program_map_, + &fetch_var_names_); } -void PaddleModelConvertor::PrepareRun(const paddle::cpp::BlockDesc& block_desc, OpMapperContext* ctx) { +void PaddleModelConvertor::PrepareRun(const paddle::cpp::BlockDesc& block_desc, + OpMapperContext* ctx) { std::unordered_map var_desc_map; // preserve var desc info lik shape and dtype for (int i = 0; i < block_desc.VarsSize(); i++) { - const auto& var_desc = block_desc.GetConstVar(i); + const auto& var_desc = block_desc.GetConstVar(i); var_desc_map[var_desc.Name()] = &var_desc; } @@ -63,16 +71,19 @@ void PaddleModelConvertor::PrepareRun(const paddle::cpp::BlockDesc& block_desc, if (op_desc.Type() == "feed") { for (const auto& var_name : op_desc.output_vars()) { - CHECK(var_desc_map.count(var_name)) << "Feed var [" << var_name << "] Not found in block"; - ctx->AddFeedInfo(var_name, utils::GetFeedInfoFromDesc(*var_desc_map[var_name])); + CHECK(var_desc_map.count(var_name)) + << "Feed var [" << var_name << "] Not found in block"; + ctx->AddFeedInfo(var_name, + utils::GetFeedInfoFromDesc(*var_desc_map[var_name])); } } } } -void PaddleModelConvertor::RunOp(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { +void PaddleModelConvertor::RunOp(const paddle::cpp::OpDesc& op_desc, + const OpMapperContext& ctx) { const auto& op_type = op_desc.Type(); - auto kernel = OpMapperRegistry::Global()->Find(op_type); + auto kernel = OpMapperRegistry::Global()->Find(op_type); CHECK(kernel) << "Op [" << op_type << "] Not supported in OpMapper"; VLOG(4) << "Running Op " << op_type; kernel->Run(op_desc, ctx); @@ -80,11 +91,13 @@ void PaddleModelConvertor::RunOp(const paddle::cpp::OpDesc& op_desc, const OpMap std::unordered_map PaddleModelConvertor::GetFetchList( const std::unordered_set& fetch_name_list) const { - // the return map's key is paddle variable name, the value is the cinn fetch variable + // the return map's key is paddle variable name, the value is the cinn fetch + // variable const std::unordered_set* var_name_list = &fetch_name_list; if (fetch_name_list.empty()) { // if paddle var list is empty, fetch the program's fetch var instead - CHECK(!fetch_var_names_.empty()) << "Should not fetch empty variable in CINN."; + CHECK(!fetch_var_names_.empty()) + << "Should not fetch empty variable in CINN."; var_name_list = &fetch_var_names_; } @@ -92,7 +105,8 @@ std::unordered_map PaddleModelConvertor::GetFetchList( fetch_list.reserve(var_name_list->size()); for (const auto& pd_name : *var_name_list) { CHECK(var_model_to_program_map_.count(pd_name)) - << "Cannot find cinn variable [" << pd_name << "] in var_model_to_program_map_"; + << "Cannot find cinn variable [" << pd_name + << "] in var_model_to_program_map_"; auto norm_pd_name = pd_name; // remove inplace output's suffix auto pos = pd_name.find(paddle::InplaceOutSuffix); @@ -104,26 +118,41 @@ std::unordered_map PaddleModelConvertor::GetFetchList( return fetch_list; } -Program PaddleModelConvertor::LoadModel(const std::string& model_dir, - bool is_combined, - const std::unordered_map>& feed) { +Program PaddleModelConvertor::LoadModel( + const std::string& model_dir, + bool is_combined, + const std::unordered_map>& feed) { paddle::cpp::ProgramDesc program_desc; - paddle::LoadModelPb(model_dir, "__model__", "", scope_.get(), &program_desc, is_combined, false, target_); - CHECK_EQ(program_desc.BlocksSize(), 1) << "CINN can only support the model with a single block"; + paddle::LoadModelPb(model_dir, + "__model__", + "", + scope_.get(), + &program_desc, + is_combined, + false, + target_); + CHECK_EQ(program_desc.BlocksSize(), 1) + << "CINN can only support the model with a single block"; auto* block_desc = program_desc.GetBlock(0); // Set feeds shape for (int i = 0; i < block_desc->VarsSize(); i++) { - auto* var_desc = block_desc->GetVar(i); + auto* var_desc = block_desc->GetVar(i); const auto var_name = var_desc->Name(); if (feed.count(var_name)) { const auto& var_shape = feed.at(var_name); - VLOG(4) << "Update var " << var_name << "'s shape to: " << cinn::utils::Join(var_shape, ", "); + VLOG(4) << "Update var " << var_name + << "'s shape to: " << cinn::utils::Join(var_shape, ", "); var_desc->SetShape(var_shape); } } - OpMapperContext ctx(*scope_, target_, builder_.get(), &var_map_, &var_model_to_program_map_, &fetch_var_names_); + OpMapperContext ctx(*scope_, + target_, + builder_.get(), + &var_map_, + &var_model_to_program_map_, + &fetch_var_names_); PrepareRun(*block_desc, &ctx); for (int i = 0; i < block_desc->OpsSize(); i++) { @@ -133,10 +162,13 @@ Program PaddleModelConvertor::LoadModel(const std::string& model_dir, return builder_->Build(); } -void SetOpDescAttr(const std::string& attr_name, const Attribute& attr_value, paddle::cpp::OpDesc* op_desc) { +void SetOpDescAttr(const std::string& attr_name, + const Attribute& attr_value, + paddle::cpp::OpDesc* op_desc) { class Visitor { public: - Visitor(paddle::cpp::OpDesc* op_desc, const std::string& attr_name) : op_desc_(op_desc), attr_name_(attr_name) {} + Visitor(paddle::cpp::OpDesc* op_desc, const std::string& attr_name) + : op_desc_(op_desc), attr_name_(attr_name) {} #define VISITOR_EXPAND(TYPE) \ void operator()(const TYPE& v) { op_desc_->SetAttr(attr_name_, v); } @@ -162,11 +194,12 @@ void SetOpDescAttr(const std::string& attr_name, const Attribute& attr_value, pa absl::visit(Visitor{op_desc, attr_name}, attr_value); } -void PaddleModelConvertor::RunOp(const std::string& op_type, - const std::map>& inputs, - const std::map>& outputs, - const std::map& attrs, - const OpMapperContext& ctx) { +void PaddleModelConvertor::RunOp( + const std::string& op_type, + const std::map>& inputs, + const std::map>& outputs, + const std::map& attrs, + const OpMapperContext& ctx) { paddle::cpp::OpDesc op_desc; op_desc.SetType(op_type); for (const auto& in_pair : inputs) { @@ -182,10 +215,11 @@ void PaddleModelConvertor::RunOp(const std::string& op_type, RunOp(op_desc, ctx); } -void PaddleModelConvertor::RunOp(const std::string& op_type, - const std::map>& inputs, - const std::map>& outputs, - const std::map& attrs) { +void PaddleModelConvertor::RunOp( + const std::string& op_type, + const std::map>& inputs, + const std::map>& outputs, + const std::map& attrs) { RunOp(op_type, inputs, outputs, attrs, *ctx_); } diff --git a/paddle/cinn/frontend/paddle_model_convertor.h b/paddle/cinn/frontend/paddle_model_convertor.h index 8d7eebe289084..ee83223d8c965 100644 --- a/paddle/cinn/frontend/paddle_model_convertor.h +++ b/paddle/cinn/frontend/paddle_model_convertor.h @@ -31,50 +31,62 @@ namespace cinn { namespace frontend { // Transform paddle model to CINN fronted::Program object. -// The paddle model is readed from __model__ file in model_dir, the PaddleModelConvertor -// will run each op's kernel registered in OpMapper, each kernel will add instruction in -// NetBuilder, after running all op of model, it will invoke its Build function and -// finally return the complete fronted::Program object. -// Note that if anyone op not registered, the program will failed and aborted. +// The paddle model is readed from __model__ file in model_dir, the +// PaddleModelConvertor will run each op's kernel registered in OpMapper, each +// kernel will add instruction in NetBuilder, after running all op of model, it +// will invoke its Build function and finally return the complete +// fronted::Program object. Note that if anyone op not registered, the program +// will failed and aborted. class PaddleModelConvertor { public: PaddleModelConvertor(); PaddleModelConvertor(const common::Target& target, - std::shared_ptr builder = nullptr, + std::shared_ptr builder = nullptr, std::shared_ptr scope = nullptr); // prepare feed variable before run CINN op - void PrepareRun(const paddle::cpp::BlockDesc& block_desc, OpMapperContext* ctx); + void PrepareRun(const paddle::cpp::BlockDesc& block_desc, + OpMapperContext* ctx); - // RunOp accept OpDesc and global run context then run it's kernel registered in OpMapper. - static void RunOp(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx); - - static void RunOp(const std::string& op_type, - const std::map>& inputs, - const std::map>& outputs, - const std::map& attrs, + // RunOp accept OpDesc and global run context then run it's kernel registered + // in OpMapper. + static void RunOp(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx); + static void RunOp( + const std::string& op_type, + const std::map>& inputs, + const std::map>& outputs, + const std::map& attrs, + const OpMapperContext& ctx); + void RunOp(const std::string& op_type, const std::map>& inputs, const std::map>& outputs, const std::map& attrs); - void CreateInput(const std::string& dtype, const cinn::utils::ShapeType& shape, const std::string& name); + void CreateInput(const std::string& dtype, + const cinn::utils::ShapeType& shape, + const std::string& name); Program operator()(); - // operator() accept the modle's directory, and return the fronted::Program object. - Program LoadModel(const std::string& model_dir, - bool is_combined = false, - const std::unordered_map>& feed = {}); + // operator() accept the modle's directory, and return the fronted::Program + // object. + Program LoadModel( + const std::string& model_dir, + bool is_combined = false, + const std::unordered_map>& feed = {}); // return the internal variable map - const std::unordered_map& var_map() const { return var_map_; } + const std::unordered_map& var_map() const { + return var_map_; + } // return the map from the variable name in paddle model to cinn program. - const std::unordered_map& var_model_to_program_map() const { + const std::unordered_map& var_model_to_program_map() + const { return var_model_to_program_map_; } diff --git a/paddle/cinn/frontend/paddle_model_convertor_test.cc b/paddle/cinn/frontend/paddle_model_convertor_test.cc index 7e985b1c570d7..f4e42859a0ad4 100644 --- a/paddle/cinn/frontend/paddle_model_convertor_test.cc +++ b/paddle/cinn/frontend/paddle_model_convertor_test.cc @@ -28,7 +28,7 @@ namespace frontend { template void RandomInput(const Target& target, hlir::framework::Tensor tensor, - T low = static_cast(0), + T low = static_cast(0), T high = static_cast(1)) { std::vector vec; InitRandomVector(&vec, tensor->shape().numel(), low, high); @@ -36,7 +36,10 @@ void RandomInput(const Target& target, } template <> -void RandomInput(const Target& target, hlir::framework::Tensor tensor, bool low, bool high) { +void RandomInput(const Target& target, + hlir::framework::Tensor tensor, + bool low, + bool high) { std::vector vec_int; InitRandomVector(&vec_int, tensor->shape().numel(), 0, 1); @@ -54,7 +57,8 @@ void RunProgram(const Target& target, Program* prog) { input_names.emplace_back(var->id); } - LOG(INFO) << "The Program's inputs are [" << cinn::utils::Join(input_names, ", ") << "]"; + LOG(INFO) << "The Program's inputs are [" + << cinn::utils::Join(input_names, ", ") << "]"; auto passes = DefaultTrainingOptimizeOptions(); @@ -93,8 +97,9 @@ TEST(PaddleModelConvertor, basic) { model_transform.LoadModel(FLAGS_model_dir); auto program = model_transform(); - const auto& var_map = model_transform.var_map(); - const auto& var_model_to_program_map = model_transform.var_model_to_program_map(); + const auto& var_map = model_transform.var_map(); + const auto& var_model_to_program_map = + model_transform.var_model_to_program_map(); ASSERT_FALSE(var_map.empty()); ASSERT_FALSE(var_model_to_program_map.empty()); diff --git a/paddle/cinn/frontend/paddle_model_to_program.cc b/paddle/cinn/frontend/paddle_model_to_program.cc index 5d3a129ad92e6..316712ff40e61 100644 --- a/paddle/cinn/frontend/paddle_model_to_program.cc +++ b/paddle/cinn/frontend/paddle_model_to_program.cc @@ -28,12 +28,12 @@ using utils::TransValidVarName; void MoveData(float* data, int i, int M, int N) { float temp = data[i]; - int cur = i; // current data index - int pre = (cur % M) * N + cur / M; + int cur = i; // current data index + int pre = (cur % M) * N + cur / M; while (pre != i) { data[cur] = data[pre]; - cur = pre; - pre = (cur % M) * N + cur / M; + cur = pre; + pre = (cur % M) * N + cur / M; } data[cur] = temp; } @@ -63,7 +63,7 @@ void PaddleModelToProgram::AddOpMapper_feed() { VLOG(2) << "Model get feed [" << outs[0] << "]"; CHECK(input_shape_map_.count(outs[0])); auto input_shape = input_shape_map_[outs[0]]; - auto input = net_builder_->CreateInput(Float(32), input_shape, outs[0]); + auto input = net_builder_->CreateInput(Float(32), input_shape, outs[0]); AddVar(outs[0], input); }; } @@ -83,7 +83,7 @@ void PaddleModelToProgram::AddOpMapper_scale() { op_mappers_["scale"] = [&](const paddle::cpp::OpDesc& op_desc) { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - auto x = GetVar(utils::TransValidVarName(x_name)); + auto x = GetVar(utils::TransValidVarName(x_name)); float scale{}; float bias{}; if (op_desc.HasAttr("scale")) { // the old model format @@ -91,10 +91,12 @@ void PaddleModelToProgram::AddOpMapper_scale() { } else { // the newly refactored format // load scale tensor CHECK_EQ(op_desc.Input("ScaleTensor").size(), 1UL); - auto* scale_tensor_var = scope_->FindVar(op_desc.Input("ScaleTensor").front()); + auto* scale_tensor_var = + scope_->FindVar(op_desc.Input("ScaleTensor").front()); CHECK(scale_tensor_var) << "No scale tensor found in the scope"; - auto& scale_tensor = absl::get(*scale_tensor_var); - scale = scale_tensor->mutable_data(common::DefaultHostTarget())[0]; + auto& scale_tensor = + absl::get(*scale_tensor_var); + scale = scale_tensor->mutable_data(common::DefaultHostTarget())[0]; } if (op_desc.HasAttr("bias")) { // the old model format bias = op_desc.GetAttr("bias"); @@ -116,9 +118,9 @@ void PaddleModelToProgram::AddOpMapper_mul() { auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Y").size(), 1UL); auto y_name = op_desc.Input("Y").front(); - auto x = GetVar(utils::TransValidVarName(x_name)); + auto x = GetVar(utils::TransValidVarName(x_name)); TransposeVar(TransValidVarName(y_name)); - auto y = GetVar(utils::TransValidVarName(y_name)); + auto y = GetVar(utils::TransValidVarName(y_name)); int x_num_col_dims = op_desc.GetAttr("x_num_col_dims"); int y_num_col_dims = op_desc.GetAttr("y_num_col_dims"); @@ -127,7 +129,8 @@ void PaddleModelToProgram::AddOpMapper_mul() { VLOG(4) << "x shape: " << utils::Join(x->shape, ","); VLOG(4) << "y shape: " << utils::Join(y->shape, ","); - const auto& out = net_builder_->Mul(x, y, x_num_col_dims, y_num_col_dims, true); + const auto& out = + net_builder_->Mul(x, y, x_num_col_dims, y_num_col_dims, true); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); @@ -141,12 +144,12 @@ void PaddleModelToProgram::AddOpMapper_matmul() { CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Input("Y").size(), 1UL); - auto y_name = op_desc.Input("Y").front(); - auto x = GetVar(utils::TransValidVarName(x_name)); - auto y = GetVar(utils::TransValidVarName(y_name)); + auto y_name = op_desc.Input("Y").front(); + auto x = GetVar(utils::TransValidVarName(x_name)); + auto y = GetVar(utils::TransValidVarName(y_name)); bool trans_a = op_desc.GetAttr("transpose_X"); bool trans_b = op_desc.GetAttr("transpose_Y"); - float alpha = op_desc.GetAttr("alpha"); + float alpha = op_desc.GetAttr("alpha"); VLOG(4) << "x shape: " << utils::Join(x->shape, ","); VLOG(4) << "y shape: " << utils::Join(y->shape, ","); auto out = net_builder_->Matmul(x, y, trans_a, trans_b, alpha); @@ -160,8 +163,8 @@ void PaddleModelToProgram::AddOpMapper_matmul() { void PaddleModelToProgram::AddOpMapper_reshape2() { op_mappers_["reshape2"] = [&](const paddle::cpp::OpDesc& op_desc) { CHECK_EQ(op_desc.Input("X").size(), 1UL); - auto x_name = op_desc.Input("X").front(); - auto x = GetVar(utils::TransValidVarName(x_name)); + auto x_name = op_desc.Input("X").front(); + auto x = GetVar(utils::TransValidVarName(x_name)); std::vector shape = op_desc.GetAttr>("shape"); VLOG(4) << "x shape: " << utils::Join(x->shape, ","); auto out = net_builder_->Reshape(x, shape); @@ -197,8 +200,8 @@ void PaddleModelToProgram::AddOpMapper_assign() { auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = GetVar(TransValidVarName(x_name)); - auto out = net_builder_->Identity(x); + auto x = GetVar(TransValidVarName(x_name)); + auto out = net_builder_->Identity(x); AddVar(TransValidVarName(out_name), out); var_model_to_program_map_[out_name] = out->id; }; @@ -227,9 +230,10 @@ void PaddleModelToProgram::AddOpMapper_fill_constant() { Variable out; switch (dtype) { -#define DO(desc, type) \ - case ::cinn::frontend::paddle::proto::VarType::Type::VarType_Type_##desc: \ - out = net_builder_->FillConstant(shapes, value, str_value, force_cpu); \ +#define DO(desc, type) \ + case ::cinn::frontend::paddle::proto::VarType::Type::VarType_Type_##desc: \ + out = \ + net_builder_->FillConstant(shapes, value, str_value, force_cpu); \ break; DO(BOOL, bool); DO(FP32, float); @@ -249,7 +253,7 @@ void PaddleModelToProgram::AddOpMapper_transpose2() { auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = GetVar(TransValidVarName(x_name)); + auto x = GetVar(TransValidVarName(x_name)); CHECK(op_desc.HasAttr("axis")); auto axis = op_desc.GetAttr>("axis"); @@ -266,7 +270,7 @@ void PaddleModelToProgram::AddOpMapper_exp() { auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = GetVar(TransValidVarName(x_name)); + auto x = GetVar(TransValidVarName(x_name)); auto out = net_builder_->Exp(x); @@ -281,8 +285,8 @@ void PaddleModelToProgram::AddOpMapper_relu() { auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = GetVar(TransValidVarName(x_name)); - auto out = net_builder_->Relu(x); + auto x = GetVar(TransValidVarName(x_name)); + auto out = net_builder_->Relu(x); AddVar(TransValidVarName(out_name), out); var_model_to_program_map_[out_name] = out->id; @@ -302,7 +306,7 @@ void PaddleModelToProgram::AddOpMapper_softmax() { } else { axis = static_cast(-1); } - auto x = GetVar(TransValidVarName(x_name)); + auto x = GetVar(TransValidVarName(x_name)); auto out = net_builder_->Softmax(x, {axis}); AddVar(TransValidVarName(out_name), out); var_model_to_program_map_[out_name] = out->id; @@ -317,10 +321,10 @@ void PaddleModelToProgram::AddOpMapper_elementwise_add() { auto y_name = op_desc.Input("Y").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - int axis = op_desc.GetAttr("axis"); + int axis = op_desc.GetAttr("axis"); - auto x = GetVar(TransValidVarName(x_name)); - auto y = GetVar(TransValidVarName(y_name)); + auto x = GetVar(TransValidVarName(x_name)); + auto y = GetVar(TransValidVarName(y_name)); auto out = net_builder_->Add(x, y, axis); AddVar(TransValidVarName(out_name), out); @@ -336,10 +340,10 @@ void PaddleModelToProgram::AddOpMapper_elementwise_mul() { auto y_name = op_desc.Input("Y").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - int axis = op_desc.GetAttr("axis"); + int axis = op_desc.GetAttr("axis"); - auto x = GetVar(TransValidVarName(x_name)); - auto y = GetVar(TransValidVarName(y_name)); + auto x = GetVar(TransValidVarName(x_name)); + auto y = GetVar(TransValidVarName(y_name)); auto out = net_builder_->Multiply(x, y, axis); AddVar(TransValidVarName(out_name), out); @@ -358,8 +362,8 @@ void PaddleModelToProgram::AddOpMapper_elementwise_div() { CHECK(op_desc.HasAttr("axis")); int axis = op_desc.GetAttr("axis"); - auto x = GetVar(TransValidVarName(x_name)); - auto y = GetVar(TransValidVarName(y_name)); + auto x = GetVar(TransValidVarName(x_name)); + auto y = GetVar(TransValidVarName(y_name)); auto out = net_builder_->Divide(x, y, axis); AddVar(TransValidVarName(out_name), out); @@ -378,8 +382,8 @@ void PaddleModelToProgram::AddOpMapper_elementwise_sub() { CHECK(op_desc.HasAttr("axis")); int axis = op_desc.GetAttr("axis"); - auto x = GetVar(TransValidVarName(x_name)); - auto y = GetVar(TransValidVarName(y_name)); + auto x = GetVar(TransValidVarName(x_name)); + auto y = GetVar(TransValidVarName(y_name)); auto out = net_builder_->Subtract(x, y, axis); AddVar(TransValidVarName(out_name), out); @@ -396,10 +400,11 @@ void PaddleModelToProgram::AddOpMapper_relu6() { absl::flat_hash_map attrs; CHECK(op_desc.HasAttr("threshold")); - CHECK_EQ(op_desc.GetAttr("threshold"), 6.0f) << "Threshold of Relu6 is not 6! To be implemented."; + CHECK_EQ(op_desc.GetAttr("threshold"), 6.0f) + << "Threshold of Relu6 is not 6! To be implemented."; attrs["threshold"] = op_desc.GetAttr("threshold"); - auto x = GetVar(TransValidVarName(x_name)); + auto x = GetVar(TransValidVarName(x_name)); auto out = net_builder_->Relu6(x); AddVar(TransValidVarName(out_name), out); var_model_to_program_map_[out_name] = out->id; @@ -431,9 +436,11 @@ void PaddleModelToProgram::AddOpMapper_depthwise_conv2d() { auto y = GetVar(TransValidVarName(y_name)); Variable out; if (target_.arch == Target::Arch::X86) { - out = net_builder_->Conv2d(x, y, strides, paddings, dilations, groups, data_format); + out = net_builder_->Conv2d( + x, y, strides, paddings, dilations, groups, data_format); } else { - out = net_builder_->DepthwiseConv2d(x, y, strides, paddings, dilations, groups, data_format); + out = net_builder_->DepthwiseConv2d( + x, y, strides, paddings, dilations, groups, data_format); } AddVar(TransValidVarName(out_name), out); @@ -463,9 +470,10 @@ void PaddleModelToProgram::AddOpMapper_conv2d() { if (data_format == "AnyLayout") { data_format = "NCHW"; } - auto x = GetVar(TransValidVarName(x_name)); - auto y = GetVar(TransValidVarName(y_name)); - auto out = net_builder_->Conv2d(x, y, strides, paddings, dilations, groups, data_format); + auto x = GetVar(TransValidVarName(x_name)); + auto y = GetVar(TransValidVarName(y_name)); + auto out = net_builder_->Conv2d( + x, y, strides, paddings, dilations, groups, data_format); AddVar(TransValidVarName(out_name), out); var_model_to_program_map_[out_name] = out->id; @@ -498,9 +506,17 @@ void PaddleModelToProgram::AddOpMapper_pool2d() { CHECK(op_desc.HasAttr("adaptive")); auto adaptive = op_desc.GetAttr("adaptive"); - auto x = GetVar(TransValidVarName(x_name)); - auto out = net_builder_->Pool2d( - x, pool_type, ksize, strides, paddings, ceil_mode, exclusive, global_pooling, data_format, adaptive); + auto x = GetVar(TransValidVarName(x_name)); + auto out = net_builder_->Pool2d(x, + pool_type, + ksize, + strides, + paddings, + ceil_mode, + exclusive, + global_pooling, + data_format, + adaptive); AddVar(TransValidVarName(out_name), out); var_model_to_program_map_[out_name] = out->id; @@ -522,10 +538,10 @@ void PaddleModelToProgram::AddOpMapper_batchnorm() { CHECK(!op_desc.Output("Y").empty()); auto out_name = op_desc.Output("Y").front(); - auto x = GetVar(TransValidVarName(x_name)); - auto scale = GetVar(TransValidVarName(scale_name)); - auto bias = GetVar(TransValidVarName(bias_name)); - auto mean = GetVar(TransValidVarName(mean_name)); + auto x = GetVar(TransValidVarName(x_name)); + auto scale = GetVar(TransValidVarName(scale_name)); + auto bias = GetVar(TransValidVarName(bias_name)); + auto mean = GetVar(TransValidVarName(mean_name)); auto variance = GetVar(TransValidVarName(variance_name)); CHECK(op_desc.HasAttr("epsilon")); auto epsilon = op_desc.GetAttr("epsilon"); @@ -535,7 +551,8 @@ void PaddleModelToProgram::AddOpMapper_batchnorm() { // auto data_format = op_desc.GetAttr("data_format"); std::string data_format = "NCHW"; - auto out = net_builder_->BatchNorm(x, scale, bias, mean, variance, epsilon, momentum, data_format, true); + auto out = net_builder_->BatchNorm( + x, scale, bias, mean, variance, epsilon, momentum, data_format, true); AddVar(TransValidVarName(out_name), out[0]); var_model_to_program_map_[out_name] = out[0]->id; @@ -548,8 +565,8 @@ void PaddleModelToProgram::AddOpMapper_sigmoid() { auto x_name = op_desc.Input("X").front(); CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); - auto x = GetVar(TransValidVarName(x_name)); - auto out = net_builder_->Sigmoid(x); + auto x = GetVar(TransValidVarName(x_name)); + auto out = net_builder_->Sigmoid(x); AddVar(TransValidVarName(out_name), out); var_model_to_program_map_[out_name] = out->id; @@ -570,8 +587,8 @@ void PaddleModelToProgram::AddOpMapper_slice() { auto end = op_desc.GetAttr>("ends"); CHECK(op_desc.HasAttr("axes")); auto axes = op_desc.GetAttr>("axes"); - auto x = GetVar(TransValidVarName(x_name)); - auto out = net_builder_->Slice(x, axes, starts, end); + auto x = GetVar(TransValidVarName(x_name)); + auto out = net_builder_->Slice(x, axes, starts, end); AddVar(TransValidVarName(out_name), out); var_model_to_program_map_[out_name] = out->id; @@ -589,9 +606,11 @@ void PaddleModelToProgram::AddOpMapper_dropout_infer() { CHECK(op_desc.HasAttr("dropout_prob")); auto dropout_prob = op_desc.GetAttr("dropout_prob"); CHECK(op_desc.HasAttr("dropout_implementation")); - auto dropout_implementation = op_desc.GetAttr("dropout_implementation"); - auto x = GetVar(TransValidVarName(x_name)); - auto out = net_builder_->DropoutInfer(x, dropout_prob, dropout_implementation); + auto dropout_implementation = + op_desc.GetAttr("dropout_implementation"); + auto x = GetVar(TransValidVarName(x_name)); + auto out = + net_builder_->DropoutInfer(x, dropout_prob, dropout_implementation); AddVar(TransValidVarName(out_name), out); var_model_to_program_map_[out_name] = out->id; @@ -600,7 +619,7 @@ void PaddleModelToProgram::AddOpMapper_dropout_infer() { void PaddleModelToProgram::AddOp(const paddle::cpp::OpDesc& op_desc) { const auto& op_type = op_desc.Type(); - auto it = op_mappers_.find(op_type); + auto it = op_mappers_.find(op_type); if (it != op_mappers_.end()) { it->second(op_desc); return; @@ -616,23 +635,30 @@ void PaddleModelToProgram::TransposeVar(const std::string& name) { auto& tensor = absl::get(*var); if (target_.arch == Target::Arch::X86) { float* data = tensor->mutable_data(target_); - CHECK(tensor->shape().size() == 2) << "The y data's shape size of op [mul] is not equal to 2! Please check."; + CHECK(tensor->shape().size() == 2) + << "The y data's shape size of op [mul] is not equal to 2! Please " + "check."; TransposeData(data, tensor->shape().data()[0], tensor->shape().data()[1]); } else if (target_.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDA // To use cublas mul api, there is no need to transpose data. #ifndef CINN_WITH_CUDNN std::vector data(tensor->shape().numel()); - CUDA_CALL(cudaMemcpy(data.data(), - reinterpret_cast(tensor->mutable_data(target_)), - tensor->shape().numel() * sizeof(float), - cudaMemcpyDeviceToHost)); - CHECK(tensor->shape().size() == 2) << "The y data's shape size of op [mul] is not equal to 2! Please check."; - TransposeData(data.data(), tensor->shape().data()[0], tensor->shape().data()[1]); - CUDA_CALL(cudaMemcpy(reinterpret_cast(tensor->mutable_data(target_)), - data.data(), - tensor->shape().numel() * sizeof(float), - cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy( + data.data(), + reinterpret_cast(tensor->mutable_data(target_)), + tensor->shape().numel() * sizeof(float), + cudaMemcpyDeviceToHost)); + CHECK(tensor->shape().size() == 2) + << "The y data's shape size of op [mul] is not equal to 2! Please " + "check."; + TransposeData( + data.data(), tensor->shape().data()[0], tensor->shape().data()[1]); + CUDA_CALL(cudaMemcpy( + reinterpret_cast(tensor->mutable_data(target_)), + data.data(), + tensor->shape().numel() * sizeof(float), + cudaMemcpyHostToDevice)); #endif #else LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; @@ -662,21 +688,27 @@ void PaddleModelToProgram::ReverseHWVar(const std::string& name) { auto& tensor = absl::get(*var); if (target_.arch == Target::Arch::X86) { float* data = tensor->mutable_data(target_); - CHECK(tensor->shape().size() == 4) << "The y data's shape size of op [conv2d] is not equal to 4! Please check."; + CHECK(tensor->shape().size() == 4) + << "The y data's shape size of op [conv2d] is not equal to 4! Please " + "check."; ReverseHWData(data, tensor->shape().data()); } else if (target_.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDA std::vector data(tensor->shape().numel()); - CUDA_CALL(cudaMemcpy(data.data(), - reinterpret_cast(tensor->mutable_data(target_)), - tensor->shape().numel() * sizeof(float), - cudaMemcpyDeviceToHost)); - CHECK(tensor->shape().size() == 4) << "The y data's shape size of op [conv2d] is not equal to 4! Please check."; + CUDA_CALL(cudaMemcpy( + data.data(), + reinterpret_cast(tensor->mutable_data(target_)), + tensor->shape().numel() * sizeof(float), + cudaMemcpyDeviceToHost)); + CHECK(tensor->shape().size() == 4) + << "The y data's shape size of op [conv2d] is not equal to 4! Please " + "check."; ReverseHWData(data.data(), tensor->shape().data()); - CUDA_CALL(cudaMemcpy(reinterpret_cast(tensor->mutable_data(target_)), - data.data(), - tensor->shape().numel() * sizeof(float), - cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy( + reinterpret_cast(tensor->mutable_data(target_)), + data.data(), + tensor->shape().numel() * sizeof(float), + cudaMemcpyHostToDevice)); #else LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; #endif @@ -711,10 +743,19 @@ Variable PaddleModelToProgram::GetVar(const std::string& name) { return Variable(); } -std::unique_ptr PaddleModelToProgram::operator()(const std::string& model_dir, bool is_combined) { +std::unique_ptr PaddleModelToProgram::operator()( + const std::string& model_dir, bool is_combined) { paddle::cpp::ProgramDesc program_desc; - paddle::LoadModelPb(model_dir, "__model__", "", scope_, &program_desc, is_combined, false, target_); - CHECK_EQ(program_desc.BlocksSize(), 1) << "CINN can only support the model with a single block"; + paddle::LoadModelPb(model_dir, + "__model__", + "", + scope_, + &program_desc, + is_combined, + false, + target_); + CHECK_EQ(program_desc.BlocksSize(), 1) + << "CINN can only support the model with a single block"; auto* block_desc = program_desc.GetBlock(0); for (int i = 0; i < block_desc->OpsSize(); i++) { @@ -724,7 +765,9 @@ std::unique_ptr PaddleModelToProgram::operator()(const std::string& mod return std::unique_ptr(new Program(net_builder_->Build())); } -void PaddleModelToProgram::AddVar(const std::string& name, const Variable& var, bool replace) { +void PaddleModelToProgram::AddVar(const std::string& name, + const Variable& var, + bool replace) { CheckVarNameValid(name); if (replace == false) { CHECK(!var_map_.count(name)) << "Duplicate variable [" << name << "] found"; diff --git a/paddle/cinn/frontend/paddle_model_to_program.h b/paddle/cinn/frontend/paddle_model_to_program.h index 18a9874760446..ab520e608de37 100644 --- a/paddle/cinn/frontend/paddle_model_to_program.h +++ b/paddle/cinn/frontend/paddle_model_to_program.h @@ -40,9 +40,10 @@ namespace frontend { class PaddleModelToProgram { public: - explicit PaddleModelToProgram(hlir::framework::Scope* scope, - std::unordered_map> input_shape_map, - const common::Target& target) + explicit PaddleModelToProgram( + hlir::framework::Scope* scope, + std::unordered_map> input_shape_map, + const common::Target& target) : scope_(scope), input_shape_map_(input_shape_map), target_(target), @@ -76,7 +77,8 @@ class PaddleModelToProgram { AddOpMapper_exp(); } - std::unique_ptr operator()(const std::string& model_dir, bool is_combined); + std::unique_ptr operator()(const std::string& model_dir, + bool is_combined); // Add an Instruction to a program given a Paddle-format \p op_desc. void AddOp(const paddle::cpp::OpDesc& op_desc); @@ -109,12 +111,19 @@ class PaddleModelToProgram { void AddOpMapper_exp(); // @} - const absl::flat_hash_map& var_map() const { return var_map_; } - const absl::flat_hash_map& var_model_to_program_map() { return var_model_to_program_map_; } + const absl::flat_hash_map& var_map() const { + return var_map_; + } + const absl::flat_hash_map& + var_model_to_program_map() { + return var_model_to_program_map_; + } const absl::flat_hash_set& fetch_names() { return fetch_names_; } protected: - void AddVar(const std::string& name, const Variable& var, bool replace = false); + void AddVar(const std::string& name, + const Variable& var, + bool replace = false); Variable GetVar(const std::string& name); @@ -124,7 +133,9 @@ class PaddleModelToProgram { private: // op mapper - absl::flat_hash_map> op_mappers_; + absl::flat_hash_map> + op_mappers_; std::unordered_map> input_shape_map_; // net builder std::unique_ptr net_builder_; diff --git a/paddle/cinn/frontend/pass/auto_broadcast.cc b/paddle/cinn/frontend/pass/auto_broadcast.cc index e68e260d533d3..558105a44ad25 100644 --- a/paddle/cinn/frontend/pass/auto_broadcast.cc +++ b/paddle/cinn/frontend/pass/auto_broadcast.cc @@ -18,12 +18,12 @@ #include #include +#include "glog/logging.h" #include "paddle/cinn/frontend/net_builder.h" #include "paddle/cinn/frontend/program_pass.h" #include "paddle/cinn/frontend/syntax.h" #include "paddle/cinn/utils/string.h" #include "paddle/cinn/utils/type_defs.h" -#include "glog/logging.h" namespace cinn { namespace frontend { @@ -47,7 +47,8 @@ class AutoBroadcastPass : public ProgramPass { axis = output_shape.size() - input_shape.size(); } CHECK_LE(axis + input_shape.size(), output_shape.size()) - << "Cannot Broadcast from shape=[" << cinn::utils::Join(input_shape, ", ") << "] to shape=[" + << "Cannot Broadcast from shape=[" + << cinn::utils::Join(input_shape, ", ") << "] to shape=[" << cinn::utils::Join(output_shape, ", ") << "] with axis=" << axis; for (int idx = 0; idx < input_shape.size(); ++idx) { broadcast_axes.push_back(axis++); @@ -57,13 +58,14 @@ class AutoBroadcastPass : public ProgramPass { } void InsertBroadcastTo(NetBuilder* builder, Instruction* broadcast_op) { - const auto& instr = *broadcast_op; + const auto& instr = *broadcast_op; const auto& op_name = instr->op_type; - const auto& op_pattern_dict_ = - &cinn::hlir::framework::Operator::GetAttrs("OpPattern"); + const auto& op_pattern_dict_ = &cinn::hlir::framework::Operator::GetAttrs< + cinn::hlir::framework::OpPatternKind>("OpPattern"); const auto* op = cinn::hlir::framework::Operator::Get(op_name); - if (!op_pattern_dict_->Find(op) || (*op_pattern_dict_)[op] != cinn::hlir::framework::kBroadcast) { + if (!op_pattern_dict_->Find(op) || + (*op_pattern_dict_)[op] != cinn::hlir::framework::kBroadcast) { // no set OpPattern or not broadcast kind operator, skip builder->AppendInstruction(instr); return; @@ -75,7 +77,8 @@ class AutoBroadcastPass : public ProgramPass { } const auto& outputs = instr.GetOutputs(); - CHECK_EQ(outputs.size(), 1) << "The broadcast operator should has and only has one output"; + CHECK_EQ(outputs.size(), 1) + << "The broadcast operator should has and only has one output"; const auto& output = outputs.front(); int axis = -1; @@ -93,7 +96,10 @@ class AutoBroadcastPass : public ProgramPass { // else insert broadcast_to need_insert = true; - auto new_var = builder->BroadcastTo(input, output->shape, GetBroadcastAxes(input->shape, output->shape, axis)); + auto new_var = builder->BroadcastTo( + input, + output->shape, + GetBroadcastAxes(input->shape, output->shape, axis)); new_inputs.emplace_back(new_var); } } @@ -133,7 +139,8 @@ class AutoBroadcastPass : public ProgramPass { } // namespace cinn CINN_REGISTER_HELPER(AutoBroadcast) { - CINN_REGISTER_PROGRAM_PASS(AutoBroadcast, cinn::frontend::pass::AutoBroadcastPass); + CINN_REGISTER_PROGRAM_PASS(AutoBroadcast, + cinn::frontend::pass::AutoBroadcastPass); return true; } diff --git a/paddle/cinn/frontend/pass/auto_cast.cc b/paddle/cinn/frontend/pass/auto_cast.cc index f1bf636d0dab4..838ff8b06f1dd 100644 --- a/paddle/cinn/frontend/pass/auto_cast.cc +++ b/paddle/cinn/frontend/pass/auto_cast.cc @@ -17,16 +17,17 @@ #include #include +#include "glog/logging.h" #include "paddle/cinn/frontend/net_builder.h" #include "paddle/cinn/frontend/program_pass.h" -#include "glog/logging.h" namespace cinn { namespace frontend { namespace pass { namespace { -using CastImplFunc = std::function; +using CastImplFunc = + std::function; bool IsInputHasFP16OrBF16(const std::vector& inputs) { return std::find_if(inputs.begin(), inputs.end(), [](const Variable& var) { @@ -34,15 +35,17 @@ bool IsInputHasFP16OrBF16(const std::vector& inputs) { }) != inputs.end(); } -Instruction CreateNewCastInstruction(const Variable& input, const Variable& output) { +Instruction CreateNewCastInstruction(const Variable& input, + const Variable& output) { Instruction new_cast_instr("cast", {input}); - new_cast_instr->outputs = {output}; - new_cast_instr->attrs = {{"dtype", common::Type2Str(output->type)}}; + new_cast_instr->outputs = {output}; + new_cast_instr->attrs = {{"dtype", common::Type2Str(output->type)}}; new_cast_instr->attrs_ordered = {{"dtype", common::Type2Str(output->type)}}; return new_cast_instr; } -Instruction CreateNewIdentityInstruction(const Variable& input, const Variable& output) { +Instruction CreateNewIdentityInstruction(const Variable& input, + const Variable& output) { Instruction new_identity_instr("identity", {input}); new_identity_instr->outputs = {output}; return new_identity_instr; @@ -65,11 +68,13 @@ void CommonCastImpl(NetBuilder* builder, const Instruction& instr) { casted_inputs.emplace_back(casted_var); } // Run fp32 op - const auto& outputs = builder->CustomInstr(instr->op_type, casted_inputs, instr->attrs); + const auto& outputs = + builder->CustomInstr(instr->op_type, casted_inputs, instr->attrs); // Cast all fp32 outputs to fp16/bf16 for (int i = 0; i < outputs.size(); ++i) { if (outputs[i]->type.is_float(32)) { - builder->AppendInstruction(CreateNewCastInstruction(outputs[i], instr->outputs[i])); + builder->AppendInstruction( + CreateNewCastInstruction(outputs[i], instr->outputs[i])); } } } @@ -115,26 +120,36 @@ static std::unordered_map need_cast_list = { // Except input [X], BatchNormTrain's Input should all be fp32 CHECK_EQ(instr->inputs.size(), 5UL) - << "The number of the given inputs is not equal to the required for op " << instr->op_type; + << "The number of the given inputs is not equal to the required for " + "op " + << instr->op_type; CHECK(instr->inputs[1]->type.is_float(32)) - << instr->op_type << "'s input [scale] should be float32, but here " << instr->inputs[1]->type; + << instr->op_type << "'s input [scale] should be float32, but here " + << instr->inputs[1]->type; CHECK(instr->inputs[2]->type.is_float(32)) - << instr->op_type << "'s input [bias] should be float32, but here " << instr->inputs[1]->type; + << instr->op_type << "'s input [bias] should be float32, but here " + << instr->inputs[1]->type; CHECK(instr->inputs[3]->type.is_float(32)) - << instr->op_type << "'s input [moving_mean] should be float32, but here " << instr->inputs[1]->type; + << instr->op_type + << "'s input [moving_mean] should be float32, but here " + << instr->inputs[1]->type; CHECK(instr->inputs[4]->type.is_float(32)) - << instr->op_type << "'s input [moving_variance] should be float32, but here " << instr->inputs[1]->type; + << instr->op_type + << "'s input [moving_variance] should be float32, but here " + << instr->inputs[1]->type; // Cast input [X] from fp16/bf16 to fp32 - const auto& x = instr->inputs[0]; + const auto& x = instr->inputs[0]; const auto& x_casted = builder->Cast(x, "float32"); auto casted_inputs = instr->inputs; - casted_inputs[0] = x_casted; + casted_inputs[0] = x_casted; // Run fp32 function - const auto& outputs = builder->CustomInstr(instr->op_type, casted_inputs, instr->attrs); + const auto& outputs = + builder->CustomInstr(instr->op_type, casted_inputs, instr->attrs); // Cast output [Y] to fp16/bf16, no other output - builder->AppendInstruction(CreateNewCastInstruction(outputs[0], instr->outputs[0])); + builder->AppendInstruction( + CreateNewCastInstruction(outputs[0], instr->outputs[0])); }}, {"batch_norm_train", [](NetBuilder* builder, const Instruction& instr) { @@ -146,29 +161,40 @@ static std::unordered_map need_cast_list = { // Except input [X], BatchNormTrain's Input should all be fp32 CHECK_EQ(instr->inputs.size(), 5UL) - << "The number of the given inputs is not equal to the required for op " << instr->op_type; + << "The number of the given inputs is not equal to the required for " + "op " + << instr->op_type; CHECK(instr->inputs[1]->type.is_float(32)) - << instr->op_type << "'s input [scale] should be float32, but here " << instr->inputs[1]->type; + << instr->op_type << "'s input [scale] should be float32, but here " + << instr->inputs[1]->type; CHECK(instr->inputs[2]->type.is_float(32)) - << instr->op_type << "'s input [bias] should be float32, but here " << instr->inputs[1]->type; + << instr->op_type << "'s input [bias] should be float32, but here " + << instr->inputs[1]->type; CHECK(instr->inputs[3]->type.is_float(32)) - << instr->op_type << "'s input [moving_mean] should be float32, but here " << instr->inputs[1]->type; + << instr->op_type + << "'s input [moving_mean] should be float32, but here " + << instr->inputs[1]->type; CHECK(instr->inputs[4]->type.is_float(32)) - << instr->op_type << "'s input [moving_variance] should be float32, but here " << instr->inputs[1]->type; + << instr->op_type + << "'s input [moving_variance] should be float32, but here " + << instr->inputs[1]->type; // Cast input [X] from fp16/bf16 to fp32 - const auto& x = instr->inputs[0]; + const auto& x = instr->inputs[0]; const auto& x_casted = builder->Cast(x, "float32"); auto casted_inputs = instr->inputs; - casted_inputs[0] = x_casted; + casted_inputs[0] = x_casted; // Run fp32 function - const auto& outputs = builder->CustomInstr(instr->op_type, casted_inputs, instr->attrs); + const auto& outputs = + builder->CustomInstr(instr->op_type, casted_inputs, instr->attrs); // Cast output [Y] to fp16/bf16 - builder->AppendInstruction(CreateNewCastInstruction(outputs[0], instr->outputs[0])); + builder->AppendInstruction( + CreateNewCastInstruction(outputs[0], instr->outputs[0])); // Identity other output for (int i = 1; i < outputs.size(); ++i) { - builder->AppendInstruction(CreateNewIdentityInstruction(outputs[i], instr->outputs[i])); + builder->AppendInstruction( + CreateNewIdentityInstruction(outputs[i], instr->outputs[i])); } }}, {"batch_norm_grad", [](NetBuilder* builder, const Instruction& instr) { @@ -180,34 +206,45 @@ static std::unordered_map need_cast_list = { // Except input [X], BatchNormTrain's Input should all be fp32 CHECK_EQ(instr->inputs.size(), 5UL) - << "The number of the given inputs is not equal to the required for op " << instr->op_type; + << "The number of the given inputs is not equal to the required for " + "op " + << instr->op_type; CHECK_EQ(instr->inputs[0]->type, instr->inputs[1]->type) - << instr->op_type << "'s input [Y@GRAD] and input [X] 's type should be the same"; + << instr->op_type + << "'s input [Y@GRAD] and input [X] 's type should be the same"; CHECK(instr->inputs[2]->type.is_float(32)) - << instr->op_type << "'s input [scale] should be float32, but here " << instr->inputs[1]->type; + << instr->op_type << "'s input [scale] should be float32, but here " + << instr->inputs[1]->type; CHECK(instr->inputs[3]->type.is_float(32)) - << instr->op_type << "'s input [save_mean] should be float32, but here " << instr->inputs[1]->type; + << instr->op_type + << "'s input [save_mean] should be float32, but here " + << instr->inputs[1]->type; CHECK(instr->inputs[4]->type.is_float(32)) - << instr->op_type << "'s input [save_variance] should be float32, but here " << instr->inputs[1]->type; + << instr->op_type + << "'s input [save_variance] should be float32, but here " + << instr->inputs[1]->type; // Cast input [Y@GRAD] from fp16/bf16 to fp32 - const auto& y_grad = instr->inputs[0]; + const auto& y_grad = instr->inputs[0]; const auto& y_grad_casted = builder->Cast(y_grad, "float32"); // Cast input [X] from fp16/bf16 to fp32 - const auto& x = instr->inputs[1]; + const auto& x = instr->inputs[1]; const auto& x_casted = builder->Cast(x, "float32"); auto casted_inputs = instr->inputs; - casted_inputs[0] = y_grad_casted; - casted_inputs[1] = x_casted; + casted_inputs[0] = y_grad_casted; + casted_inputs[1] = x_casted; // Run fp32 function - const auto& outputs = builder->CustomInstr(instr->op_type, casted_inputs, instr->attrs); + const auto& outputs = + builder->CustomInstr(instr->op_type, casted_inputs, instr->attrs); // Cast output [X@GRAD] to fp16/bf16 - builder->AppendInstruction(CreateNewCastInstruction(outputs[0], instr->outputs[0])); + builder->AppendInstruction( + CreateNewCastInstruction(outputs[0], instr->outputs[0])); // Identity other output for (int i = 1; i < outputs.size(); ++i) { - builder->AppendInstruction(CreateNewIdentityInstruction(outputs[i], instr->outputs[i])); + builder->AppendInstruction( + CreateNewIdentityInstruction(outputs[i], instr->outputs[i])); } }}}; } // namespace diff --git a/paddle/cinn/frontend/pass/auto_cast_test.cc b/paddle/cinn/frontend/pass/auto_cast_test.cc index 5cf5ae333133e..4b570a2755cdb 100644 --- a/paddle/cinn/frontend/pass/auto_cast_test.cc +++ b/paddle/cinn/frontend/pass/auto_cast_test.cc @@ -33,53 +33,59 @@ namespace cinn::frontend { TEST(AutoCast, Exp) { NetBuilder builder("net_builder"); - auto x = builder.CreateInput(common::Float16(), {4, 5, 3}, "X"); - auto out = builder.Exp(x); + auto x = builder.CreateInput(common::Float16(), {4, 5, 3}, "X"); + auto out = builder.Exp(x); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); - std::pair, std::vector> passes{{}, {"AutoCast", "Decomposer"}}; + std::pair, std::vector> passes{ + {}, {"AutoCast", "Decomposer"}}; CompareProgramPassResult(&program, target, {out->id}, -2, passes); } TEST(AutoCast, Exp_bf16) { NetBuilder builder("net_builder"); - auto x = builder.CreateInput(common::BFloat16(), {4, 5, 3}, "X"); - auto out = builder.Exp(x); + auto x = builder.CreateInput(common::BFloat16(), {4, 5, 3}, "X"); + auto out = builder.Exp(x); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); - std::pair, std::vector> passes{{}, {"AutoCast", "Decomposer"}}; + std::pair, std::vector> passes{ + {}, {"AutoCast", "Decomposer"}}; CompareProgramPassResult(&program, target, {out->id}, -2, passes); } TEST(AutoCast, BatchNorm) { NetBuilder builder("net_builder"); - auto x = builder.CreateInput(common::Float16(), {128, 64, 112, 112}, "X"); - auto scale = builder.FillConstant({64}, 1.0f, "scale", "float32"); - auto bias = builder.FillConstant({64}, 0.0f, "bias", "float32"); - auto mean = builder.FillConstant({64}, 0.0f, "mean", "float32"); + auto x = builder.CreateInput(common::Float16(), {128, 64, 112, 112}, "X"); + auto scale = builder.FillConstant({64}, 1.0f, "scale", "float32"); + auto bias = builder.FillConstant({64}, 0.0f, "bias", "float32"); + auto mean = builder.FillConstant({64}, 0.0f, "mean", "float32"); auto variance = builder.FillConstant({64}, 1.0f, "variance", "float32"); - auto out = builder.BatchNorm(x, scale, bias, mean, variance, 1e-5f, 0.9f, "NCHW", false); - auto program = builder.Build(); + auto out = builder.BatchNorm( + x, scale, bias, mean, variance, 1e-5f, 0.9f, "NCHW", false); + auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); - std::pair, std::vector> passes{{}, {"AutoCast", "Decomposer"}}; + std::pair, std::vector> passes{ + {}, {"AutoCast", "Decomposer"}}; CompareProgramPassResult(&program, target, {out[0]->id}, -2, passes); } TEST(AutoCast, BatchNorm_bf16) { NetBuilder builder("net_builder"); - auto x = builder.CreateInput(common::BFloat16(), {128, 64, 112, 112}, "X"); - auto scale = builder.FillConstant({64}, 1.0f, "scale", "float32"); - auto bias = builder.FillConstant({64}, 0.0f, "bias", "float32"); - auto mean = builder.FillConstant({64}, 0.0f, "mean", "float32"); + auto x = builder.CreateInput(common::BFloat16(), {128, 64, 112, 112}, "X"); + auto scale = builder.FillConstant({64}, 1.0f, "scale", "float32"); + auto bias = builder.FillConstant({64}, 0.0f, "bias", "float32"); + auto mean = builder.FillConstant({64}, 0.0f, "mean", "float32"); auto variance = builder.FillConstant({64}, 1.0f, "variance", "float32"); - auto out = builder.BatchNorm(x, scale, bias, mean, variance, 1e-5f, 0.9f, "NCHW", false); - auto program = builder.Build(); + auto out = builder.BatchNorm( + x, scale, bias, mean, variance, 1e-5f, 0.9f, "NCHW", false); + auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); - std::pair, std::vector> passes{{}, {"AutoCast", "Decomposer"}}; + std::pair, std::vector> passes{ + {}, {"AutoCast", "Decomposer"}}; CompareProgramPassResult(&program, target, {out[0]->id}, -2, passes); } diff --git a/paddle/cinn/frontend/pass/cast_collapsing.cc b/paddle/cinn/frontend/pass/cast_collapsing.cc index 373d462a58875..5fc40d407029f 100644 --- a/paddle/cinn/frontend/pass/cast_collapsing.cc +++ b/paddle/cinn/frontend/pass/cast_collapsing.cc @@ -27,18 +27,26 @@ namespace cinn::frontend::pass { class CastKey { public: - CastKey(const std::string& input_id, const std::string& cast_type) { SetKey(input_id, cast_type); } + CastKey(const std::string& input_id, const std::string& cast_type) { + SetKey(input_id, cast_type); + } void SetKey(const std::string& input_id, const std::string& cast_type) { - input_id_ = input_id; + input_id_ = input_id; cast_type_ = cast_type; } - bool operator==(const CastKey& other) const { return cast_type_ == other.cast_type_ && input_id_ == other.input_id_; } - bool operator!=(const CastKey& other) const { return !this->operator==(other); } + bool operator==(const CastKey& other) const { + return cast_type_ == other.cast_type_ && input_id_ == other.input_id_; + } + bool operator!=(const CastKey& other) const { + return !this->operator==(other); + } struct Hash { - size_t operator()(const CastKey& key) const { return std::hash()(key.input_id_ + key.cast_type_); } + size_t operator()(const CastKey& key) const { + return std::hash()(key.input_id_ + key.cast_type_); + } }; private: @@ -51,7 +59,8 @@ class CastCollapsingPass : public ProgramPass { public: using ProgramPass::ProgramPass; using OutputToOpMap = std::unordered_map; - using InputToOpMap = std::unordered_map>; + using InputToOpMap = + std::unordered_map>; protected: void Clear() override {} @@ -81,7 +90,8 @@ class CastCollapsingPass : public ProgramPass { // the useless cast op need to remove from program std::unordered_set remove_instrs; - FoldingCastVertical(all_cast, fetch_ids, in2instr, out2instr, &remove_instrs); + FoldingCastVertical( + all_cast, fetch_ids, in2instr, out2instr, &remove_instrs); for (auto instr : remove_instrs) { if (all_cast.count(instr)) { @@ -89,9 +99,10 @@ class CastCollapsingPass : public ProgramPass { } } // TODO(thisjiang): reopen after CINN support recompute for performance - // due to recompute unsupported, if the op output to two group, it will also create a new group, - // so that the horizontal fuse will not improve performance. - // FoldingCastHorizontal(all_cast, fetch_ids, in2instr, out2instr, &remove_instrs); + // due to recompute unsupported, if the op output to two group, it will also + // create a new group, so that the horizontal fuse will not improve + // performance. FoldingCastHorizontal(all_cast, fetch_ids, in2instr, + // out2instr, &remove_instrs); NetBuilder builder("cast_collapsing_builder"); for (auto& var : program->GetInputs()) { @@ -106,11 +117,12 @@ class CastCollapsingPass : public ProgramPass { } private: - void FoldingCastVertical(const std::unordered_set& all_cast, - const std::unordered_set& fetch_ids, - const InputToOpMap& in2instr, - const OutputToOpMap& out2instr, - std::unordered_set* remove_instrs) const { + void FoldingCastVertical( + const std::unordered_set& all_cast, + const std::unordered_set& fetch_ids, + const InputToOpMap& in2instr, + const OutputToOpMap& out2instr, + std::unordered_set* remove_instrs) const { if (all_cast.size() == 1) { return; } @@ -124,16 +136,19 @@ class CastCollapsingPass : public ProgramPass { // Fuse cast from front to back, the fuse path is unique auto first_cast = FindFirstCast(cast, out2instr); - TryFuseCast(first_cast, fetch_ids, in2instr, remove_instrs, &visited_instrs); + TryFuseCast( + first_cast, fetch_ids, in2instr, remove_instrs, &visited_instrs); } } - Instruction* FindFirstCast(Instruction* cast, const OutputToOpMap& out2instr) const { + Instruction* FindFirstCast(Instruction* cast, + const OutputToOpMap& out2instr) const { auto first_cast = cast; auto input_name = (*first_cast)->inputs.front()->id; // Q: Why check whether cast's input in out2instr ? - // A: The input may be the input of the graph other than another op's output. + // A: The input may be the input of the graph other than another op's + // output. // Obviously, the cast op is the first cast in the situation. while (out2instr.count(input_name)) { auto instr = out2instr.at(input_name); @@ -155,30 +170,33 @@ class CastCollapsingPass : public ProgramPass { std::unordered_set* visited_instrs) const { visited_instrs->insert(cast); - const auto& input = (*cast)->inputs.front(); - const auto& input_name = input->id; + const auto& input = (*cast)->inputs.front(); + const auto& input_name = input->id; const auto& input_dtype = input->type; - const auto& output = (*cast)->outputs.front(); - const auto& output_name = output->id; + const auto& output = (*cast)->outputs.front(); + const auto& output_name = output->id; const auto& output_dtype = output->type; const auto& dtype = cast->GetAttrs("dtype"); - const auto& cast_info = output_name + "=cast(" + input_name + ", dtype=" + dtype + ")"; + const auto& cast_info = + output_name + "=cast(" + input_name + ", dtype=" + dtype + ")"; bool can_remove = !fetch_ids.count(output_name); if (CheckCastBorder(cast, in2instr)) { if (can_remove) { - VLOG(4) << "The op " << cast_info << " is a output op of graph, cannot fuse, remove."; + VLOG(4) << "The op " << cast_info + << " is a output op of graph, cannot fuse, remove."; // this cast not used by any other op, remove remove_instrs->insert(cast); } else { if (input_dtype == output_dtype) { - VLOG(4) << "The cast op " << cast_info << " is fetched but useless, replace with identity."; - // cannot remove, however, the transpose is useless, we can replace the cast with identity for more - // fusion opportunity + VLOG(4) << "The cast op " << cast_info + << " is fetched but useless, replace with identity."; + // cannot remove, however, the transpose is useless, we can replace + // the cast with identity for more fusion opportunity ReplaceWithIdentity(cast); } // else the transpose is fetched and helpful, ignore @@ -190,8 +208,10 @@ class CastCollapsingPass : public ProgramPass { const auto& out_instrs = in2instr.at(output_name); if (input_dtype == output_dtype) { if (!can_remove) { - VLOG(4) << "The cast op " << cast_info << " is useless but fetched, replace with identity."; - // cannot remove, but we can replace the cast with indentiy for more fusion opportunity + VLOG(4) << "The cast op " << cast_info + << " is useless but fetched, replace with identity."; + // cannot remove, but we can replace the cast with indentiy for more + // fusion opportunity ReplaceWithIdentity(cast); } else { VLOG(4) << "The cast op " << cast_info << " is useless, remove."; @@ -204,7 +224,8 @@ class CastCollapsingPass : public ProgramPass { for (auto instr : out_instrs) { if ("cast" == (*instr)->op_type) { // if the next instruction is cast op, continue fuse - TryFuseCast(instr, fetch_ids, in2instr, remove_instrs, visited_instrs); + TryFuseCast( + instr, fetch_ids, in2instr, remove_instrs, visited_instrs); } } } @@ -212,7 +233,8 @@ class CastCollapsingPass : public ProgramPass { } if (!CheckOutputContainCast(cast, in2instr)) { - VLOG(4) << "The cast op " << cast_info << " doesn't has output link to cast, skip."; + VLOG(4) << "The cast op " << cast_info + << " doesn't has output link to cast, skip."; return; } @@ -222,16 +244,18 @@ class CastCollapsingPass : public ProgramPass { if ("cast" != (*instr)->op_type) { // the cast was used by other non-cast op, cannot remove, skip can_remove = false; - VLOG(4) << "Fuse cast of " << cast_info << " was used by " << (*instr)->op_type << ", cannot remove."; + VLOG(4) << "Fuse cast of " << cast_info << " was used by " + << (*instr)->op_type << ", cannot remove."; continue; } const auto& next_dtype = instr->GetAttrs("dtype"); - VLOG(4) << "Fuse cast of " << cast_info << " and cast of " << (*instr)->outputs.front()->id << "=cast(" + VLOG(4) << "Fuse cast of " << cast_info << " and cast of " + << (*instr)->outputs.front()->id << "=cast(" << (*instr)->inputs.front()->id << ", dtype=" << next_dtype << ")" - << " into cast of " << (*instr)->outputs.front()->id << "=cast(" << input_name << ", dtype=" << next_dtype - << ")"; + << " into cast of " << (*instr)->outputs.front()->id << "=cast(" + << input_name << ", dtype=" << next_dtype << ")"; auto fused_cast = FuseCastImpl(cast, instr, next_dtype); @@ -248,15 +272,16 @@ class CastCollapsingPass : public ProgramPass { } } - // check whether the op is the border op of graph, in other words, its output var was not - // used by any op in graph. + // check whether the op is the border op of graph, in other words, its output + // var was not used by any op in graph. bool CheckCastBorder(Instruction* cast, const InputToOpMap& in2instr) const { const auto& output_name = (*cast)->outputs.front()->id; return !in2instr.count(output_name) || in2instr.at(output_name).empty(); } // check whether the op's output ops has cast, if not, no cast need folding - bool CheckOutputContainCast(Instruction* cast, const InputToOpMap& in2instr) const { + bool CheckOutputContainCast(Instruction* cast, + const InputToOpMap& in2instr) const { const auto& output_name = (*cast)->outputs.front()->id; for (auto instr : in2instr.at(output_name)) { if ("cast" == (*instr)->op_type) { @@ -267,17 +292,23 @@ class CastCollapsingPass : public ProgramPass { return false; } - // replace the op's input variable whose name is `old_input_name` to `new_input`, note we need keep the input list - // order - void ReplaceInputVariable(Instruction* op, const std::string& old_input_name, const Variable& new_input) const { + // replace the op's input variable whose name is `old_input_name` to + // `new_input`, note we need keep the input list order + void ReplaceInputVariable(Instruction* op, + const std::string& old_input_name, + const Variable& new_input) const { auto find_input = [&](const std::string& input_name) { return std::find_if( - (*op)->inputs.begin(), (*op)->inputs.end(), [&](const Variable& v) { return input_name == v->id; }); + (*op)->inputs.begin(), (*op)->inputs.end(), [&](const Variable& v) { + return input_name == v->id; + }); }; // Why Loop : To avoid the op's inputs are the same variable ! - for (auto it = find_input(old_input_name); it != (*op)->inputs.end(); it = find_input(old_input_name)) { - // erase previous fill_constant output var and replace to new fill_constant output var + for (auto it = find_input(old_input_name); it != (*op)->inputs.end(); + it = find_input(old_input_name)) { + // erase previous fill_constant output var and replace to new + // fill_constant output var auto next_it = (*op)->inputs.erase(it); // keep the input place same, it's very important (*op)->inputs.insert(next_it, new_input); @@ -292,38 +323,44 @@ class CastCollapsingPass : public ProgramPass { } // fuse the two cast dtype into the second cast, replace its input and dtype - Instruction* FuseCastImpl(Instruction* cast1, Instruction* cast2, const std::string& fused_dtype) const { + Instruction* FuseCastImpl(Instruction* cast1, + Instruction* cast2, + const std::string& fused_dtype) const { (*cast2)->inputs.front() = (*cast1)->inputs.front(); cast2->SetAttr("dtype", fused_dtype); return cast2; } - // if the casts have the same input and dtype, they can folding into one, the redundance should remove - void FoldingCastHorizontal(const std::unordered_set& all_cast, - const std::unordered_set& fetch_ids, - const InputToOpMap& in2instr, - const OutputToOpMap& out2instr, - std::unordered_set* remove_instrs) const { + // if the casts have the same input and dtype, they can folding into one, the + // redundance should remove + void FoldingCastHorizontal( + const std::unordered_set& all_cast, + const std::unordered_set& fetch_ids, + const InputToOpMap& in2instr, + const OutputToOpMap& out2instr, + std::unordered_set* remove_instrs) const { std::unordered_map first_cast_map; for (auto cast : all_cast) { if (("cast" != (*cast)->op_type) || remove_instrs->count(cast)) { continue; } - const auto& input_id = (*cast)->inputs.front()->id; + const auto& input_id = (*cast)->inputs.front()->id; const auto& output_id = (*cast)->outputs.front()->id; - const auto& dtype = cast->GetAttrs("dtype"); + const auto& dtype = cast->GetAttrs("dtype"); CastKey key(input_id, dtype); if (!first_cast_map.count(key)) { - VLOG(4) << "The cast, whose output [" << output_id << "], cannot remove because it is the first cast ! "; + VLOG(4) << "The cast, whose output [" << output_id + << "], cannot remove because it is the first cast ! "; first_cast_map.emplace(key, &(*cast)->outputs.front()); continue; } if (fetch_ids.find(output_id) != fetch_ids.end()) { // the cast's output variable was fetched, skip - VLOG(4) << "Cannot remove cast, because the output [" << output_id << "] was fetched by other op ! "; + VLOG(4) << "Cannot remove cast, because the output [" << output_id + << "] was fetched by other op ! "; continue; } @@ -341,7 +378,8 @@ class CastCollapsingPass : public ProgramPass { } // namespace cinn::frontend::pass CINN_REGISTER_HELPER(CastCollapsing) { - CINN_REGISTER_PROGRAM_PASS(CastCollapsing, ::cinn::frontend::pass::CastCollapsingPass); + CINN_REGISTER_PROGRAM_PASS(CastCollapsing, + ::cinn::frontend::pass::CastCollapsingPass); return true; } diff --git a/paddle/cinn/frontend/pass/cast_collapsing_test.cc b/paddle/cinn/frontend/pass/cast_collapsing_test.cc index 8effc53f330b1..8384002c872cf 100644 --- a/paddle/cinn/frontend/pass/cast_collapsing_test.cc +++ b/paddle/cinn/frontend/pass/cast_collapsing_test.cc @@ -37,17 +37,18 @@ TEST(CastCollapsing, FuseTwoCast) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto x_t = builder.Cast(x, "float16"); - auto out = builder.Cast(x_t, "float32"); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto x_t = builder.Cast(x, "float16"); + auto out = builder.Cast(x_t, "float32"); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; - absl::c_transform(std::vector{x.id()}, std::back_inserter(input_ids), [](absl::string_view id) { - return std::string(id); - }); - std::pair, std::vector> passes{{"Decomposer"}, {"CastCollapsing"}}; + absl::c_transform(std::vector{x.id()}, + std::back_inserter(input_ids), + [](absl::string_view id) { return std::string(id); }); + std::pair, std::vector> passes{ + {"Decomposer"}, {"CastCollapsing"}}; CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } @@ -56,18 +57,19 @@ TEST(CastCollapsing, FuseThreeCast) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto x_1t = builder.Cast(x, "int32"); - auto x_2t = builder.Cast(x_1t, "int64"); - auto out = builder.Cast(x_2t, "float32"); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto x_1t = builder.Cast(x, "int32"); + auto x_2t = builder.Cast(x_1t, "int64"); + auto out = builder.Cast(x_2t, "float32"); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; - absl::c_transform(std::vector{x.id()}, std::back_inserter(input_ids), [](absl::string_view id) { - return std::string(id); - }); - std::pair, std::vector> passes{{"Decomposer"}, {"CastCollapsing"}}; + absl::c_transform(std::vector{x.id()}, + std::back_inserter(input_ids), + [](absl::string_view id) { return std::string(id); }); + std::pair, std::vector> passes{ + {"Decomposer"}, {"CastCollapsing"}}; CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } @@ -76,16 +78,17 @@ TEST(CastCollapsing, ReplaceUselessCastWithIndentity) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto out = builder.Cast(x, "float32"); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto out = builder.Cast(x, "float32"); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; - absl::c_transform(std::vector{x.id()}, std::back_inserter(input_ids), [](absl::string_view id) { - return std::string(id); - }); - std::pair, std::vector> passes{{"Decomposer"}, {"CastCollapsing"}}; + absl::c_transform(std::vector{x.id()}, + std::back_inserter(input_ids), + [](absl::string_view id) { return std::string(id); }); + std::pair, std::vector> passes{ + {"Decomposer"}, {"CastCollapsing"}}; CompareResult(&program, target, input_ids, {out->id}, 0, passes, 123, true); } @@ -94,19 +97,20 @@ TEST(CastCollapsing, FuseCastToUseless) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto x_1t = builder.Cast(x, "int32"); - auto x_2t = builder.Cast(x_1t, "int64"); - auto x_3t = builder.Cast(x_2t, "float32"); - auto out = builder.Add(x_3t, x_3t); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto x_1t = builder.Cast(x, "int32"); + auto x_2t = builder.Cast(x_1t, "int64"); + auto x_3t = builder.Cast(x_2t, "float32"); + auto out = builder.Add(x_3t, x_3t); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; - absl::c_transform(std::vector{x.id()}, std::back_inserter(input_ids), [](absl::string_view id) { - return std::string(id); - }); - std::pair, std::vector> passes{{"Decomposer"}, {"CastCollapsing"}}; + absl::c_transform(std::vector{x.id()}, + std::back_inserter(input_ids), + [](absl::string_view id) { return std::string(id); }); + std::pair, std::vector> passes{ + {"Decomposer"}, {"CastCollapsing"}}; CompareResult(&program, target, input_ids, {out->id}, 3, passes, 123, true); } @@ -115,22 +119,30 @@ TEST(TransposeCollapsing, FuseTransposeWithMultiOutput) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto x_1t = builder.Cast(x, "int32"); - auto x_2t = builder.Cast(x_1t, "float32"); - auto x_3t = builder.Cast(x_2t, "int32"); - auto out1 = builder.Transpose(x_1t, {0, 2, 1}); - auto out2 = builder.Transpose(x_2t, {0, 2, 1}); - auto out3 = builder.Transpose(x_3t, {0, 2, 1}); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto x_1t = builder.Cast(x, "int32"); + auto x_2t = builder.Cast(x_1t, "float32"); + auto x_3t = builder.Cast(x_2t, "int32"); + auto out1 = builder.Transpose(x_1t, {0, 2, 1}); + auto out2 = builder.Transpose(x_2t, {0, 2, 1}); + auto out3 = builder.Transpose(x_3t, {0, 2, 1}); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; - absl::c_transform(std::vector{x.id()}, std::back_inserter(input_ids), [](absl::string_view id) { - return std::string(id); - }); - std::pair, std::vector> passes{{"Decomposer"}, {"CastCollapsing"}}; - CompareResult(&program, target, input_ids, {out1->id, out2->id, out3->id}, 1, passes, 123, true); + absl::c_transform(std::vector{x.id()}, + std::back_inserter(input_ids), + [](absl::string_view id) { return std::string(id); }); + std::pair, std::vector> passes{ + {"Decomposer"}, {"CastCollapsing"}}; + CompareResult(&program, + target, + input_ids, + {out1->id, out2->id, out3->id}, + 1, + passes, + 123, + true); } TEST(TransposeCollapsing, FuseTwoSecTranspose) { @@ -138,22 +150,24 @@ TEST(TransposeCollapsing, FuseTwoSecTranspose) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto x_1t = builder.Cast(x, "int32"); - auto x_2t = builder.Cast(x_1t, "float32"); - auto out1 = builder.Reshape(x_2t, {5, 3, 4}); - auto x_3t = builder.Cast(out1, "int32"); - auto x_4t = builder.Cast(x_3t, "float32"); - auto out2 = builder.Transpose(x_2t, {0, 2, 1}); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto x_1t = builder.Cast(x, "int32"); + auto x_2t = builder.Cast(x_1t, "float32"); + auto out1 = builder.Reshape(x_2t, {5, 3, 4}); + auto x_3t = builder.Cast(out1, "int32"); + auto x_4t = builder.Cast(x_3t, "float32"); + auto out2 = builder.Transpose(x_2t, {0, 2, 1}); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; - absl::c_transform(std::vector{x.id()}, std::back_inserter(input_ids), [](absl::string_view id) { - return std::string(id); - }); - std::pair, std::vector> passes{{"Decomposer"}, {"CastCollapsing"}}; - CompareResult(&program, target, input_ids, {out1->id, out2->id}, 4, passes, 123, true); + absl::c_transform(std::vector{x.id()}, + std::back_inserter(input_ids), + [](absl::string_view id) { return std::string(id); }); + std::pair, std::vector> passes{ + {"Decomposer"}, {"CastCollapsing"}}; + CompareResult( + &program, target, input_ids, {out1->id, out2->id}, 4, passes, 123, true); } TEST(TransposeCollapsing, FuseTwoHorizontalTranspose) { @@ -161,18 +175,19 @@ TEST(TransposeCollapsing, FuseTwoHorizontalTranspose) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto y_t1 = builder.Cast(x, "int32"); - auto y_t2 = builder.Cast(x, "int32"); - auto out = builder.Add(y_t1, y_t2); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto y_t1 = builder.Cast(x, "int32"); + auto y_t2 = builder.Cast(x, "int32"); + auto out = builder.Add(y_t1, y_t2); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; - absl::c_transform(std::vector{x.id()}, std::back_inserter(input_ids), [](absl::string_view id) { - return std::string(id); - }); - std::pair, std::vector> passes{{"Decomposer"}, {"CastCollapsing"}}; + absl::c_transform(std::vector{x.id()}, + std::back_inserter(input_ids), + [](absl::string_view id) { return std::string(id); }); + std::pair, std::vector> passes{ + {"Decomposer"}, {"CastCollapsing"}}; CompareResult(&program, target, input_ids, {out->id}, 0, passes, 123, true); } @@ -181,19 +196,20 @@ TEST(TransposeCollapsing, FuseVerAndHorTranspose) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto y_t1 = builder.Cast(x, "int32"); - auto y_t2 = builder.Cast(y_t1, "float32"); - auto y_t3 = builder.Cast(x, "float32"); - auto out = builder.Add(y_t2, y_t3); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto y_t1 = builder.Cast(x, "int32"); + auto y_t2 = builder.Cast(y_t1, "float32"); + auto y_t3 = builder.Cast(x, "float32"); + auto out = builder.Add(y_t2, y_t3); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; - absl::c_transform(std::vector{x.id()}, std::back_inserter(input_ids), [](absl::string_view id) { - return std::string(id); - }); - std::pair, std::vector> passes{{"Decomposer"}, {"CastCollapsing"}}; + absl::c_transform(std::vector{x.id()}, + std::back_inserter(input_ids), + [](absl::string_view id) { return std::string(id); }); + std::pair, std::vector> passes{ + {"Decomposer"}, {"CastCollapsing"}}; CompareResult(&program, target, input_ids, {out->id}, 3, passes, 123, true); } diff --git a/paddle/cinn/frontend/pass/dead_code_eliminate.cc b/paddle/cinn/frontend/pass/dead_code_eliminate.cc index 1df236d1e94bd..0c093cf75fd02 100644 --- a/paddle/cinn/frontend/pass/dead_code_eliminate.cc +++ b/paddle/cinn/frontend/pass/dead_code_eliminate.cc @@ -24,8 +24,8 @@ namespace pass { // Program maybe has some unused instructions. `DeadCodeEliminate` will remove // these instructions. The way to find unused instructions is to traverse all -// instructions to determine whether its output is used by other instructions in the -// same subgraph or in the `fetch_ids`. +// instructions to determine whether its output is used by other instructions in +// the same subgraph or in the `fetch_ids`. class DeadCodeEliminatePass : public ProgramPass { public: using ProgramPass::ProgramPass; @@ -44,7 +44,7 @@ class DeadCodeEliminatePass : public ProgramPass { std::unordered_set remove_idxs; for (int i = program->size() - 1; i >= 0; --i) { const auto& instr = (*program)[i]; - bool can_remove = true; + bool can_remove = true; for (const auto& out : instr->outputs) { if (inputs.count(out->id) || fetch_ids.count(out->id)) { can_remove = false; @@ -79,9 +79,11 @@ class DeadCodeEliminatePass : public ProgramPass { } private: - bool CheckFetchIds(const Program& program, const std::unordered_set& fetch_ids) { + bool CheckFetchIds(const Program& program, + const std::unordered_set& fetch_ids) { if (fetch_ids.empty()) { - // If fetch_ids is not specified, all output vars are considered as fetch vars. + // If fetch_ids is not specified, all output vars are considered as fetch + // vars. return false; } @@ -96,7 +98,9 @@ class DeadCodeEliminatePass : public ProgramPass { bool res = true; for (auto& id : fetch_ids) { if (!outputs.count(id)) { - LOG(WARNING) << id << " in fetch_ids is not output of any instruction in program."; + LOG(WARNING) + << id + << " in fetch_ids is not output of any instruction in program."; res = false; } } @@ -110,7 +114,8 @@ class DeadCodeEliminatePass : public ProgramPass { } // namespace cinn CINN_REGISTER_HELPER(DeadCodeEliminate) { - CINN_REGISTER_PROGRAM_PASS(DeadCodeEliminate, cinn::frontend::pass::DeadCodeEliminatePass); + CINN_REGISTER_PROGRAM_PASS(DeadCodeEliminate, + cinn::frontend::pass::DeadCodeEliminatePass); return true; } diff --git a/paddle/cinn/frontend/pass/dead_code_eliminate_test.cc b/paddle/cinn/frontend/pass/dead_code_eliminate_test.cc index 5651381d690b2..7e418f394dae3 100644 --- a/paddle/cinn/frontend/pass/dead_code_eliminate_test.cc +++ b/paddle/cinn/frontend/pass/dead_code_eliminate_test.cc @@ -35,20 +35,22 @@ TEST(DeadCodeEliminate, remove_single) { // | | // NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {32, 16}, "x"); - auto identity_1 = builder.Identity(x); - auto identity_2 = builder.Identity(x); + auto x = builder.CreateInput(Float(32), {32, 16}, "x"); + auto identity_1 = builder.Identity(x); + auto identity_2 = builder.Identity(x); auto reduce_sum_1 = builder.ReduceSum(x, {0, 1}); auto reduce_sum_2 = builder.ReduceSum(x, {0, 1}); - auto program = builder.Build(); + auto program = builder.Build(); PassTest tester; - std::vector input_names = {x.id().data()}; + std::vector input_names = {x.id().data()}; std::vector output_names = {identity_1->id, reduce_sum_2->id}; common::Target target = common::DefaultNVGPUTarget(); - std::pair, std::vector> passes{{"Decomposer"}, {"DeadCodeEliminate"}}; - CompareResult(&program, target, input_names, output_names, 2, passes, 123, true); + std::pair, std::vector> passes{ + {"Decomposer"}, {"DeadCodeEliminate"}}; + CompareResult( + &program, target, input_names, output_names, 2, passes, 123, true); } TEST(DeadCodeEliminate, remove_multiple) { @@ -63,19 +65,21 @@ TEST(DeadCodeEliminate, remove_multiple) { // | // NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {32, 16}, "x"); - auto identity_1 = builder.Transpose(x, {1, 0}); + auto x = builder.CreateInput(Float(32), {32, 16}, "x"); + auto identity_1 = builder.Transpose(x, {1, 0}); auto reduce_sum_1 = builder.ReduceSum(x, {0, 1}); - auto mul_1 = builder.Matmul(x, identity_1); - auto program = builder.Build(); + auto mul_1 = builder.Matmul(x, identity_1); + auto program = builder.Build(); PassTest tester; - std::vector input_names = {x.id().data()}; + std::vector input_names = {x.id().data()}; std::vector output_names = {reduce_sum_1->id}; common::Target target = common::DefaultNVGPUTarget(); - std::pair, std::vector> passes{{"Decomposer"}, {"DeadCodeEliminate"}}; - CompareResult(&program, target, input_names, output_names, 2, passes, 123, true); + std::pair, std::vector> passes{ + {"Decomposer"}, {"DeadCodeEliminate"}}; + CompareResult( + &program, target, input_names, output_names, 2, passes, 123, true); } } // namespace cinn::frontend diff --git a/paddle/cinn/frontend/pass/decomposer.cc b/paddle/cinn/frontend/pass/decomposer.cc index e3d816ac87622..b18ac57be73f3 100755 --- a/paddle/cinn/frontend/pass/decomposer.cc +++ b/paddle/cinn/frontend/pass/decomposer.cc @@ -41,8 +41,9 @@ class DecomposerPass : public ProgramPass { absl::flat_hash_map var_map; DecomposerContext context(&builder, &var_map); for (size_t i = 0; i < prog->size(); i++) { - auto instr = (*prog)[i]; - auto decomposer = InstrDecomposerRegistry::Global()->Find(instr->op_type, target); + auto instr = (*prog)[i]; + auto decomposer = + InstrDecomposerRegistry::Global()->Find(instr->op_type, target); if (decomposer) { VLOG(3) << "Run decomposer of op " << instr->op_type; decomposer->Run(instr, context); @@ -79,7 +80,8 @@ class DecomposerPass : public ProgramPass { } // namespace cinn CINN_REGISTER_HELPER(Decomposer) { - CINN_REGISTER_PROGRAM_PASS(Decomposer, ::cinn::frontend::pass::DecomposerPass); + CINN_REGISTER_PROGRAM_PASS(Decomposer, + ::cinn::frontend::pass::DecomposerPass); return true; } diff --git a/paddle/cinn/frontend/pass/decomposer_test.cc b/paddle/cinn/frontend/pass/decomposer_test.cc index 4c5f2dc15a526..c6b434edba0a9 100644 --- a/paddle/cinn/frontend/pass/decomposer_test.cc +++ b/paddle/cinn/frontend/pass/decomposer_test.cc @@ -37,18 +37,20 @@ Program CreateAddProgram() { constexpr int N = 24; NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {M, N}); - auto b = builder.CreateInput(Float(32), {M, N}); - auto c = builder.Relu(a); - auto d = builder.Add(b, c); + auto a = builder.CreateInput(Float(32), {M, N}); + auto b = builder.CreateInput(Float(32), {M, N}); + auto c = builder.Relu(a); + auto d = builder.Add(b, c); auto program = builder.Build(); return program; } TEST(DecomposePassRegistry, basic) { - ASSERT_NE(cinn::frontend::ProgramPassRegistry::Global()->Find("Decomposer"), nullptr); - ASSERT_EQ(cinn::frontend::ProgramPassRegistry::Global()->Find("Test"), nullptr); + ASSERT_NE(cinn::frontend::ProgramPassRegistry::Global()->Find("Decomposer"), + nullptr); + ASSERT_EQ(cinn::frontend::ProgramPassRegistry::Global()->Find("Test"), + nullptr); } TEST(DecomposePass, basic) { diff --git a/paddle/cinn/frontend/pass/expand_zero_dim_pass.cc b/paddle/cinn/frontend/pass/expand_zero_dim_pass.cc index 293c92389b639..92dfecab6dd09 100644 --- a/paddle/cinn/frontend/pass/expand_zero_dim_pass.cc +++ b/paddle/cinn/frontend/pass/expand_zero_dim_pass.cc @@ -17,9 +17,9 @@ #include #include +#include "glog/logging.h" #include "paddle/cinn/frontend/net_builder.h" #include "paddle/cinn/frontend/program_pass.h" -#include "glog/logging.h" namespace cinn { namespace frontend { @@ -50,7 +50,8 @@ class ExpandZeroDimPass : public ProgramPass { } for (auto& output : instr->outputs) { if (output->shape.empty()) { - VLOG(4) << "Change output 0D-Tensor " << output->id << " to 1D-Tensor"; + VLOG(4) << "Change output 0D-Tensor " << output->id + << " to 1D-Tensor"; output->shape.push_back(1); } } @@ -67,7 +68,8 @@ class ExpandZeroDimPass : public ProgramPass { } // namespace cinn CINN_REGISTER_HELPER(ExpandZeroDim) { - CINN_REGISTER_PROGRAM_PASS(ExpandZeroDim, cinn::frontend::pass::ExpandZeroDimPass); + CINN_REGISTER_PROGRAM_PASS(ExpandZeroDim, + cinn::frontend::pass::ExpandZeroDimPass); return true; } diff --git a/paddle/cinn/frontend/pass/expand_zero_dim_pass_test.cc b/paddle/cinn/frontend/pass/expand_zero_dim_pass_test.cc index 47854e9c7b608..51f80a25035d0 100644 --- a/paddle/cinn/frontend/pass/expand_zero_dim_pass_test.cc +++ b/paddle/cinn/frontend/pass/expand_zero_dim_pass_test.cc @@ -33,13 +33,17 @@ namespace cinn { namespace frontend { -int GetSize(std::vector& shape) { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); } +int GetSize(std::vector& shape) { + return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); +} -std::unordered_map> GetInputRandom(const std::vector&& inputs) { +std::unordered_map> GetInputRandom( + const std::vector&& inputs) { std::unordered_map> input_data; for (auto input : inputs) { input_data[input->id] = std::vector(GetSize(input->shape)); - InitRandomVector(&input_data[input->id], input_data[input->id].size(), 0.0f, 1.0f, 1e-3); + InitRandomVector( + &input_data[input->id], input_data[input->id].size(), 0.0f, 1.0f, 1e-3); } return input_data; @@ -50,7 +54,8 @@ std::unordered_map RunWithProgram( const Target& target, const std::unordered_map>& input_data, const std::unordered_set& fetch_ids) { - auto graph = std::make_shared(program, fetch_ids, target); + auto graph = + std::make_shared(program, fetch_ids, target); auto scope = hlir::framework::BuildScope(target, graph); hlir::framework::ApplyPasses(graph.get(), {"InferShape"}); @@ -75,11 +80,11 @@ std::unordered_map RunWithProgram( TEST(ExpandZeroDimPass, expand_zero_dim_1) { NetBuilder builder("expand_zero_dim_1"); - auto x = builder.CreateInput(Float(32), {}, "x"); - auto y = builder.CreateInput(Float(32), {}, "y"); - auto out = builder.Add(x, y); + auto x = builder.CreateInput(Float(32), {}, "x"); + auto y = builder.CreateInput(Float(32), {}, "y"); + auto out = builder.Add(x, y); auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); size_t origin_size = program.size(); VLOG(1) << "Program Before ExpandZeroDimPass:\n" << program; @@ -105,8 +110,8 @@ TEST(ExpandZeroDimPass, expand_zero_dim_1) { } */ auto input_data = GetInputRandom({x, y}); - auto fetch_ids = {out->id}; - auto outputs = RunWithProgram(program, target, input_data, fetch_ids); + auto fetch_ids = {out->id}; + auto outputs = RunWithProgram(program, target, input_data, fetch_ids); for (auto iter : outputs) { // output var_1: shape=[1] ASSERT_EQ(iter.second->shape().data().size(), 1); @@ -115,11 +120,11 @@ TEST(ExpandZeroDimPass, expand_zero_dim_1) { TEST(ExpandZeroDimPass, expand_zero_dim_2) { NetBuilder builder("expand_zero_dim_1"); - auto x = builder.CreateInput(Float(32), {3, 5}, "x"); - auto y = builder.CreateInput(Float(32), {}, "y"); - auto out = builder.Add(x, y); + auto x = builder.CreateInput(Float(32), {3, 5}, "x"); + auto y = builder.CreateInput(Float(32), {}, "y"); + auto out = builder.Add(x, y); auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); size_t origin_size = program.size(); VLOG(1) << "Program Before ExpandZeroDimPass:\n" << program; @@ -145,8 +150,8 @@ TEST(ExpandZeroDimPass, expand_zero_dim_2) { } */ auto input_data = GetInputRandom({x, y}); - auto fetch_ids = {out->id}; - auto outputs = RunWithProgram(program, target, input_data, fetch_ids); + auto fetch_ids = {out->id}; + auto outputs = RunWithProgram(program, target, input_data, fetch_ids); for (auto iter : outputs) { // output var_1: shape=[3, 5] ASSERT_EQ(iter.second->shape().data().size(), 2); diff --git a/paddle/cinn/frontend/pass/fill_constant_folding.cc b/paddle/cinn/frontend/pass/fill_constant_folding.cc index c0a5dd51c4734..c6ee33bc6c79c 100644 --- a/paddle/cinn/frontend/pass/fill_constant_folding.cc +++ b/paddle/cinn/frontend/pass/fill_constant_folding.cc @@ -33,27 +33,38 @@ using cinn::utils::ShapeType; class FillConstantKey { public: - FillConstantKey(const ShapeType& shape, Attribute value, const std::string& dtype, bool force_cpu) { + FillConstantKey(const ShapeType& shape, + Attribute value, + const std::string& dtype, + bool force_cpu) { SetKey(shape, value, dtype, force_cpu); } - void SetKey(const ShapeType& shape, Attribute value, const std::string& dtype, bool force_cpu) { - shape_ = shape; - value_ = value; + void SetKey(const ShapeType& shape, + Attribute value, + const std::string& dtype, + bool force_cpu) { + shape_ = shape; + value_ = value; force_cpu_ = force_cpu; - dtype_ = dtype; + dtype_ = dtype; } bool operator==(const FillConstantKey& other) const { - return shape_ == other.shape_ && value_ == other.value_ && force_cpu_ == other.force_cpu_ && dtype_ == other.dtype_; + return shape_ == other.shape_ && value_ == other.value_ && + force_cpu_ == other.force_cpu_ && dtype_ == other.dtype_; + } + bool operator!=(const FillConstantKey& other) const { + return !this->operator==(other); } - bool operator!=(const FillConstantKey& other) const { return !this->operator==(other); } struct Hash { size_t operator()(const FillConstantKey& key) const { std::ostringstream hash_str; - std::for_each(key.shape_.begin(), key.shape_.end(), [&](const DimType& dim) { hash_str << dim; }); + std::for_each(key.shape_.begin(), + key.shape_.end(), + [&](const DimType& dim) { hash_str << dim; }); hash_str << utils::Attribute2String(key.value_); hash_str << key.force_cpu_; @@ -75,7 +86,8 @@ class FillConstantKey { class FillConstantFoldingPass : public ProgramPass { public: using ProgramPass::ProgramPass; - using InputToOpMap = std::unordered_map>; + using InputToOpMap = + std::unordered_map>; protected: void Clear() override {} @@ -85,9 +97,12 @@ class FillConstantFoldingPass : public ProgramPass { const common::Target& target) const override { auto in2instr = GetInputToOpMap(program); - // `fill_constant_map` is used to represent the first fill_constant and its output variable - std::unordered_map fill_constant_map; - // `remove_instrs` is used to represent Instructions of which type is fill_constant to be deleted. + // `fill_constant_map` is used to represent the first fill_constant and its + // output variable + std::unordered_map + fill_constant_map; + // `remove_instrs` is used to represent Instructions of which type is + // fill_constant to be deleted. std::unordered_set remove_instrs; for (int i = 0; i < program->size(); ++i) { @@ -99,16 +114,18 @@ class FillConstantFoldingPass : public ProgramPass { } CHECK_EQ(instr->outputs.size(), 1UL) - << "The fill_constant op should has one, and only one output ! Please check."; + << "The fill_constant op should has one, and only one output ! " + "Please check."; const auto& shape = instr.GetAttrs("shape"); - auto value = instr->attrs.at("value"); + auto value = instr->attrs.at("value"); const auto& dtype = instr.GetAttrs("dtype"); - auto force_cpu = instr.GetAttrs("force_cpu"); + auto force_cpu = instr.GetAttrs("force_cpu"); FillConstantKey key(shape, value, dtype, force_cpu); if (!fill_constant_map.count(key)) { - VLOG(4) << "The fill_constant, whose output is Var [" << instr->outputs[0]->id + VLOG(4) << "The fill_constant, whose output is Var [" + << instr->outputs[0]->id << "], cannot remove because it is the first fill_costant ! "; // retain the first fill constant op node fill_constant_map.emplace(key, &instr->outputs[0]); @@ -117,12 +134,13 @@ class FillConstantFoldingPass : public ProgramPass { if (fetch_ids.find(instr->outputs[0]->id) != fetch_ids.end()) { // the fill constant's output variable was fetched, skip - VLOG(4) << "Cannot remove fill_constant, because Var [" << instr->outputs[0]->id - << "] was fetched by other op ! "; + VLOG(4) << "Cannot remove fill_constant, because Var [" + << instr->outputs[0]->id << "] was fetched by other op ! "; continue; } - VLOG(4) << "Try remove fill_constant, whose output is Var [" << instr->outputs[0]->id << "]. "; + VLOG(4) << "Try remove fill_constant, whose output is Var [" + << instr->outputs[0]->id << "]. "; remove_instrs.insert(&instr); auto constant_name = instr->outputs[0]->id; @@ -155,25 +173,32 @@ class FillConstantFoldingPass : public ProgramPass { return in2instr; } - static void ReLinkFillConstant(const InputToOpMap& in2instr, const std::string& input_var_name, Variable* out_var) { + static void ReLinkFillConstant(const InputToOpMap& in2instr, + const std::string& input_var_name, + Variable* out_var) { if (!in2instr.count(input_var_name)) { LOG(WARNING) << "Var [" << input_var_name << "] not used by other op ! "; return; } - VLOG(4) << "Try replace the input Var [" << input_var_name << "] to [" << (*out_var)->id + VLOG(4) << "Try replace the input Var [" << input_var_name << "] to [" + << (*out_var)->id << "], because the fill_constant will be folding."; const auto& output_ops = in2instr.at(input_var_name); for (auto op : output_ops) { auto find_input = [&](const std::string& input_name) { return std::find_if( - (*op)->inputs.begin(), (*op)->inputs.end(), [&](const Variable& var) { return var->id == input_name; }); + (*op)->inputs.begin(), + (*op)->inputs.end(), + [&](const Variable& var) { return var->id == input_name; }); }; // Why Loop : To avoid the op's inputs are the same variable ! - for (auto it = find_input(input_var_name); it != (*op)->inputs.end(); it = find_input(input_var_name)) { - // erase previous fill_constant output var and replace to new fill_constant output var + for (auto it = find_input(input_var_name); it != (*op)->inputs.end(); + it = find_input(input_var_name)) { + // erase previous fill_constant output var and replace to new + // fill_constant output var auto next_it = (*op)->inputs.erase(it); // keep the input place same, it's very important (*op)->inputs.insert(next_it, *out_var); @@ -185,7 +210,8 @@ class FillConstantFoldingPass : public ProgramPass { } // namespace cinn::frontend::pass CINN_REGISTER_HELPER(FillConstantFolding) { - CINN_REGISTER_PROGRAM_PASS(FillConstantFolding, ::cinn::frontend::pass::FillConstantFoldingPass); + CINN_REGISTER_PROGRAM_PASS(FillConstantFolding, + ::cinn::frontend::pass::FillConstantFoldingPass); return true; } diff --git a/paddle/cinn/frontend/pass/fill_constant_folding_test.cc b/paddle/cinn/frontend/pass/fill_constant_folding_test.cc index 62bf732e10558..48faf488c82a7 100644 --- a/paddle/cinn/frontend/pass/fill_constant_folding_test.cc +++ b/paddle/cinn/frontend/pass/fill_constant_folding_test.cc @@ -31,7 +31,9 @@ namespace cinn::frontend { -std::vector RunWithProgram(const Program& program, const Target& target, Variable out) { +std::vector RunWithProgram(const Program& program, + const Target& target, + Variable out) { auto graph = std::make_shared(program, target); auto scope = hlir::framework::BuildScope(target, graph); @@ -47,13 +49,13 @@ std::vector RunWithProgram(const Program& program, const Target& target, TEST(TransposeFolding, FoldTwoFillConstant) { NetBuilder builder("net_builder"); - auto x = builder.FillConstant({32, 32}, 1.0f, "x"); - auto y = builder.FillConstant({32, 32}, 1.0f, "y"); + auto x = builder.FillConstant({32, 32}, 1.0f, "x"); + auto y = builder.FillConstant({32, 32}, 1.0f, "y"); auto transpose_x = builder.Transpose(x, {1, 0}); auto transpose_y = builder.Transpose(y, {1, 0}); - auto out = builder.Add(transpose_x, transpose_y); - auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto out = builder.Add(transpose_x, transpose_y); + auto program = builder.Build(); + auto target = common::DefaultTarget(); size_t origin_size = program.size(); VLOG(1) << "Program Before FillConstantFolding:\n" << program; @@ -89,12 +91,12 @@ TEST(TransposeFolding, FoldTwoFillConstant) { TEST(TransposeFolding, FoldTwoFillConstantWithSameOuput) { NetBuilder builder("net_builder"); - auto x = builder.FillConstant({32, 32}, 1.0f, "x"); - auto y = builder.FillConstant({32, 32}, 1.0f, "y"); + auto x = builder.FillConstant({32, 32}, 1.0f, "x"); + auto y = builder.FillConstant({32, 32}, 1.0f, "y"); auto transpose_x = builder.Transpose(x, {1, 0}); - auto out = builder.Add(y, y); - auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto out = builder.Add(y, y); + auto program = builder.Build(); + auto target = common::DefaultTarget(); size_t origin_size = program.size(); VLOG(1) << "Program Before FillConstantFolding:\n" << program; @@ -127,13 +129,13 @@ TEST(TransposeFolding, FoldTwoFillConstantWithSameOuput) { TEST(TransposeFolding, FoldThreeFillConstant) { NetBuilder builder("net_builder"); - auto x = builder.FillConstant({32, 32}, 1.0f, "x"); - auto y = builder.FillConstant({32, 32}, 1.0f, "y"); - auto z = builder.FillConstant({32, 32}, 1.0f, "z"); - auto transpose_x = builder.Transpose(x, {1, 0}); - auto out = builder.Add(y, z); - auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto x = builder.FillConstant({32, 32}, 1.0f, "x"); + auto y = builder.FillConstant({32, 32}, 1.0f, "y"); + auto z = builder.FillConstant({32, 32}, 1.0f, "z"); + auto transpose_x = builder.Transpose(x, {1, 0}); + auto out = builder.Add(y, z); + auto program = builder.Build(); + auto target = common::DefaultTarget(); size_t origin_size = program.size(); VLOG(1) << "Program Before FillConstantFolding:\n" << program; // Program { @@ -166,15 +168,15 @@ TEST(TransposeFolding, FoldThreeFillConstant) { TEST(TransposeFolding, FoldThreeFillConstantWithOneDiff) { NetBuilder builder("net_builder"); - auto x = builder.FillConstant({32, 32}, 1.0f, "x"); - auto y = builder.FillConstant({32, 32}, 1.0f, "y"); - auto z = builder.FillConstant({32, 32}, 0.0f, "z"); + auto x = builder.FillConstant({32, 32}, 1.0f, "x"); + auto y = builder.FillConstant({32, 32}, 1.0f, "y"); + auto z = builder.FillConstant({32, 32}, 0.0f, "z"); auto transpose_x = builder.Transpose(x, {1, 0}); - auto out = builder.Add(y, z); - auto program = builder.Build(); - auto target = common::DefaultTarget(); - auto graph = std::make_shared(program, target); - auto scope = hlir::framework::BuildScope(target, graph); + auto out = builder.Add(y, z); + auto program = builder.Build(); + auto target = common::DefaultTarget(); + auto graph = std::make_shared(program, target); + auto scope = hlir::framework::BuildScope(target, graph); size_t origin_size = program.size(); VLOG(1) << "Program Before FillConstantFolding:\n" << program; diff --git a/paddle/cinn/frontend/pass/fill_constant_rewriter.cc b/paddle/cinn/frontend/pass/fill_constant_rewriter.cc index 058fd628efa0c..569d1ba77f859 100644 --- a/paddle/cinn/frontend/pass/fill_constant_rewriter.cc +++ b/paddle/cinn/frontend/pass/fill_constant_rewriter.cc @@ -19,9 +19,9 @@ #include #include +#include "glog/logging.h" #include "paddle/cinn/frontend/net_builder.h" #include "paddle/cinn/frontend/program_pass.h" -#include "glog/logging.h" namespace cinn { namespace frontend { @@ -51,83 +51,96 @@ namespace pass { } \ } -static std::unordered_map> rewriter_ops = { - {"reshape", - [](const Instruction& fill_constant, Instruction* instr) -> void { - (*instr)->op_type = "fill_constant"; - (*instr)->inputs.clear(); - // the outputs keep same - - CHECK((*instr)->attrs.count("shape")) << "The reshape op should has attribute [shape]!"; - auto new_shape = (*instr)->attrs.at("shape"); - (*instr)->attrs = fill_constant->attrs; - (*instr)->attrs["shape"] = new_shape; - }}, - {"scale", - [](const Instruction& fill_constant, Instruction* instr) -> void { - (*instr)->op_type = "fill_constant"; - (*instr)->inputs.clear(); - // the outputs keep same - - auto scale = (*instr)->attrs.count("scale") ? instr->GetAttrs("scale") : 1.0f; - auto bias = (*instr)->attrs.count("bias") ? instr->GetAttrs("bias") : 0.0f; - auto bias_after_scale = - (*instr)->attrs.count("bias_after_scale") ? instr->GetAttrs("bias_after_scale") : true; - - (*instr)->attrs = fill_constant->attrs; - - const auto& old_attr = fill_constant->attrs.at("value"); - auto& new_attr = (*instr)->attrs.at("value"); - if (bias_after_scale) { - auto scale_func = [&](const auto& value) -> decltype(auto) { - return value * static_cast(scale) + static_cast(bias); - }; - FILL_CONSTANT_VALUE_REWRITE(old_attr, scale_func, new_attr) - } else { - auto scale_func = [&](const auto& value) -> decltype(auto) { - return (value + static_cast(bias)) * static_cast(scale); - }; - FILL_CONSTANT_VALUE_REWRITE(old_attr, scale_func, new_attr) - } - }}, - {"cast", - [](const Instruction& fill_constant, Instruction* instr) -> void { - (*instr)->op_type = "fill_constant"; - (*instr)->inputs.clear(); - // the outputs keep same - - CHECK((*instr)->attrs.count("dtype")) << "The cast op should has attribute [dtype]!"; - auto cast_dtype = instr->GetAttrs("dtype"); - - (*instr)->attrs = fill_constant->attrs; - (*instr)->attrs["dtype"] = cast_dtype; - }}, - {"broadcast_to", - [](const Instruction& fill_constant, Instruction* instr) -> void { - (*instr)->op_type = "fill_constant"; - (*instr)->inputs.clear(); - // the outputs keep same - - CHECK((*instr)->attrs.count("out_shape")) << "The cast op should has attribute [out_shape]!"; - auto out_shape = instr->GetAttrs>("out_shape"); - - (*instr)->attrs = fill_constant->attrs; - (*instr)->attrs["shape"] = out_shape; - }}, - {"slice", - [](const Instruction& fill_constant, Instruction* instr) -> void { - (*instr)->op_type = "fill_constant"; - (*instr)->inputs.clear(); - // the outputs keep same - - (*instr)->attrs = fill_constant->attrs; - (*instr)->attrs["shape"] = (*instr)->outputs[0]->shape; - }}, - MATH_FUNC_REWRITER(abs), - MATH_FUNC_REWRITER(log), - MATH_FUNC_REWRITER(log2), - MATH_FUNC_REWRITER(log10), - MATH_FUNC_REWRITER(tanh)}; +static std::unordered_map> + rewriter_ops = { + {"reshape", + [](const Instruction& fill_constant, Instruction* instr) -> void { + (*instr)->op_type = "fill_constant"; + (*instr)->inputs.clear(); + // the outputs keep same + + CHECK((*instr)->attrs.count("shape")) + << "The reshape op should has attribute [shape]!"; + auto new_shape = (*instr)->attrs.at("shape"); + (*instr)->attrs = fill_constant->attrs; + (*instr)->attrs["shape"] = new_shape; + }}, + {"scale", + [](const Instruction& fill_constant, Instruction* instr) -> void { + (*instr)->op_type = "fill_constant"; + (*instr)->inputs.clear(); + // the outputs keep same + + auto scale = (*instr)->attrs.count("scale") + ? instr->GetAttrs("scale") + : 1.0f; + auto bias = (*instr)->attrs.count("bias") + ? instr->GetAttrs("bias") + : 0.0f; + auto bias_after_scale = + (*instr)->attrs.count("bias_after_scale") + ? instr->GetAttrs("bias_after_scale") + : true; + + (*instr)->attrs = fill_constant->attrs; + + const auto& old_attr = fill_constant->attrs.at("value"); + auto& new_attr = (*instr)->attrs.at("value"); + if (bias_after_scale) { + auto scale_func = [&](const auto& value) -> decltype(auto) { + return value * static_cast(scale) + + static_cast(bias); + }; + FILL_CONSTANT_VALUE_REWRITE(old_attr, scale_func, new_attr) + } else { + auto scale_func = [&](const auto& value) -> decltype(auto) { + return (value + static_cast(bias)) * + static_cast(scale); + }; + FILL_CONSTANT_VALUE_REWRITE(old_attr, scale_func, new_attr) + } + }}, + {"cast", + [](const Instruction& fill_constant, Instruction* instr) -> void { + (*instr)->op_type = "fill_constant"; + (*instr)->inputs.clear(); + // the outputs keep same + + CHECK((*instr)->attrs.count("dtype")) + << "The cast op should has attribute [dtype]!"; + auto cast_dtype = instr->GetAttrs("dtype"); + + (*instr)->attrs = fill_constant->attrs; + (*instr)->attrs["dtype"] = cast_dtype; + }}, + {"broadcast_to", + [](const Instruction& fill_constant, Instruction* instr) -> void { + (*instr)->op_type = "fill_constant"; + (*instr)->inputs.clear(); + // the outputs keep same + + CHECK((*instr)->attrs.count("out_shape")) + << "The cast op should has attribute [out_shape]!"; + auto out_shape = instr->GetAttrs>("out_shape"); + + (*instr)->attrs = fill_constant->attrs; + (*instr)->attrs["shape"] = out_shape; + }}, + {"slice", + [](const Instruction& fill_constant, Instruction* instr) -> void { + (*instr)->op_type = "fill_constant"; + (*instr)->inputs.clear(); + // the outputs keep same + + (*instr)->attrs = fill_constant->attrs; + (*instr)->attrs["shape"] = (*instr)->outputs[0]->shape; + }}, + MATH_FUNC_REWRITER(abs), + MATH_FUNC_REWRITER(log), + MATH_FUNC_REWRITER(log2), + MATH_FUNC_REWRITER(log10), + MATH_FUNC_REWRITER(tanh)}; #undef FILL_CONSTANT_VALUE_REWRITE #undef MATH_FUNC_REWRITER @@ -152,7 +165,8 @@ class FillConstantRewriterPass : public ProgramPass { RewriteFillConstant(instr, input2instr, fetch_ids, &remove_instr); } } - VLOG(3) << "FillConstantRewriterPass Remove " << remove_instr.size() << " instruction"; + VLOG(3) << "FillConstantRewriterPass Remove " << remove_instr.size() + << " instruction"; NetBuilder builder("reshape_rewritter_builder"); for (auto& var : program->GetInputs()) { @@ -170,7 +184,8 @@ class FillConstantRewriterPass : public ProgramPass { } private: - using Input2Instr = std::unordered_map>; + using Input2Instr = + std::unordered_map>; Input2Instr GetInput2Instr(Program* program) { Input2Instr input2instr; @@ -185,12 +200,14 @@ class FillConstantRewriterPass : public ProgramPass { return input2instr; } - void RewriteFillConstant(const Instruction& fill_constant, - const Input2Instr& input2instr, - const std::unordered_set& fetch_ids, - std::unordered_set* remove_instr) { + void RewriteFillConstant( + const Instruction& fill_constant, + const Input2Instr& input2instr, + const std::unordered_set& fetch_ids, + std::unordered_set* remove_instr) { CHECK_EQ(fill_constant->op_type, std::string("fill_constant")); - CHECK_EQ(fill_constant->outputs.size(), 1UL) << "The fill_constant op should just has one output! Please check."; + CHECK_EQ(fill_constant->outputs.size(), 1UL) + << "The fill_constant op should just has one output! Please check."; const auto& out = fill_constant->outputs[0]; if (!input2instr.count(out->id)) { @@ -220,7 +237,8 @@ class FillConstantRewriterPass : public ProgramPass { } // namespace cinn CINN_REGISTER_HELPER(FillConstantRewriter) { - CINN_REGISTER_PROGRAM_PASS(FillConstantRewriter, cinn::frontend::pass::FillConstantRewriterPass); + CINN_REGISTER_PROGRAM_PASS(FillConstantRewriter, + cinn::frontend::pass::FillConstantRewriterPass); return true; } diff --git a/paddle/cinn/frontend/pass/fill_constant_rewriter_test.cc b/paddle/cinn/frontend/pass/fill_constant_rewriter_test.cc index 3c215d66f2b96..7823e1d63e493 100644 --- a/paddle/cinn/frontend/pass/fill_constant_rewriter_test.cc +++ b/paddle/cinn/frontend/pass/fill_constant_rewriter_test.cc @@ -30,17 +30,19 @@ TEST(FillConstantRewriter, remove_reshape_single) { // | | // NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {32, 16}, "x"); - auto identity_1 = builder.Identity(x); - auto reshape_1 = builder.Reshape(x, {32, 16}); + auto x = builder.CreateInput(Float(32), {32, 16}, "x"); + auto identity_1 = builder.Identity(x); + auto reshape_1 = builder.Reshape(x, {32, 16}); auto reduce_sum_1 = builder.ReduceSum(identity_1, {0}); auto reduce_sum_2 = builder.ReduceSum(reshape_1, {1}); PassTest tester; - std::vector input_names = {x.id().data()}; - std::vector output_names = {reduce_sum_1->id, reduce_sum_2->id}; - std::vector program_passes = {"FillConstantRewriter", "RemoveIdentity"}; - int num_removed_ops = tester.RunAndCheck(builder, program_passes, input_names, output_names); + std::vector input_names = {x.id().data()}; + std::vector output_names = {reduce_sum_1->id, reduce_sum_2->id}; + std::vector program_passes = {"FillConstantRewriter", + "RemoveIdentity"}; + int num_removed_ops = + tester.RunAndCheck(builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 2); } @@ -53,17 +55,20 @@ TEST(FillConstantRewriter, remove_reshape_with_fill_constant) { // | // NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {32, 16}, "x"); - auto constant_1 = builder.FillConstant({16, 32}, static_cast(1.0), "constant_1"); - auto reshape_1 = builder.Reshape(constant_1, {32, 16}); - auto reshape_2 = builder.Reshape(x, {32, 16}); - auto add_1 = builder.Add(reshape_1, reshape_2); + auto x = builder.CreateInput(Float(32), {32, 16}, "x"); + auto constant_1 = builder.FillConstant( + {16, 32}, static_cast(1.0), "constant_1"); + auto reshape_1 = builder.Reshape(constant_1, {32, 16}); + auto reshape_2 = builder.Reshape(x, {32, 16}); + auto add_1 = builder.Add(reshape_1, reshape_2); PassTest tester; - std::vector input_names = {x.id().data()}; - std::vector output_names = {add_1->id}; - std::vector program_passes = {"FillConstantRewriter", "RemoveIdentity"}; - int num_removed_ops = tester.RunAndCheck(builder, program_passes, input_names, output_names); + std::vector input_names = {x.id().data()}; + std::vector output_names = {add_1->id}; + std::vector program_passes = {"FillConstantRewriter", + "RemoveIdentity"}; + int num_removed_ops = + tester.RunAndCheck(builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 2); } @@ -76,17 +81,19 @@ TEST(FillConstantRewriter, remove_scale_single) { // | | // NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {32, 16}, "x"); - auto identity_1 = builder.Identity(x); - auto scale_1 = builder.Scale(x, 1.0f); + auto x = builder.CreateInput(Float(32), {32, 16}, "x"); + auto identity_1 = builder.Identity(x); + auto scale_1 = builder.Scale(x, 1.0f); auto reduce_sum_1 = builder.ReduceSum(identity_1, {0}); auto reduce_sum_2 = builder.ReduceSum(scale_1, {1}); PassTest tester; - std::vector input_names = {x.id().data()}; - std::vector output_names = {reduce_sum_1->id, reduce_sum_2->id}; - std::vector program_passes = {"FillConstantRewriter", "RemoveIdentity"}; - int num_removed_ops = tester.RunAndCheck(builder, program_passes, input_names, output_names); + std::vector input_names = {x.id().data()}; + std::vector output_names = {reduce_sum_1->id, reduce_sum_2->id}; + std::vector program_passes = {"FillConstantRewriter", + "RemoveIdentity"}; + int num_removed_ops = + tester.RunAndCheck(builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 2); } @@ -99,17 +106,19 @@ TEST(FillConstantRewriter, remove_scale_with_fill_constant) { // | // NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {32, 16}, "x"); + auto x = builder.CreateInput(Float(32), {32, 16}, "x"); auto constant_1 = builder.FillConstant({32, 16}, 128.0f, "constant_1"); - auto scale_1 = builder.Scale(constant_1, -1.0f); - auto scale_2 = builder.Scale(x, 1.0f); - auto add_1 = builder.Add(scale_1, scale_2); + auto scale_1 = builder.Scale(constant_1, -1.0f); + auto scale_2 = builder.Scale(x, 1.0f); + auto add_1 = builder.Add(scale_1, scale_2); PassTest tester; - std::vector input_names = {x.id().data()}; - std::vector output_names = {add_1->id}; - std::vector program_passes = {"FillConstantRewriter", "RemoveIdentity"}; - int num_removed_ops = tester.RunAndCheck(builder, program_passes, input_names, output_names); + std::vector input_names = {x.id().data()}; + std::vector output_names = {add_1->id}; + std::vector program_passes = {"FillConstantRewriter", + "RemoveIdentity"}; + int num_removed_ops = + tester.RunAndCheck(builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 2); } @@ -126,20 +135,22 @@ TEST(FillConstantRewriter, remove_multi_scale_with_fill_constant) { // | // NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {32, 16}, "x"); + auto x = builder.CreateInput(Float(32), {32, 16}, "x"); auto constant_1 = builder.FillConstant({32, 16}, 128.0f, "constant_1"); - auto scale_1 = builder.Scale(constant_1, -1.0f); - auto scale_2 = builder.Scale(scale_1, 2.0f, 10.0f); - auto scale_3 = builder.Scale(scale_2, 3.0f, 1.0f, false); + auto scale_1 = builder.Scale(constant_1, -1.0f); + auto scale_2 = builder.Scale(scale_1, 2.0f, 10.0f); + auto scale_3 = builder.Scale(scale_2, 3.0f, 1.0f, false); - auto x_1 = builder.Scale(x, 1.0f); + auto x_1 = builder.Scale(x, 1.0f); auto add_1 = builder.Add(scale_3, x_1); PassTest tester; - std::vector input_names = {x.id().data()}; - std::vector output_names = {add_1->id}; - std::vector program_passes = {"FillConstantRewriter", "RemoveIdentity"}; - int num_removed_ops = tester.RunAndCheck(builder, program_passes, input_names, output_names); + std::vector input_names = {x.id().data()}; + std::vector output_names = {add_1->id}; + std::vector program_passes = {"FillConstantRewriter", + "RemoveIdentity"}; + int num_removed_ops = + tester.RunAndCheck(builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 4); } @@ -147,13 +158,16 @@ TEST(FillConstantRewriter, two_fill_constant) { // fill_constant({16, 32}) fill_constant({16, 32}) NetBuilder builder("net_builder"); auto constant_1 = builder.FillConstant({32, 16}, 128.0f, "constant_1"); - auto constant_2 = builder.FillConstant({32, 16}, -128.0f, "constant_2"); + auto constant_2 = + builder.FillConstant({32, 16}, -128.0f, "constant_2"); PassTest tester; - std::vector input_names = {}; - std::vector output_names = {constant_1->id, constant_2->id}; - std::vector program_passes = {"FillConstantRewriter", "RemoveIdentity"}; - int num_removed_ops = tester.RunAndCheck(builder, program_passes, input_names, output_names); + std::vector input_names = {}; + std::vector output_names = {constant_1->id, constant_2->id}; + std::vector program_passes = {"FillConstantRewriter", + "RemoveIdentity"}; + int num_removed_ops = + tester.RunAndCheck(builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 0); } diff --git a/paddle/cinn/frontend/pass/gemm_rewriter.cc b/paddle/cinn/frontend/pass/gemm_rewriter.cc index 1306eaba19718..9a43ea4ade125 100644 --- a/paddle/cinn/frontend/pass/gemm_rewriter.cc +++ b/paddle/cinn/frontend/pass/gemm_rewriter.cc @@ -17,9 +17,9 @@ #include #include +#include "glog/logging.h" #include "paddle/cinn/frontend/net_builder.h" #include "paddle/cinn/frontend/program_pass.h" -#include "glog/logging.h" namespace cinn { namespace frontend { @@ -94,38 +94,43 @@ class GemmRewriterPass : public ProgramPass { } // Fuse the pattern of `matmul + add` - bool DoGemmFusion(NetBuilder* builder, const Instruction& instr, const std::unordered_set& fetch_ids) { - CHECK_EQ(instr->inputs.size(), 2) << "elementwise should have only two inputs"; + bool DoGemmFusion(NetBuilder* builder, + const Instruction& instr, + const std::unordered_set& fetch_ids) { + CHECK_EQ(instr->inputs.size(), 2) + << "elementwise should have only two inputs"; std::vector inputs; - bool trans_a = false; - bool trans_b = false; + bool trans_a = false; + bool trans_b = false; bool trans_out = false; - float alpha = 1.f; + float alpha = 1.f; std::unordered_set dot_instrs{"matmul", "cublas_matmul"}; for (auto& var : instr->inputs) { auto it = output2instr_.find(var.get()); if (it != output2instr_.end() && dot_instrs.count(it->second->op_type)) { - // If the output var of matmul is consumed by more than one instruction or - // a fetch var, just skip to fuse it. + // If the output var of matmul is consumed by more than one instruction + // or a fetch var, just skip to fuse it. CHECK_GT(var_used_count_.count(var.get()), 0) << "The input(" << var->id << ")" - << "should be included in var_used_count_. Please check the CollectInfo method."; + << "should be included in var_used_count_. Please check the " + "CollectInfo method."; if ((var_used_count_.at(var.get()) > 1) || fetch_ids.count(var->id)) { continue; } auto& matmul_instr = it->second; // check inputs of cublas_gemm - auto& bias = instr->inputs[0].get() == var.get() ? instr->inputs[1] : instr->inputs[0]; + auto& bias = instr->inputs[0].get() == var.get() ? instr->inputs[1] + : instr->inputs[0]; auto& matmul_inputs = matmul_instr->inputs; - int lhs_dim_size = matmul_inputs[0]->shape.size(); - int rhs_dim_size = matmul_inputs[1]->shape.size(); - int bias_dim_size = bias->shape.size(); + int lhs_dim_size = matmul_inputs[0]->shape.size(); + int rhs_dim_size = matmul_inputs[1]->shape.size(); + int bias_dim_size = bias->shape.size(); // only support the condition below: // 1) tow-dim matrix multiply, such as m * k, k * n // 2) three-dim tensor multiply, such as b * m * k, b * k * n - if (!((lhs_dim_size == 2 || lhs_dim_size == 3) && lhs_dim_size == rhs_dim_size && - rhs_dim_size == bias_dim_size)) { + if (!((lhs_dim_size == 2 || lhs_dim_size == 3) && + lhs_dim_size == rhs_dim_size && rhs_dim_size == bias_dim_size)) { continue; } // set inputs of cublas_gemm @@ -157,10 +162,12 @@ class GemmRewriterPass : public ProgramPass { VLOG(4) << "-- The trans_a of GEMM: " << std::boolalpha << trans_a; VLOG(4) << "-- The trans_b of GEMM: " << std::boolalpha << trans_b; VLOG(4) << "-- The trans_out of GEMM: " << std::boolalpha << trans_out; - const auto& new_outs = builder->CustomInstr( - "cublas_gemm", - inputs, - {{"trans_a", trans_a}, {"trans_b", trans_b}, {"trans_out", trans_out}, {"alpha", alpha}}); + const auto& new_outs = builder->CustomInstr("cublas_gemm", + inputs, + {{"trans_a", trans_a}, + {"trans_b", trans_b}, + {"trans_out", trans_out}, + {"alpha", alpha}}); auto new_out = new_outs[0]; auto old_out = instr.GetOutput(0); new_out.set_id(old_out->id); @@ -178,8 +185,8 @@ class GemmRewriterPass : public ProgramPass { auto& instr = (*prog)[i]; if (instr->op_type == "matmul") { auto& matmul_inputs = instr->inputs; - int lhs_dim_size = matmul_inputs[0]->shape.size(); - int rhs_dim_size = matmul_inputs[1]->shape.size(); + int lhs_dim_size = matmul_inputs[0]->shape.size(); + int rhs_dim_size = matmul_inputs[1]->shape.size(); // only support the condition below: // 1) tow-dim matrix multiply, such as m * k, k * n // 2) three-dim tensor multiply, such as b * m * k, b * k * n diff --git a/paddle/cinn/frontend/pass/gemm_rewriter_test.cc b/paddle/cinn/frontend/pass/gemm_rewriter_test.cc index f84aebd6b2c39..88a4f7482f48e 100755 --- a/paddle/cinn/frontend/pass/gemm_rewriter_test.cc +++ b/paddle/cinn/frontend/pass/gemm_rewriter_test.cc @@ -35,12 +35,12 @@ TEST(GemmRwriter, BatchedTransLeft) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {3, 6, 8}, "A"); - auto b = builder.Transpose(a, {0, 2, 1}); - auto c = builder.CreateInput(Float(32), {3, 6, 7}, "C"); - auto d = builder.Matmul(b, c); - auto e = builder.CreateInput(Float(32), {3, 8, 7}, "E"); - auto out = builder.Add(d, e); + auto a = builder.CreateInput(Float(32), {3, 6, 8}, "A"); + auto b = builder.Transpose(a, {0, 2, 1}); + auto c = builder.CreateInput(Float(32), {3, 6, 7}, "C"); + auto d = builder.Matmul(b, c); + auto e = builder.CreateInput(Float(32), {3, 8, 7}, "E"); + auto out = builder.Add(d, e); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -48,8 +48,9 @@ TEST(GemmRwriter, BatchedTransLeft) { absl::c_transform(std::vector{a.id(), c.id(), e.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - std::pair, std::vector> passes{{"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter"}}; + std::pair, std::vector> passes{ + {"Decomposer", "RemoveIdentity"}, + {"TransposeFoldingInput", "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } @@ -58,12 +59,12 @@ TEST(GemmRwriter, BatchedTransRight) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {3, 8, 6}, "A"); - auto b = builder.CreateInput(Float(32), {3, 7, 6}, "B"); - auto c = builder.Transpose(b, {0, 2, 1}); - auto e = builder.Matmul(a, c); - auto f = builder.CreateInput(Float(32), {3, 8, 7}, "F"); - auto out = builder.Add(e, f); + auto a = builder.CreateInput(Float(32), {3, 8, 6}, "A"); + auto b = builder.CreateInput(Float(32), {3, 7, 6}, "B"); + auto c = builder.Transpose(b, {0, 2, 1}); + auto e = builder.Matmul(a, c); + auto f = builder.CreateInput(Float(32), {3, 8, 7}, "F"); + auto out = builder.Add(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -71,8 +72,9 @@ TEST(GemmRwriter, BatchedTransRight) { absl::c_transform(std::vector{a.id(), b.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - std::pair, std::vector> passes{{"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter"}}; + std::pair, std::vector> passes{ + {"Decomposer", "RemoveIdentity"}, + {"TransposeFoldingInput", "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } @@ -81,13 +83,13 @@ TEST(GemmRwriter, BatchedTransTwo) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {3, 6, 8}, "A"); - auto b = builder.Transpose(a, {0, 2, 1}); - auto c = builder.CreateInput(Float(32), {3, 7, 6}, "C"); - auto d = builder.Transpose(c, {0, 2, 1}); - auto e = builder.Matmul(b, d); - auto f = builder.CreateInput(Float(32), {3, 8, 7}, "F"); - auto out = builder.Add(e, f); + auto a = builder.CreateInput(Float(32), {3, 6, 8}, "A"); + auto b = builder.Transpose(a, {0, 2, 1}); + auto c = builder.CreateInput(Float(32), {3, 7, 6}, "C"); + auto d = builder.Transpose(c, {0, 2, 1}); + auto e = builder.Matmul(b, d); + auto f = builder.CreateInput(Float(32), {3, 8, 7}, "F"); + auto out = builder.Add(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -95,8 +97,9 @@ TEST(GemmRwriter, BatchedTransTwo) { absl::c_transform(std::vector{a.id(), c.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter"}); + auto passes = std::make_pair( + std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } @@ -105,11 +108,11 @@ TEST(GemmRwriter, BatchedNoTrans) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {3, 8, 6}, "A"); - auto b = builder.CreateInput(Float(32), {3, 6, 7}, "B"); - auto e = builder.Matmul(a, b); - auto f = builder.CreateInput(Float(32), {3, 8, 7}, "F"); - auto out = builder.Add(e, f); + auto a = builder.CreateInput(Float(32), {3, 8, 6}, "A"); + auto b = builder.CreateInput(Float(32), {3, 6, 7}, "B"); + auto e = builder.Matmul(a, b); + auto f = builder.CreateInput(Float(32), {3, 8, 7}, "F"); + auto out = builder.Add(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -117,8 +120,9 @@ TEST(GemmRwriter, BatchedNoTrans) { absl::c_transform(std::vector{a.id(), b.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter"}); + auto passes = std::make_pair( + std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 0, passes, 123, true); } @@ -127,12 +131,12 @@ TEST(GemmRwriter, TransLeft) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {6, 8}, "A"); - auto b = builder.Transpose(a, {1, 0}); - auto c = builder.CreateInput(Float(32), {6, 7}, "C"); - auto d = builder.Matmul(b, c); - auto e = builder.CreateInput(Float(32), {8, 7}, "E"); - auto out = builder.Add(d, e); + auto a = builder.CreateInput(Float(32), {6, 8}, "A"); + auto b = builder.Transpose(a, {1, 0}); + auto c = builder.CreateInput(Float(32), {6, 7}, "C"); + auto d = builder.Matmul(b, c); + auto e = builder.CreateInput(Float(32), {8, 7}, "E"); + auto out = builder.Add(d, e); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -140,8 +144,9 @@ TEST(GemmRwriter, TransLeft) { absl::c_transform(std::vector{a.id(), c.id(), e.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter"}); + auto passes = std::make_pair( + std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } @@ -150,12 +155,12 @@ TEST(GemmRwriter, TransRight) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {8, 6}, "A"); - auto b = builder.CreateInput(Float(32), {7, 6}, "B"); - auto c = builder.Transpose(b, {1, 0}); - auto e = builder.Matmul(a, c); - auto f = builder.CreateInput(Float(32), {8, 7}, "F"); - auto out = builder.Add(e, f); + auto a = builder.CreateInput(Float(32), {8, 6}, "A"); + auto b = builder.CreateInput(Float(32), {7, 6}, "B"); + auto c = builder.Transpose(b, {1, 0}); + auto e = builder.Matmul(a, c); + auto f = builder.CreateInput(Float(32), {8, 7}, "F"); + auto out = builder.Add(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -163,8 +168,9 @@ TEST(GemmRwriter, TransRight) { absl::c_transform(std::vector{a.id(), b.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter"}); + auto passes = std::make_pair( + std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } @@ -173,13 +179,13 @@ TEST(GemmRwriter, TransTwo) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {6, 8}, "A"); - auto b = builder.Transpose(a, {1, 0}); - auto c = builder.CreateInput(Float(32), {7, 6}, "C"); - auto d = builder.Transpose(c, {1, 0}); - auto e = builder.Matmul(b, d); - auto f = builder.CreateInput(Float(32), {8, 7}, "F"); - auto out = builder.Add(e, f); + auto a = builder.CreateInput(Float(32), {6, 8}, "A"); + auto b = builder.Transpose(a, {1, 0}); + auto c = builder.CreateInput(Float(32), {7, 6}, "C"); + auto d = builder.Transpose(c, {1, 0}); + auto e = builder.Matmul(b, d); + auto f = builder.CreateInput(Float(32), {8, 7}, "F"); + auto out = builder.Add(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -187,8 +193,9 @@ TEST(GemmRwriter, TransTwo) { absl::c_transform(std::vector{a.id(), c.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter"}); + auto passes = std::make_pair( + std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } @@ -197,11 +204,11 @@ TEST(GemmRwriter, NoTrans) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {8, 6}, "A"); - auto b = builder.CreateInput(Float(32), {6, 7}, "B"); - auto e = builder.Matmul(a, b); - auto f = builder.CreateInput(Float(32), {8, 7}, "F"); - auto out = builder.Add(e, f); + auto a = builder.CreateInput(Float(32), {8, 6}, "A"); + auto b = builder.CreateInput(Float(32), {6, 7}, "B"); + auto e = builder.Matmul(a, b); + auto f = builder.CreateInput(Float(32), {8, 7}, "F"); + auto out = builder.Add(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -209,8 +216,9 @@ TEST(GemmRwriter, NoTrans) { absl::c_transform(std::vector{a.id(), b.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter"}); + auto passes = std::make_pair( + std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 0, passes, 123, true); } @@ -219,22 +227,22 @@ TEST(GemmRwriter, BatchedComplex) { return; } NetBuilder builder("net_builder"); - auto a = builder.FillConstant({2, 20}, 2.0f, "A"); - auto b = builder.FillConstant({16, 2, 20}, 2.0f, "B"); - auto c = builder.Transpose(b, {0, 2, 1}); - auto d = builder.CreateInput(Float(32), {121, 20}, "D"); - auto e = builder.BroadcastTo(d, {16, 121, 20}, {1, 2}); - auto f = builder.Matmul(e, c); - auto x = builder.FillConstant({16, 2, 20}, 1.0f, "X"); - auto y = builder.Transpose(x, {0, 2, 1}); - auto z = builder.CreateInput(Float(32), {16, 20, 121}, "Z"); - auto l = builder.Transpose(z, {0, 2, 1}); - auto m = builder.Matmul(l, y); - auto n = builder.Matmul(d, a, false, true); - auto o = builder.BroadcastTo(n, {16, n->shape[0], n->shape[1]}, {1, 2}); - auto p = builder.Subtract(f, o); - auto q = builder.Add(f, m); - auto out = builder.Add(p, q); + auto a = builder.FillConstant({2, 20}, 2.0f, "A"); + auto b = builder.FillConstant({16, 2, 20}, 2.0f, "B"); + auto c = builder.Transpose(b, {0, 2, 1}); + auto d = builder.CreateInput(Float(32), {121, 20}, "D"); + auto e = builder.BroadcastTo(d, {16, 121, 20}, {1, 2}); + auto f = builder.Matmul(e, c); + auto x = builder.FillConstant({16, 2, 20}, 1.0f, "X"); + auto y = builder.Transpose(x, {0, 2, 1}); + auto z = builder.CreateInput(Float(32), {16, 20, 121}, "Z"); + auto l = builder.Transpose(z, {0, 2, 1}); + auto m = builder.Matmul(l, y); + auto n = builder.Matmul(d, a, false, true); + auto o = builder.BroadcastTo(n, {16, n->shape[0], n->shape[1]}, {1, 2}); + auto p = builder.Subtract(f, o); + auto q = builder.Add(f, m); + auto out = builder.Add(p, q); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -242,8 +250,9 @@ TEST(GemmRwriter, BatchedComplex) { absl::c_transform(std::vector{d.id(), z.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter"}); + auto passes = std::make_pair( + std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 4, passes, 123, false); } @@ -252,19 +261,19 @@ TEST(GemmRwriter, Complex) { return; } NetBuilder builder("net_builder"); - auto a = builder.FillConstant({2, 20}, 2.0f, "A"); - auto b = builder.Transpose(a, {1, 0}); - auto c = builder.CreateInput(Float(32), {121, 20}, "C"); - auto d = builder.Matmul(c, b); - auto x = builder.FillConstant({2, 20}, 1.0f, "X"); - auto y = builder.Transpose(x, {1, 0}); - auto z = builder.CreateInput(Float(32), {20, 121}, "Z"); - auto l = builder.Transpose(z, {1, 0}); - auto m = builder.Matmul(l, y); - auto n = builder.Matmul(c, a, false, true); - auto p = builder.Subtract(d, n); - auto q = builder.Add(d, m); - auto out = builder.Add(p, q); + auto a = builder.FillConstant({2, 20}, 2.0f, "A"); + auto b = builder.Transpose(a, {1, 0}); + auto c = builder.CreateInput(Float(32), {121, 20}, "C"); + auto d = builder.Matmul(c, b); + auto x = builder.FillConstant({2, 20}, 1.0f, "X"); + auto y = builder.Transpose(x, {1, 0}); + auto z = builder.CreateInput(Float(32), {20, 121}, "Z"); + auto l = builder.Transpose(z, {1, 0}); + auto m = builder.Matmul(l, y); + auto n = builder.Matmul(c, a, false, true); + auto p = builder.Subtract(d, n); + auto q = builder.Add(d, m); + auto out = builder.Add(p, q); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -272,8 +281,9 @@ TEST(GemmRwriter, Complex) { absl::c_transform(std::vector{c.id(), z.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter"}); + auto passes = std::make_pair( + std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 3, passes, 123, false); } diff --git a/paddle/cinn/frontend/pass/pass_test_helper.h b/paddle/cinn/frontend/pass/pass_test_helper.h index 1c8fac1c3365a..5c57835e1f161 100644 --- a/paddle/cinn/frontend/pass/pass_test_helper.h +++ b/paddle/cinn/frontend/pass/pass_test_helper.h @@ -50,9 +50,10 @@ inline void PrintMatrix(const std::vector& mat, int bs, int m, int n) { return; } const auto min_max = std::minmax_element(mat.begin(), mat.end()); - int min = static_cast(*min_max.first); - int max = static_cast(*min_max.second); - auto ele_width = std::max(std::to_string(min).length(), std::to_string(max).length()); + int min = static_cast(*min_max.first); + int max = static_cast(*min_max.second); + auto ele_width = + std::max(std::to_string(min).length(), std::to_string(max).length()); std::cout << "\n" << std::string((ele_width + 2) * n - 1, '-') << "\n"; for (int b = 0; b < bs; b++) { for (int i = 0; i < m; i++) { @@ -77,23 +78,28 @@ inline void RunGraph(std::shared_ptr graph, VLOG(3) << "Graph Viz:\n" << graph->Visualize(); BuildScope(target, graph, scope); hlir::framework::GraphCompiler::CompileOptions options; - options.attached_code = ""; + options.attached_code = ""; options.with_instantiate_variables = true; hlir::framework::GraphCompiler gc(target, scope, graph); - auto runtime_program = - gc.Build(options, std::unordered_set(output_ids.begin(), output_ids.end())).runtime_program; + auto runtime_program = gc.Build(options, + std::unordered_set( + output_ids.begin(), output_ids.end())) + .runtime_program; runtime_program->Execute(); } -inline std::vector RunProgram(const Program& program, - const common::Target& target, - const std::vector& input_ids, - const std::vector& output_ids, - const std::vector& graph_passes, - int seed = -1, - bool print_tensor = false) { - std::unordered_set outputs_set{output_ids.begin(), output_ids.end()}; - auto graph = std::make_shared(program, outputs_set, target); +inline std::vector RunProgram( + const Program& program, + const common::Target& target, + const std::vector& input_ids, + const std::vector& output_ids, + const std::vector& graph_passes, + int seed = -1, + bool print_tensor = false) { + std::unordered_set outputs_set{output_ids.begin(), + output_ids.end()}; + auto graph = + std::make_shared(program, outputs_set, target); auto scope = hlir::framework::BuildScope(target, graph); for (auto& input_id : input_ids) { scope->Var(input_id); @@ -102,7 +108,10 @@ inline std::vector RunProgram(const Program& program, if (print_tensor) { auto tensor_data = GetTensorData(input_tensor, target); if (input_tensor->shape().data().size() == 2) { - PrintMatrix(tensor_data, 1, input_tensor->shape().data()[0], input_tensor->shape().data()[1]); + PrintMatrix(tensor_data, + 1, + input_tensor->shape().data()[0], + input_tensor->shape().data()[1]); } else if (input_tensor->shape().data().size() == 3) { PrintMatrix(tensor_data, input_tensor->shape().data()[0], @@ -115,10 +124,13 @@ inline std::vector RunProgram(const Program& program, RunGraph(graph, target, scope, output_ids, graph_passes); auto output_tensor = scope->GetTensor(output_ids.front()); - auto output_data = GetTensorData(output_tensor, target); + auto output_data = GetTensorData(output_tensor, target); if (print_tensor) { if (output_tensor->shape().data().size() == 2) { - PrintMatrix(output_data, 1, output_tensor->shape().data()[0], output_tensor->shape().data()[1]); + PrintMatrix(output_data, + 1, + output_tensor->shape().data()[0], + output_tensor->shape().data()[1]); } else if (output_tensor->shape().data().size() == 3) { PrintMatrix(output_data, output_tensor->shape().data()[0], @@ -131,21 +143,25 @@ inline std::vector RunProgram(const Program& program, struct OptimizeConfig { struct PassGroup; - OptimizeConfig(const PassGroup& program_passes) : program_passes{program_passes} { + OptimizeConfig(const PassGroup& program_passes) + : program_passes{program_passes} { if (FLAGS_cinn_use_op_fusion) { - graph_passes = {{"OpFusionPass", "FusionMergePass"}, {"OpFusionPass", "FusionMergePass"}}; + graph_passes = {{"OpFusionPass", "FusionMergePass"}, + {"OpFusionPass", "FusionMergePass"}}; } } OptimizeConfig(const PassGroup& program_passes, const PassGroup& graph_passes) : program_passes{program_passes}, graph_passes{graph_passes} {} - OptimizeConfig(const std::pair, std::vector>& program_passes) { + OptimizeConfig(const std::pair, + std::vector>& program_passes) { this->program_passes.ctrl = program_passes.first; - this->program_passes.exp = program_passes.second; + this->program_passes.exp = program_passes.second; if (FLAGS_cinn_use_op_fusion) { - graph_passes = {{"TransToCustomCallPass", "OpFusionPass", "FusionMergePass"}, - {"TransToCustomCallPass", "OpFusionPass", "FusionMergePass"}}; + graph_passes = { + {"TransToCustomCallPass", "OpFusionPass", "FusionMergePass"}, + {"TransToCustomCallPass", "OpFusionPass", "FusionMergePass"}}; } } @@ -165,15 +181,22 @@ inline void CompareResult(Program* program, const std::vector& output_ids, size_t size_diff, const OptimizeConfig& passes, - int seed = -1, + int seed = -1, bool print_tensor = false) { - std::unordered_set fetch_ids(output_ids.begin(), output_ids.end()); + std::unordered_set fetch_ids(output_ids.begin(), + output_ids.end()); // apply common passes ProgramPass::Apply(program, fetch_ids, target, passes.program_passes.ctrl); // get original program size auto origin_size = program->size(); // get original output - auto origin_out = RunProgram(*program, target, input_ids, output_ids, passes.graph_passes.ctrl, seed, print_tensor); + auto origin_out = RunProgram(*program, + target, + input_ids, + output_ids, + passes.graph_passes.ctrl, + seed, + print_tensor); // apply fused passes ProgramPass::Apply(program, fetch_ids, target, passes.program_passes.exp); @@ -182,7 +205,13 @@ inline void CompareResult(Program* program, auto fused_size = program->size(); ASSERT_EQ(size_diff, origin_size - fused_size); // get fused output - auto fused_out = RunProgram(*program, target, input_ids, output_ids, passes.graph_passes.exp, seed, print_tensor); + auto fused_out = RunProgram(*program, + target, + input_ids, + output_ids, + passes.graph_passes.exp, + seed, + print_tensor); ASSERT_EQ(origin_out.size(), fused_out.size()); for (size_t i = 0; i < origin_out.size(); ++i) { @@ -190,11 +219,12 @@ inline void CompareResult(Program* program, } } -inline bool CompareProgramPassResult(Program* program, - const common::Target& target, - const std::unordered_set& fetch_ids, - const size_t size_diff, - const OptimizeConfig& passes) { +inline bool CompareProgramPassResult( + Program* program, + const common::Target& target, + const std::unordered_set& fetch_ids, + const size_t size_diff, + const OptimizeConfig& passes) { // apply common passes ProgramPass::Apply(program, fetch_ids, target, passes.program_passes.ctrl); // get original program size diff --git a/paddle/cinn/frontend/pass/program_topoerror_test.cc b/paddle/cinn/frontend/pass/program_topoerror_test.cc index 141d3dceb1ff6..ca691444b7b5b 100644 --- a/paddle/cinn/frontend/pass/program_topoerror_test.cc +++ b/paddle/cinn/frontend/pass/program_topoerror_test.cc @@ -50,13 +50,13 @@ TEST(TransposeFoldingInput, TransposeWithMultiMamtul) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {2, 2}, "X"); - auto y = builder.CreateInput(Float(32), {2, 2}, "Y"); + auto x = builder.CreateInput(Float(32), {2, 2}, "X"); + auto y = builder.CreateInput(Float(32), {2, 2}, "Y"); auto transpose_y = builder.Transpose(y, {1, 0}); - auto dot1 = builder.Matmul(x, transpose_y); - auto dot2 = builder.Matmul(transpose_y, x); - auto out = builder.Add(dot1, dot2); - auto program = builder.Build(); + auto dot1 = builder.Matmul(x, transpose_y); + auto dot2 = builder.Matmul(transpose_y, x); + auto out = builder.Add(dot1, dot2); + auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; @@ -64,7 +64,11 @@ TEST(TransposeFoldingInput, TransposeWithMultiMamtul) { std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ - {"Decomposer"}, {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"Decomposer"}, + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } diff --git a/paddle/cinn/frontend/pass/remove_identity.cc b/paddle/cinn/frontend/pass/remove_identity.cc index c5f8d3853cefe..bd80a45701ded 100644 --- a/paddle/cinn/frontend/pass/remove_identity.cc +++ b/paddle/cinn/frontend/pass/remove_identity.cc @@ -16,9 +16,9 @@ #include #include +#include "glog/logging.h" #include "paddle/cinn/frontend/net_builder.h" #include "paddle/cinn/frontend/program_pass.h" -#include "glog/logging.h" namespace cinn { namespace frontend { @@ -33,48 +33,57 @@ namespace pass { } \ } -static std::unordered_map> identity_ops = { - {"identity", [](const Instruction& instr) -> bool { return true; }}, - {"scale", - [](const Instruction& instr) -> bool { - bool bias_zero = !instr->attrs.count("bias") || instr.GetAttrs("bias") == 0.0f; - bool scale_one = !instr->attrs.count("scale") || instr.GetAttrs("scale") == 1.0f; - return bias_zero && scale_one; - }}, - {"cast", - [](const Instruction& instr) -> bool { - const auto& input_dtype = instr->inputs[0]->type; - const auto& output_dtype = instr->outputs[0]->type; - return input_dtype == output_dtype; - }}, - {"transpose", - [](const Instruction& instr) -> bool { - const auto& input_shape = instr->inputs[0]->shape; - const auto& axis = instr.GetAttrs>("axis"); - - bool can_remove = (input_shape.size() == axis.size()); - if (can_remove) { - for (int i = 0; i < axis.size(); ++i) { - if (axis[i] != i) { - can_remove = false; - break; +static std::unordered_map> + identity_ops = { + {"identity", [](const Instruction& instr) -> bool { return true; }}, + {"scale", + [](const Instruction& instr) -> bool { + bool bias_zero = !instr->attrs.count("bias") || + instr.GetAttrs("bias") == 0.0f; + bool scale_one = !instr->attrs.count("scale") || + instr.GetAttrs("scale") == 1.0f; + return bias_zero && scale_one; + }}, + {"cast", + [](const Instruction& instr) -> bool { + const auto& input_dtype = instr->inputs[0]->type; + const auto& output_dtype = instr->outputs[0]->type; + return input_dtype == output_dtype; + }}, + {"transpose", + [](const Instruction& instr) -> bool { + const auto& input_shape = instr->inputs[0]->shape; + const auto& axis = instr.GetAttrs>("axis"); + + bool can_remove = (input_shape.size() == axis.size()); + if (can_remove) { + for (int i = 0; i < axis.size(); ++i) { + if (axis[i] != i) { + can_remove = false; + break; + } + } } - } - } - - return can_remove; - }}, - {"concat", [](const Instruction& instr) -> bool { return (instr->inputs.size() == 1); }}, - {"split", [](const Instruction& instr) -> bool { return (instr->outputs.size() == 1); }}, - SHAPE_SAME_REMOVE(broadcast_to), - SHAPE_SAME_REMOVE(reduce_sum), - SHAPE_SAME_REMOVE(reduce_prod), - SHAPE_SAME_REMOVE(reduce_max), - SHAPE_SAME_REMOVE(reduce_min), - SHAPE_SAME_REMOVE(reduce_all), - SHAPE_SAME_REMOVE(reduce_any), - SHAPE_SAME_REMOVE(slice), - SHAPE_SAME_REMOVE(reshape)}; + + return can_remove; + }}, + {"concat", + [](const Instruction& instr) -> bool { + return (instr->inputs.size() == 1); + }}, + {"split", + [](const Instruction& instr) -> bool { + return (instr->outputs.size() == 1); + }}, + SHAPE_SAME_REMOVE(broadcast_to), + SHAPE_SAME_REMOVE(reduce_sum), + SHAPE_SAME_REMOVE(reduce_prod), + SHAPE_SAME_REMOVE(reduce_max), + SHAPE_SAME_REMOVE(reduce_min), + SHAPE_SAME_REMOVE(reduce_all), + SHAPE_SAME_REMOVE(reduce_any), + SHAPE_SAME_REMOVE(slice), + SHAPE_SAME_REMOVE(reshape)}; #undef SHAPE_SAME_REMOVE @@ -82,7 +91,9 @@ namespace { bool check_reduce_to_reshape(const Instruction& instr) { const auto& input_shape = instr->inputs[0]->shape; - auto dims = instr->attrs.count("dim") ? instr.GetAttrs>("dim") : std::vector(); + auto dims = instr->attrs.count("dim") + ? instr.GetAttrs>("dim") + : std::vector(); if (dims.empty()) { for (int i = 0; i < input_shape.size(); ++i) { @@ -99,18 +110,20 @@ bool check_reduce_to_reshape(const Instruction& instr) { } } // namespace -static std::unordered_map> reshape_ops = { - {"reduce_sum", check_reduce_to_reshape}, - {"reduce_prod", check_reduce_to_reshape}, - {"reduce_max", check_reduce_to_reshape}, - {"reduce_min", check_reduce_to_reshape}, - {"reduce_all", check_reduce_to_reshape}, - {"reduce_any", check_reduce_to_reshape}}; +static std::unordered_map> + reshape_ops = {{"reduce_sum", check_reduce_to_reshape}, + {"reduce_prod", check_reduce_to_reshape}, + {"reduce_max", check_reduce_to_reshape}, + {"reduce_min", check_reduce_to_reshape}, + {"reduce_all", check_reduce_to_reshape}, + {"reduce_any", check_reduce_to_reshape}}; -// RemoveIdentityPass will remove the identity instructions in following patterns: +// RemoveIdentityPass will remove the identity instructions in following +// patterns: // // 1. When varB is not in fetch_ids, the identity and varB will be removed. -// When varB is in fetch_ids and varA is not in fetch_ids, the identity and varA will be removed. +// When varB is in fetch_ids and varA is not in fetch_ids, the identity and +// varA will be removed. // instrA instrA // | varA | // identity => | varA/varB @@ -147,16 +160,21 @@ class RemoveIdentityPass : public ProgramPass { auto& instr = (*program)[i]; if (replace_identity_idxs_.count(i)) { - VLOG(4) << "Replace op " << instr->outputs[0]->id << "[" << cinn::utils::Join(instr->outputs[0]->shape, ", ") + VLOG(4) << "Replace op " << instr->outputs[0]->id << "[" + << cinn::utils::Join(instr->outputs[0]->shape, ", ") << "]=" << instr->op_type << "{" << instr->inputs[0]->id << "[" - << cinn::utils::Join(instr->inputs[0]->shape, ", ") << "]} to identity"; + << cinn::utils::Join(instr->inputs[0]->shape, ", ") + << "]} to identity"; instr->op_type = "identity"; instr->attrs.clear(); - } else if (reshape_ops.count(instr->op_type) && reshape_ops.at(instr->op_type)(instr)) { - VLOG(4) << "Replace op " << instr->outputs[0]->id << "[" << cinn::utils::Join(instr->outputs[0]->shape, ", ") + } else if (reshape_ops.count(instr->op_type) && + reshape_ops.at(instr->op_type)(instr)) { + VLOG(4) << "Replace op " << instr->outputs[0]->id << "[" + << cinn::utils::Join(instr->outputs[0]->shape, ", ") << "]=" << instr->op_type << "{" << instr->inputs[0]->id << "[" - << cinn::utils::Join(instr->inputs[0]->shape, ", ") << "]} to reshape"; + << cinn::utils::Join(instr->inputs[0]->shape, ", ") + << "]} to reshape"; instr->op_type = "reshape"; instr->attrs.clear(); @@ -187,7 +205,8 @@ class RemoveIdentityPass : public ProgramPass { } private: - void CollectInfo(const Program& program, const std::unordered_set& fetch_ids) { + void CollectInfo(const Program& program, + const std::unordered_set& fetch_ids) { remove_idxs_.clear(); origin2new_.clear(); @@ -204,13 +223,16 @@ class RemoveIdentityPass : public ProgramPass { if (!identity_ops.at(instr->op_type)(instr)) { continue; } - CHECK_EQ(instr->inputs.size(), 1) << instr->op_type << " should have only 1 input. But here " << instr; - CHECK_EQ(instr->outputs.size(), 1) << instr->op_type << " should have only 1 output. But here " << instr; + CHECK_EQ(instr->inputs.size(), 1) + << instr->op_type << " should have only 1 input. But here " << instr; + CHECK_EQ(instr->outputs.size(), 1) + << instr->op_type << " should have only 1 output. But here " << instr; - auto& input_var = instr->inputs[0]; + auto& input_var = instr->inputs[0]; auto& output_var = instr->outputs[0]; - bool can_input_var_removed = !feed_ids.count(input_var->id) && !fetch_ids.count(input_var->id); + bool can_input_var_removed = + !feed_ids.count(input_var->id) && !fetch_ids.count(input_var->id); bool can_output_var_removed = !fetch_ids.count(output_var->id); if (can_input_var_removed || can_output_var_removed) { bool updated = false; @@ -231,10 +253,10 @@ class RemoveIdentityPass : public ProgramPass { for (auto& v : origin2new_) { const auto& reserved_var = v.second; - auto iter = origin2new_.find(reserved_var.get()); + auto iter = origin2new_.find(reserved_var.get()); if (iter != origin2new_.end()) { - VLOG(4) << "Update " << v.first->id << " -> " << reserved_var->id << " to " << v.first->id << " -> " - << iter->second->id; + VLOG(4) << "Update " << v.first->id << " -> " << reserved_var->id + << " to " << v.first->id << " -> " << iter->second->id; origin2new_[v.first] = iter->second; } } @@ -249,7 +271,8 @@ class RemoveIdentityPass : public ProgramPass { bool UpdateOrigin2New(const Variable& origin, const Variable& new_var) { if (!origin2new_.count(origin.get())) { if (origin2new_.count(new_var.get())) { - VLOG(4) << "Add " << origin->id << " -> " << origin2new_[new_var.get()]->id; + VLOG(4) << "Add " << origin->id << " -> " + << origin2new_[new_var.get()]->id; origin2new_.emplace(origin.get(), origin2new_[new_var.get()]); } else { VLOG(4) << "Add " << origin->id << " -> " << new_var->id; @@ -270,7 +293,8 @@ class RemoveIdentityPass : public ProgramPass { } // namespace cinn CINN_REGISTER_HELPER(RemoveIdentity) { - CINN_REGISTER_PROGRAM_PASS(RemoveIdentity, cinn::frontend::pass::RemoveIdentityPass); + CINN_REGISTER_PROGRAM_PASS(RemoveIdentity, + cinn::frontend::pass::RemoveIdentityPass); return true; } diff --git a/paddle/cinn/frontend/pass/remove_identity_test.cc b/paddle/cinn/frontend/pass/remove_identity_test.cc index 227f076e66cbb..13ad1e1a70019 100644 --- a/paddle/cinn/frontend/pass/remove_identity_test.cc +++ b/paddle/cinn/frontend/pass/remove_identity_test.cc @@ -28,17 +28,19 @@ TEST(RemoveIdentity, remove_single) { // | | // NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {32, 16}, "x"); - auto identity_1 = builder.Identity(x); - auto identity_2 = builder.Identity(x); + auto x = builder.CreateInput(Float(32), {32, 16}, "x"); + auto identity_1 = builder.Identity(x); + auto identity_2 = builder.Identity(x); auto reduce_sum_1 = builder.ReduceSum(identity_1, {0}); auto reduce_sum_2 = builder.ReduceSum(identity_2, {1}); PassTest tester; - std::vector input_names = {x.id().data()}; - std::vector output_names = {reduce_sum_1->id}; - std::vector program_passes = {"RemoveIdentity", "DeadCodeEliminate"}; - int num_removed_ops = tester.RunAndCheck(builder, program_passes, input_names, output_names); + std::vector input_names = {x.id().data()}; + std::vector output_names = {reduce_sum_1->id}; + std::vector program_passes = {"RemoveIdentity", + "DeadCodeEliminate"}; + int num_removed_ops = + tester.RunAndCheck(builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 3); } @@ -51,16 +53,17 @@ TEST(RemoveIdentity, remove_branch) { // | | // NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {32, 16}, "x"); - auto identity_1 = builder.Identity(x); + auto x = builder.CreateInput(Float(32), {32, 16}, "x"); + auto identity_1 = builder.Identity(x); auto reduce_sum_1 = builder.ReduceSum(identity_1, {0}); auto reduce_sum_2 = builder.ReduceSum(identity_1, {1}); PassTest tester; - std::vector input_names = {x.id().data()}; - std::vector output_names = {reduce_sum_1->id, reduce_sum_2->id}; + std::vector input_names = {x.id().data()}; + std::vector output_names = {reduce_sum_1->id, reduce_sum_2->id}; std::vector program_passes = {"RemoveIdentity"}; - int num_removed_ops = tester.RunAndCheck(builder, program_passes, input_names, output_names); + int num_removed_ops = + tester.RunAndCheck(builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 1); } @@ -77,18 +80,19 @@ TEST(RemoveIdentity, remove_multiple) { // | // NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {32, 16}, "x"); - auto y = builder.CreateInput(Float(32), {32, 16}, "y"); + auto x = builder.CreateInput(Float(32), {32, 16}, "x"); + auto y = builder.CreateInput(Float(32), {32, 16}, "y"); auto identity_1 = builder.Identity(x); auto identity_2 = builder.Identity(identity_1); auto identity_3 = builder.Identity(identity_2); - auto mul_1 = builder.Add(identity_3, y); + auto mul_1 = builder.Add(identity_3, y); PassTest tester; - std::vector input_names = {x.id().data(), y.id().data()}; - std::vector output_names = {mul_1->id}; + std::vector input_names = {x.id().data(), y.id().data()}; + std::vector output_names = {mul_1->id}; std::vector program_passes = {"RemoveIdentity"}; - int num_removed_ops = tester.RunAndCheck(builder, program_passes, input_names, output_names); + int num_removed_ops = + tester.RunAndCheck(builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 3); } @@ -105,18 +109,19 @@ TEST(RemoveIdentity, cannot_remove_fetch) { // | // NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {32, 16}, "x"); - auto y = builder.CreateInput(Float(32), {32, 16}, "y"); - auto relu_1 = builder.Relu(x); + auto x = builder.CreateInput(Float(32), {32, 16}, "x"); + auto y = builder.CreateInput(Float(32), {32, 16}, "y"); + auto relu_1 = builder.Relu(x); auto identity_1 = builder.Identity(relu_1); auto identity_2 = builder.Identity(identity_1); - auto mul_1 = builder.Add(identity_2, y); + auto mul_1 = builder.Add(identity_2, y); PassTest tester; - std::vector input_names = {x.id().data(), y.id().data()}; - std::vector output_names = {identity_2->id, mul_1->id}; + std::vector input_names = {x.id().data(), y.id().data()}; + std::vector output_names = {identity_2->id, mul_1->id}; std::vector program_passes = {"RemoveIdentity"}; - int num_removed_ops = tester.RunAndCheck(builder, program_passes, input_names, output_names); + int num_removed_ops = + tester.RunAndCheck(builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 1); } diff --git a/paddle/cinn/frontend/pass/test_helper.h b/paddle/cinn/frontend/pass/test_helper.h index 00f46012f3b00..ba5e0058e8e88 100644 --- a/paddle/cinn/frontend/pass/test_helper.h +++ b/paddle/cinn/frontend/pass/test_helper.h @@ -42,9 +42,11 @@ std::vector GeneratedRandomVector(size_t numel) { } template -void CopyFromVector(const std::vector& src, hlir::framework::Tensor tensor, Target target) { +void CopyFromVector(const std::vector& src, + hlir::framework::Tensor tensor, + Target target) { size_t numel = tensor->shape().numel(); - auto* dst = tensor->mutable_data(target); + auto* dst = tensor->mutable_data(target); #ifdef CINN_WITH_CUDA cudaMemcpy(dst, src.data(), numel * sizeof(T), cudaMemcpyHostToDevice); @@ -56,7 +58,7 @@ void CopyFromVector(const std::vector& src, hlir::framework::Tensor tensor, T template std::vector CopyToVector(const hlir::framework::Tensor tensor) { size_t numel = tensor->shape().numel(); - auto* src = tensor->data(); + auto* src = tensor->data(); std::vector dst(numel); #ifdef CINN_WITH_CUDA @@ -81,14 +83,17 @@ class PassTest { CHECK(IsValid(program)) << "The origin program is not valid."; int origin_program_size = program.size(); LOG(INFO) << "Run origin program"; - std::unordered_map> origin_outputs = Execute(program, input_names, output_names); + std::unordered_map> origin_outputs = + Execute(program, input_names, output_names); - std::unordered_set fetch_var_ids(output_names.begin(), output_names.end()); + std::unordered_set fetch_var_ids(output_names.begin(), + output_names.end()); ProgramPass::Apply(&program, fetch_var_ids, target_, program_passes); int optimized_program_size = program.size(); CHECK(IsValid(program)) << "The optimized program is not valid."; LOG(INFO) << "Run optimized program"; - std::unordered_map> optimized_outputs = Execute(program, input_names, output_names); + std::unordered_map> optimized_outputs = + Execute(program, input_names, output_names); for (auto name : output_names) { LOG(INFO) << "Check output name=" << name; @@ -100,20 +105,23 @@ class PassTest { } protected: - std::unordered_map> Execute(const Program& program, - const std::vector& input_names, - const std::vector& output_names) { + std::unordered_map> Execute( + const Program& program, + const std::vector& input_names, + const std::vector& output_names) { LOG(INFO) << program; - std::unordered_set fetch_var_ids(output_names.begin(), output_names.end()); - auto graph = std::make_shared(program, fetch_var_ids, target_); + std::unordered_set fetch_var_ids(output_names.begin(), + output_names.end()); + auto graph = std::make_shared( + program, fetch_var_ids, target_); hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses()); auto scope = hlir::framework::BuildScope(target_, graph); hlir::framework::GraphCompiler gc(target_, scope, graph); hlir::framework::GraphCompiler::CompileOptions options; options.with_instantiate_variables = true; - auto result = gc.Build(options, std::move(fetch_var_ids)); - auto runtime_program = std::move(result.runtime_program); + auto result = gc.Build(options, std::move(fetch_var_ids)); + auto runtime_program = std::move(result.runtime_program); for (auto& name : input_names) { SetInputTensor(name, scope); @@ -122,26 +130,29 @@ class PassTest { std::unordered_map> outputs; for (auto& name : output_names) { - auto tensor = scope->GetTensor(name); + auto tensor = scope->GetTensor(name); std::vector vec = CopyToVector(tensor); outputs.emplace(name, vec); } return outputs; } - void SetInputTensor(const std::string& name, std::shared_ptr scope) { + void SetInputTensor(const std::string& name, + std::shared_ptr scope) { scope->Var(name); auto tensor = scope->GetTensor(name); if (!inputs_.count(name)) { - std::vector vec = GeneratedRandomVector(tensor->shape().numel()); + std::vector vec = + GeneratedRandomVector(tensor->shape().numel()); inputs_.emplace(name, vec); } auto iter = inputs_.find(name); CopyFromVector(iter->second, tensor, target_); } - void CheckOutput(const std::vector& actual, const std::vector& expect) { + void CheckOutput(const std::vector& actual, + const std::vector& expect) { CHECK_EQ(actual.size(), expect.size()); for (size_t i = 0; i < expect.size(); ++i) { ASSERT_FLOAT_EQ(actual[i], expect[i]); @@ -168,7 +179,8 @@ class PassTest { // The inputs should be feeded, or other instructions' output. for (auto& var : instr->inputs) { if (!inputs.count(var->id) && !outputs.count(var->id)) { - LOG(INFO) << "The input " << var->id << " of " << i << "-th instrution (" << instr + LOG(INFO) << "The input " << var->id << " of " << i + << "-th instrution (" << instr << ") is not the output of any other instructions."; valid = false; } diff --git a/paddle/cinn/frontend/pass/transpose_collapsing.cc b/paddle/cinn/frontend/pass/transpose_collapsing.cc index 37055a48a13b5..ecf71ae55a0aa 100644 --- a/paddle/cinn/frontend/pass/transpose_collapsing.cc +++ b/paddle/cinn/frontend/pass/transpose_collapsing.cc @@ -30,22 +30,31 @@ using cinn::utils::ShapeType; class TransposeKey { public: - TransposeKey(const std::string& input_id, const ShapeType& axis) { SetKey(input_id, axis); } + TransposeKey(const std::string& input_id, const ShapeType& axis) { + SetKey(input_id, axis); + } void SetKey(const std::string& input_id, const ShapeType& axis) { input_id_ = input_id; - axis_ = axis; + axis_ = axis; } - bool operator==(const TransposeKey& other) const { return axis_ == other.axis_ && input_id_ == other.input_id_; } - bool operator!=(const TransposeKey& other) const { return !this->operator==(other); } + bool operator==(const TransposeKey& other) const { + return axis_ == other.axis_ && input_id_ == other.input_id_; + } + bool operator!=(const TransposeKey& other) const { + return !this->operator==(other); + } struct Hash { size_t operator()(const TransposeKey& key) const { std::string ret; ret.append(key.input_id_); - std::for_each(key.axis_.begin(), key.axis_.end(), [&](const DimType& dim) { ret.append(std::to_string(dim)); }); + std::for_each( + key.axis_.begin(), key.axis_.end(), [&](const DimType& dim) { + ret.append(std::to_string(dim)); + }); return std::hash()(ret); } @@ -61,7 +70,8 @@ class TransposeCollapsingPass : public ProgramPass { public: using ProgramPass::ProgramPass; using OutputToOpMap = std::unordered_map; - using InputToOpMap = std::unordered_map>; + using InputToOpMap = + std::unordered_map>; protected: void Clear() override {} @@ -91,7 +101,8 @@ class TransposeCollapsingPass : public ProgramPass { // the useless transpose op need to remove from program std::unordered_set remove_instrs; - FoldingTransposeVertical(all_transpose, fetch_ids, in2instr, out2instr, &remove_instrs); + FoldingTransposeVertical( + all_transpose, fetch_ids, in2instr, out2instr, &remove_instrs); for (auto instr : remove_instrs) { if (all_transpose.count(instr)) { @@ -99,9 +110,10 @@ class TransposeCollapsingPass : public ProgramPass { } } // TODO(thisjiang): reopen after CINN support recompute for performance - // due to recompute unsupported, if the op output to two group, it will also create a new group, - // so that the horizontal fuse will not improve performance. - // FoldingTransposeHorizontal(all_transpose, fetch_ids, in2instr, out2instr, &remove_instrs); + // due to recompute unsupported, if the op output to two group, it will also + // create a new group, so that the horizontal fuse will not improve + // performance. FoldingTransposeHorizontal(all_transpose, fetch_ids, + // in2instr, out2instr, &remove_instrs); NetBuilder builder("transpose_collapsing_builder"); for (auto& var : program->GetInputs()) { @@ -116,80 +128,90 @@ class TransposeCollapsingPass : public ProgramPass { } private: - void FoldingTransposeVertical(const std::unordered_set& all_transpose, - const std::unordered_set& fetch_ids, - const InputToOpMap& in2instr, - const OutputToOpMap& out2instr, - std::unordered_set* remove_instrs) const { + void FoldingTransposeVertical( + const std::unordered_set& all_transpose, + const std::unordered_set& fetch_ids, + const InputToOpMap& in2instr, + const OutputToOpMap& out2instr, + std::unordered_set* remove_instrs) const { if (all_transpose.size() == 1) { return; } // the transpose op should not remove std::unordered_set visited_instrs; for (auto transpose : all_transpose) { - if (("transpose" != (*transpose)->op_type) || visited_instrs.count(transpose)) { + if (("transpose" != (*transpose)->op_type) || + visited_instrs.count(transpose)) { // the transpose op had been fused, skip continue; } // Fuse transpose from front to back, the fuse path is unique auto first_transpose = FindFirstTranspose(transpose, out2instr); - TryFuseTranspose(first_transpose, fetch_ids, in2instr, remove_instrs, &visited_instrs); + TryFuseTranspose( + first_transpose, fetch_ids, in2instr, remove_instrs, &visited_instrs); } } - Instruction* FindFirstTranspose(Instruction* transpose, const OutputToOpMap& out2instr) const { + Instruction* FindFirstTranspose(Instruction* transpose, + const OutputToOpMap& out2instr) const { auto first_transpose = transpose; auto input_name = (*first_transpose)->inputs.front()->id; // Q: Why check whether transpose's input in out2instr ? - // A: The input may be the input of the graph other than another op's output. + // A: The input may be the input of the graph other than another op's + // output. // Obviously, the transpose op is the first transpose in the situation. while (out2instr.count(input_name)) { auto instr = out2instr.at(input_name); if ("transpose" != (*instr)->op_type) { - // if input of transpose is not output of another transpose, it is the first transpose. + // if input of transpose is not output of another transpose, it is the + // first transpose. break; } - input_name = (*instr)->inputs.front()->id; + input_name = (*instr)->inputs.front()->id; first_transpose = instr; } return first_transpose; } - void TryFuseTranspose(Instruction* transpose, - const std::unordered_set& fetch_ids, - const InputToOpMap& in2instr, - std::unordered_set* remove_instrs, - std::unordered_set* visited_instrs) const { + void TryFuseTranspose( + Instruction* transpose, + const std::unordered_set& fetch_ids, + const InputToOpMap& in2instr, + std::unordered_set* remove_instrs, + std::unordered_set* visited_instrs) const { visited_instrs->insert(transpose); - const auto& input = (*transpose)->inputs.front(); + const auto& input = (*transpose)->inputs.front(); const auto& input_name = input->id; - const auto& output = (*transpose)->outputs.front(); + const auto& output = (*transpose)->outputs.front(); const auto& output_name = output->id; const auto& axis = transpose->GetAttrs("axis"); CHECK_EQ(axis.size(), input->shape.size()) - << "The transpose's axis size should equal with input variable's shape size, but the transpose of [" + << "The transpose's axis size should equal with input variable's shape " + "size, but the transpose of [" << input->id << "] not ! Please check."; bool can_remove = !fetch_ids.count(output_name); if (CheckTransposeBorder(transpose, in2instr)) { if (can_remove) { - VLOG(4) << "The transpose op {input[" << input_name << "], output[" << output_name << "], axis[" - << cinn::utils::Join(axis, ",") << "]} is a output op of graph, connot fuse, remove."; + VLOG(4) << "The transpose op {input[" << input_name << "], output[" + << output_name << "], axis[" << cinn::utils::Join(axis, ",") + << "]} is a output op of graph, connot fuse, remove."; // this transpose not used by any other op, remove remove_instrs->insert(transpose); } else { if (CheckTransposeUseless(axis)) { - VLOG(4) << "The transpose op {input[" << input_name << "], output[" << output_name << "], axis[" - << cinn::utils::Join(axis, ",") << "]} is fetched but useless, replace with identity."; - // cannot remove, however, the transpsoe is useless, we can replace the transpose with indentiy for more - // fusion opportunity + VLOG(4) << "The transpose op {input[" << input_name << "], output[" + << output_name << "], axis[" << cinn::utils::Join(axis, ",") + << "]} is fetched but useless, replace with identity."; + // cannot remove, however, the transpsoe is useless, we can replace + // the transpose with indentiy for more fusion opportunity ReplaceWithIdentity(transpose); } // else the transpsoe is fetched and helpful, ignore @@ -201,13 +223,16 @@ class TransposeCollapsingPass : public ProgramPass { const auto& out_instrs = in2instr.at(output_name); if (CheckTransposeUseless(axis)) { if (!can_remove) { - VLOG(4) << "The transpose op {input[" << input_name << "], output[" << output_name << "], axis[" - << cinn::utils::Join(axis, ",") << "]} is useless but fetched, replace with identity."; - // cannot remove, but we can replace the transpose with indentiy for more fusion opportunity + VLOG(4) << "The transpose op {input[" << input_name << "], output[" + << output_name << "], axis[" << cinn::utils::Join(axis, ",") + << "]} is useless but fetched, replace with identity."; + // cannot remove, but we can replace the transpose with indentiy for + // more fusion opportunity ReplaceWithIdentity(transpose); } else { - VLOG(4) << "The transpose op {input[" << input_name << "], output[" << output_name << "], axis[" - << cinn::utils::Join(axis, ",") << "]} is useless, remove."; + VLOG(4) << "The transpose op {input[" << input_name << "], output[" + << output_name << "], axis[" << cinn::utils::Join(axis, ",") + << "]} is useless, remove."; for (auto instr : out_instrs) { // replace the input to transpose's input ReplaceInputVariable(instr, output_name, input); @@ -217,7 +242,8 @@ class TransposeCollapsingPass : public ProgramPass { for (auto instr : out_instrs) { if ("transpose" == (*instr)->op_type) { // if the next instruction is transpose op, continue fuse - TryFuseTranspose(instr, fetch_ids, in2instr, remove_instrs, visited_instrs); + TryFuseTranspose( + instr, fetch_ids, in2instr, remove_instrs, visited_instrs); } } } @@ -225,8 +251,9 @@ class TransposeCollapsingPass : public ProgramPass { } if (!CheckOutputContainTranspose(transpose, in2instr)) { - VLOG(4) << "The transpose op {input[" << input_name << "], output[" << output_name << "], axis[" - << cinn::utils::Join(axis, ",") << "]} doesn't has output link to transpose, skip."; + VLOG(4) << "The transpose op {input[" << input_name << "], output[" + << output_name << "], axis[" << cinn::utils::Join(axis, ",") + << "]} doesn't has output link to transpose, skip."; return; } @@ -236,8 +263,9 @@ class TransposeCollapsingPass : public ProgramPass { if ("transpose" != (*instr)->op_type) { // the transpose was used by other non-transpose op, cannot remove, skip can_remove = false; - VLOG(4) << "Fuse transpose of {input[" << input_name << "], output[" << output_name << "], axis [" - << cinn::utils::Join(axis, ",") << "]} was used by " << (*instr)->op_type << ", cannot remove."; + VLOG(4) << "Fuse transpose of {input[" << input_name << "], output[" + << output_name << "], axis [" << cinn::utils::Join(axis, ",") + << "]} was used by " << (*instr)->op_type << ", cannot remove."; continue; } @@ -246,14 +274,18 @@ class TransposeCollapsingPass : public ProgramPass { // step | axis | after_transpose // 1 | [0, 2, 1] | [0, 2, 1] // 2 | [2, 1, 0] | [1, 2, 0] - // so we can fuse tranpose([0, 2, 1]) and tranpose([2, 1, 0]) into tranpose([1, 2, 0]) + // so we can fuse tranpose([0, 2, 1]) and tranpose([2, 1, 0]) into + // tranpose([1, 2, 0]) const auto& fused_axis = FuseTransposeAxis(axis, next_axis); - VLOG(4) << "Fuse transpose of {input[" << input_name << "], output[" << output_name << "], axis [" - << cinn::utils::Join(axis, ",") << "]} and transpose of {input[" << (*instr)->inputs.front()->id - << "], output[" << (*instr)->outputs.front()->id << "], axis [" << cinn::utils::Join(next_axis, ",") - << "]} into transpose of {input[" << input_name << "], output[" << (*instr)->outputs.front()->id - << "], axis[" << cinn::utils::Join(fused_axis, ",") << "]}."; + VLOG(4) << "Fuse transpose of {input[" << input_name << "], output[" + << output_name << "], axis [" << cinn::utils::Join(axis, ",") + << "]} and transpose of {input[" << (*instr)->inputs.front()->id + << "], output[" << (*instr)->outputs.front()->id << "], axis [" + << cinn::utils::Join(next_axis, ",") + << "]} into transpose of {input[" << input_name << "], output[" + << (*instr)->outputs.front()->id << "], axis[" + << cinn::utils::Join(fused_axis, ",") << "]}."; auto fused_transpose = FuseTransposeImpl(transpose, instr, fused_axis); @@ -261,25 +293,30 @@ class TransposeCollapsingPass : public ProgramPass { } if (can_remove) { - VLOG(4) << "Remove transpose of {input[" << input_name << "], output[" << output_name << "], axis [" - << cinn::utils::Join(axis, ",") << "]}."; + VLOG(4) << "Remove transpose of {input[" << input_name << "], output[" + << output_name << "], axis [" << cinn::utils::Join(axis, ",") + << "]}."; remove_instrs->insert(transpose); } for (auto instr : next_fused_instrs) { - TryFuseTranspose(instr, fetch_ids, in2instr, remove_instrs, visited_instrs); + TryFuseTranspose( + instr, fetch_ids, in2instr, remove_instrs, visited_instrs); } } - // check whether the op is the border op of graph, in other words, its output var was not - // used by any op in graph. - bool CheckTransposeBorder(Instruction* transpose, const InputToOpMap& in2instr) const { + // check whether the op is the border op of graph, in other words, its output + // var was not used by any op in graph. + bool CheckTransposeBorder(Instruction* transpose, + const InputToOpMap& in2instr) const { const auto& output_name = (*transpose)->outputs.front()->id; return !in2instr.count(output_name) || in2instr.at(output_name).empty(); } - // check whether the op's output ops has transpose, if not, no transpose need folding - bool CheckOutputContainTranspose(Instruction* transpose, const InputToOpMap& in2instr) const { + // check whether the op's output ops has transpose, if not, no transpose need + // folding + bool CheckOutputContainTranspose(Instruction* transpose, + const InputToOpMap& in2instr) const { const auto& output_name = (*transpose)->outputs.front()->id; for (auto instr : in2instr.at(output_name)) { if ("transpose" == (*instr)->op_type) { @@ -290,7 +327,8 @@ class TransposeCollapsingPass : public ProgramPass { return false; } - // if the transpose axis like {0, 1, 2, 3, 4, 5}, the transpose is useless, should remove + // if the transpose axis like {0, 1, 2, 3, 4, 5}, the transpose is useless, + // should remove bool CheckTransposeUseless(const ShapeType& axis) const { for (int i = 0; i < axis.size(); ++i) { if (axis[i] != i) { @@ -300,17 +338,23 @@ class TransposeCollapsingPass : public ProgramPass { return true; } - // replace the op's input variable whose name is `old_input_name` to `new_input`, note we need keep the input list - // order - void ReplaceInputVariable(Instruction* op, const std::string& old_input_name, const Variable& new_input) const { + // replace the op's input variable whose name is `old_input_name` to + // `new_input`, note we need keep the input list order + void ReplaceInputVariable(Instruction* op, + const std::string& old_input_name, + const Variable& new_input) const { auto find_input = [&](const std::string& input_name) { return std::find_if( - (*op)->inputs.begin(), (*op)->inputs.end(), [&](const Variable& v) { return input_name == v->id; }); + (*op)->inputs.begin(), (*op)->inputs.end(), [&](const Variable& v) { + return input_name == v->id; + }); }; // Why Loop : To avoid the op's inputs are the same variable ! - for (auto it = find_input(old_input_name); it != (*op)->inputs.end(); it = find_input(old_input_name)) { - // erase previous fill_constant output var and replace to new fill_constant output var + for (auto it = find_input(old_input_name); it != (*op)->inputs.end(); + it = find_input(old_input_name)) { + // erase previous fill_constant output var and replace to new + // fill_constant output var auto next_it = (*op)->inputs.erase(it); // keep the input place same, it's very important (*op)->inputs.insert(next_it, new_input); @@ -324,10 +368,13 @@ class TransposeCollapsingPass : public ProgramPass { return op; } - // compute the fused axis of `old_axis` and `new_axis`, like [0, 2, 1] + [2, 1, 0] = [1, 2, 0] - ShapeType FuseTransposeAxis(const ShapeType& old_axis, const ShapeType& new_axis) const { + // compute the fused axis of `old_axis` and `new_axis`, like [0, 2, 1] + [2, + // 1, 0] = [1, 2, 0] + ShapeType FuseTransposeAxis(const ShapeType& old_axis, + const ShapeType& new_axis) const { CHECK_EQ(old_axis.size(), new_axis.size()) - << "The transpose axis size should be " << old_axis.size() << ", but here " << new_axis.size(); + << "The transpose axis size should be " << old_axis.size() + << ", but here " << new_axis.size(); ShapeType axis = old_axis; for (int i = 0; i < new_axis.size(); ++i) { @@ -336,28 +383,35 @@ class TransposeCollapsingPass : public ProgramPass { return axis; } - // fuse the two transpose axis into the second transpose, replace its input and axis - Instruction* FuseTransposeImpl(Instruction* transpose1, Instruction* transpose2, const ShapeType& fused_axis) const { + // fuse the two transpose axis into the second transpose, replace its input + // and axis + Instruction* FuseTransposeImpl(Instruction* transpose1, + Instruction* transpose2, + const ShapeType& fused_axis) const { (*transpose2)->inputs.front() = (*transpose1)->inputs.front(); transpose2->SetAttr("axis", fused_axis); return transpose2; } - // if the transposes have the same input and axis, they can folding into one, the redundance should remove - void FoldingTransposeHorizontal(const std::unordered_set& all_transpose, - const std::unordered_set& fetch_ids, - const InputToOpMap& in2instr, - const OutputToOpMap& out2instr, - std::unordered_set* remove_instrs) const { - std::unordered_map first_transpose_map; + // if the transposes have the same input and axis, they can folding into one, + // the redundance should remove + void FoldingTransposeHorizontal( + const std::unordered_set& all_transpose, + const std::unordered_set& fetch_ids, + const InputToOpMap& in2instr, + const OutputToOpMap& out2instr, + std::unordered_set* remove_instrs) const { + std::unordered_map + first_transpose_map; for (auto transpose : all_transpose) { - if (("transpose" != (*transpose)->op_type) || remove_instrs->count(transpose)) { + if (("transpose" != (*transpose)->op_type) || + remove_instrs->count(transpose)) { continue; } - const auto& input_id = (*transpose)->inputs.front()->id; + const auto& input_id = (*transpose)->inputs.front()->id; const auto& output_id = (*transpose)->outputs.front()->id; - const auto& axis = transpose->GetAttrs("axis"); + const auto& axis = transpose->GetAttrs("axis"); TransposeKey key(input_id, axis); if (!first_transpose_map.count(key)) { @@ -369,7 +423,8 @@ class TransposeCollapsingPass : public ProgramPass { if (fetch_ids.find(output_id) != fetch_ids.end()) { // the transpose's output variable was fetched, skip - VLOG(4) << "Cannot remove transpose, because the output [" << output_id << "] was fetched by other op ! "; + VLOG(4) << "Cannot remove transpose, because the output [" << output_id + << "] was fetched by other op ! "; continue; } @@ -387,7 +442,8 @@ class TransposeCollapsingPass : public ProgramPass { } // namespace cinn::frontend::pass CINN_REGISTER_HELPER(TransposeCollapsing) { - CINN_REGISTER_PROGRAM_PASS(TransposeCollapsing, ::cinn::frontend::pass::TransposeCollapsingPass); + CINN_REGISTER_PROGRAM_PASS(TransposeCollapsing, + ::cinn::frontend::pass::TransposeCollapsingPass); return true; } diff --git a/paddle/cinn/frontend/pass/transpose_collapsing_test.cc b/paddle/cinn/frontend/pass/transpose_collapsing_test.cc index 0809661544ea1..2c1e88ecca12d 100644 --- a/paddle/cinn/frontend/pass/transpose_collapsing_test.cc +++ b/paddle/cinn/frontend/pass/transpose_collapsing_test.cc @@ -38,22 +38,27 @@ void SetInputData(const hlir::framework::Tensor& tensor, Target target) { } #ifdef CINN_WITH_CUDA if (target == common::DefaultNVGPUTarget()) { - cudaMemcpy(data, host_memory.data(), tensor->shape().numel() * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(data, + host_memory.data(), + tensor->shape().numel() * sizeof(float), + cudaMemcpyHostToDevice); return; } #endif CHECK(target == common::DefaultHostTarget()); std::copy(host_memory.begin(), host_memory.end(), data); } -std::vector> RunWithProgram(const Program& program, - const Target& target, - const std::vector& input_names, - const std::vector& out_ids) { +std::vector> RunWithProgram( + const Program& program, + const Target& target, + const std::vector& input_names, + const std::vector& out_ids) { std::unordered_set fetch_list; for (auto id : out_ids) { fetch_list.insert(id); } - auto graph = std::make_shared(program, fetch_list, target); + auto graph = + std::make_shared(program, fetch_list, target); auto scope = hlir::framework::BuildScope(target, graph); for (const auto& in_name : input_names) { @@ -69,18 +74,19 @@ std::vector> RunWithProgram(const Program& program, std::vector> outputs; for (const auto& out_id : out_ids) { - outputs.emplace_back(GetTensorData(scope->GetTensor(out_id), target)); + outputs.emplace_back( + GetTensorData(scope->GetTensor(out_id), target)); } return outputs; } TEST(TransposeCollapsing, FuseTwoTranspose) { NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto x_t = builder.Transpose(x, {0, 2, 1}); - auto out = builder.Transpose(x_t, {2, 1, 0}); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto x_t = builder.Transpose(x, {0, 2, 1}); + auto out = builder.Transpose(x_t, {2, 1, 0}); auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); std::initializer_list fetch_list = {out->id}; @@ -114,12 +120,12 @@ TEST(TransposeCollapsing, FuseTwoTranspose) { TEST(TransposeCollapsing, FuseThreeTranspose) { NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto x_1t = builder.Transpose(x, {0, 2, 1}); - auto x_2t = builder.Transpose(x_1t, {2, 1, 0}); - auto out = builder.Transpose(x_2t, {1, 2, 0}); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto x_1t = builder.Transpose(x, {0, 2, 1}); + auto x_2t = builder.Transpose(x_1t, {2, 1, 0}); + auto out = builder.Transpose(x_2t, {1, 2, 0}); auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); std::initializer_list fetch_list = {out->id}; @@ -154,11 +160,11 @@ TEST(TransposeCollapsing, FuseThreeTranspose) { TEST(TransposeCollapsing, RemoveUselessTranspose) { NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto x_t = builder.Transpose(x, {0, 1, 2}); - auto out = builder.Add(x, x_t); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto x_t = builder.Transpose(x, {0, 1, 2}); + auto out = builder.Add(x, x_t); auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); std::initializer_list fetch_list = {out->id}; @@ -190,10 +196,10 @@ TEST(TransposeCollapsing, RemoveUselessTranspose) { TEST(TransposeCollapsing, ReplaceUselessTransposeWithIndentity) { NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto out = builder.Transpose(x, {0, 1, 2}); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto out = builder.Transpose(x, {0, 1, 2}); auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); std::initializer_list fetch_list = {out->id}; @@ -227,13 +233,13 @@ TEST(TransposeCollapsing, ReplaceUselessTransposeWithIndentity) { TEST(TransposeCollapsing, FuseTransposeToUseless) { NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto x_1t = builder.Transpose(x, {0, 2, 1}); - auto x_2t = builder.Transpose(x_1t, {0, 2, 1}); - auto x_3t = builder.Transpose(x_2t, {0, 2, 1}); - auto out = builder.Add(x_3t, x_3t); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto x_1t = builder.Transpose(x, {0, 2, 1}); + auto x_2t = builder.Transpose(x_1t, {0, 2, 1}); + auto x_3t = builder.Transpose(x_2t, {0, 2, 1}); + auto out = builder.Add(x_3t, x_3t); auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); std::initializer_list fetch_list = {out->id}; @@ -270,17 +276,18 @@ TEST(TransposeCollapsing, FuseTransposeToUseless) { TEST(TransposeCollapsing, FuseTransposeWithMultiOutput) { NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto x_1t = builder.Transpose(x, {0, 2, 1}); - auto x_2t = builder.Transpose(x_1t, {2, 0, 1}); - auto x_3t = builder.Transpose(x_2t, {2, 1, 0}); - auto out1 = builder.Sqrt(x_1t); - auto out2 = builder.Sqrt(x_2t); - auto out3 = builder.Sqrt(x_3t); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto x_1t = builder.Transpose(x, {0, 2, 1}); + auto x_2t = builder.Transpose(x_1t, {2, 0, 1}); + auto x_3t = builder.Transpose(x_2t, {2, 1, 0}); + auto out1 = builder.Sqrt(x_1t); + auto out2 = builder.Sqrt(x_2t); + auto out3 = builder.Sqrt(x_3t); auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); - std::initializer_list fetch_list = {out1->id, out2->id, out3->id}; + std::initializer_list fetch_list = { + out1->id, out2->id, out3->id}; size_t origin_size = program.size(); VLOG(1) << "Program before pass:\n" << program; @@ -321,15 +328,15 @@ TEST(TransposeCollapsing, FuseTransposeWithMultiOutput) { TEST(TransposeCollapsing, FuseTwoSecTranspose) { NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto x_1t = builder.Transpose(x, {0, 2, 1}); - auto x_2t = builder.Transpose(x_1t, {2, 1, 0}); - auto out1 = builder.Reshape(x_2t, {5, 3, 4}); - auto x_3t = builder.Transpose(out1, {0, 2, 1}); - auto x_4t = builder.Transpose(x_3t, {2, 1, 0}); - auto out2 = builder.Sqrt(x_4t); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto x_1t = builder.Transpose(x, {0, 2, 1}); + auto x_2t = builder.Transpose(x_1t, {2, 1, 0}); + auto out1 = builder.Reshape(x_2t, {5, 3, 4}); + auto x_3t = builder.Transpose(out1, {0, 2, 1}); + auto x_4t = builder.Transpose(x_3t, {2, 1, 0}); + auto out2 = builder.Sqrt(x_4t); auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); std::initializer_list fetch_list = {out1->id, out2->id}; @@ -370,12 +377,12 @@ TEST(TransposeCollapsing, FuseTwoSecTranspose) { TEST(TransposeCollapsing, FuseTwoHorizontalTranspose) { NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto y_t1 = builder.Transpose(x, {0, 2, 1}); - auto y_t2 = builder.Transpose(x, {0, 2, 1}); - auto out = builder.Add(y_t1, y_t2); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto y_t1 = builder.Transpose(x, {0, 2, 1}); + auto y_t2 = builder.Transpose(x, {0, 2, 1}); + auto out = builder.Add(y_t1, y_t2); auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); std::initializer_list fetch_list = {out->id}; @@ -411,13 +418,13 @@ TEST(TransposeCollapsing, FuseTwoHorizontalTranspose) { TEST(TransposeCollapsing, FuseVerAndHorTranspose) { NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto y_t1 = builder.Transpose(x, {0, 2, 1}); - auto y_t2 = builder.Transpose(y_t1, {2, 1, 0}); - auto y_t3 = builder.Transpose(x, {1, 2, 0}); - auto out = builder.Add(y_t2, y_t3); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto y_t1 = builder.Transpose(x, {0, 2, 1}); + auto y_t2 = builder.Transpose(y_t1, {2, 1, 0}); + auto y_t3 = builder.Transpose(x, {1, 2, 0}); + auto out = builder.Add(y_t2, y_t3); auto program = builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); std::initializer_list fetch_list = {out->id}; diff --git a/paddle/cinn/frontend/pass/transpose_folding_base.h b/paddle/cinn/frontend/pass/transpose_folding_base.h index 489b520a1b494..fbcb384b1704d 100644 --- a/paddle/cinn/frontend/pass/transpose_folding_base.h +++ b/paddle/cinn/frontend/pass/transpose_folding_base.h @@ -28,14 +28,18 @@ namespace cinn::frontend::pass { class TransposeFoldingBase : public ProgramPass { public: using ProgramPass::ProgramPass; - using In2InstrType = absl::flat_hash_map>; + using In2InstrType = + absl::flat_hash_map>; using Out2InstrType = absl::flat_hash_map; protected: virtual void set_target_instrs() = 0; // the ops which can folding into matmul - void set_fold_instrs() { fold_instrs_ = {"transpose", "scale", "broadcast_to"}; } - // the ops which cannot folding but can ignore when it place between into folding op and matmul + void set_fold_instrs() { + fold_instrs_ = {"transpose", "scale", "broadcast_to"}; + } + // the ops which cannot folding but can ignore when it place between into + // folding op and matmul void set_skip_instrs() { skip_instrs_ = {"cast", "identity"}; } void Clear() override { @@ -64,12 +68,14 @@ class TransposeFoldingBase : public ProgramPass { } } - // `remove_instrs` is used to represent Instructions of which type is transpose to be deleted. + // `remove_instrs` is used to represent Instructions of which type is + // transpose to be deleted. absl::flat_hash_set remove_instrs; for (size_t i = 0; i < program->size(); ++i) { auto& instr = (*program)[i]; if (target_instrs_.count(instr->op_type)) { - DoMatmulFoldOptimize(&instr, out2instr, in2instr, fetch_ids, &remove_instrs); + DoMatmulFoldOptimize( + &instr, out2instr, in2instr, fetch_ids, &remove_instrs); } } @@ -91,21 +97,26 @@ class TransposeFoldingBase : public ProgramPass { const Out2InstrType& out2instr, const In2InstrType& in2instr, bool from_input) const { - if (!fold_instrs_.count((*instr)->op_type) && !skip_instrs_.count((*instr)->op_type)) { + if (!fold_instrs_.count((*instr)->op_type) && + !skip_instrs_.count((*instr)->op_type)) { return {}; } - CHECK_EQ((*instr)->inputs.size(), 1UL) << "The op " << (*instr)->op_type << " should has 1 input."; - CHECK_EQ((*instr)->outputs.size(), 1UL) << "The op " << (*instr)->op_type << " should has 1 output."; + CHECK_EQ((*instr)->inputs.size(), 1UL) + << "The op " << (*instr)->op_type << " should has 1 input."; + CHECK_EQ((*instr)->outputs.size(), 1UL) + << "The op " << (*instr)->op_type << " should has 1 output."; - VLOG(5) << "Try get matmul's folding instructions begin from [" << (*instr)->inputs[0]->id << "]"; + VLOG(5) << "Try get matmul's folding instructions begin from [" + << (*instr)->inputs[0]->id << "]"; if (!from_input && in2instr.at((*instr)->inputs[0]->id).size() != 1UL) { // the matmul's output should only link to one op - VLOG(5) << "The var [" << (*instr)->inputs[0]->id << "] link to many op, cannot fold into matmul! Please check."; + VLOG(5) << "The var [" << (*instr)->inputs[0]->id + << "] link to many op, cannot fold into matmul! Please check."; return {}; } - std::vector res = {instr}; + std::vector res = {instr}; std::unordered_set visited = {(*instr)->op_type}; auto cur_instr = instr; @@ -141,18 +152,21 @@ class TransposeFoldingBase : public ProgramPass { return res; } - bool CanFold(const Instruction* instr, const std::unordered_set& visited_instr) const { + bool CanFold(const Instruction* instr, + const std::unordered_set& visited_instr) const { if (!instr) { return false; } const auto& instr_type = (*instr)->op_type; - if ((!fold_instrs_.count(instr_type) && !skip_instrs_.count(instr_type)) || visited_instr.count(instr_type)) { + if ((!fold_instrs_.count(instr_type) && !skip_instrs_.count(instr_type)) || + visited_instr.count(instr_type)) { return false; } if (instr_type == "transpose") { if (visited_instr.count("broadcast_to")) { - // if transpose after broadcast_to, cannot fold because shape has changed + // if transpose after broadcast_to, cannot fold because shape has + // changed return false; } } @@ -191,17 +205,22 @@ class TransposeFoldingBase : public ProgramPass { return false; } - float bias = scale->attrs.count("bias") ? absl::get(scale->attrs.at("bias")) : 0.0f; + float bias = scale->attrs.count("bias") + ? absl::get(scale->attrs.at("bias")) + : 0.0f; return (bias == 0.0f); } - bool CanSkip(const Instruction& instr) const { return skip_instrs_.count(instr->op_type); } + bool CanSkip(const Instruction& instr) const { + return skip_instrs_.count(instr->op_type); + } - virtual void DoMatmulFoldOptimize(Instruction* instr, - const Out2InstrType& out2instr, - const In2InstrType& in2instr, - const std::unordered_set& fetch_ids, - absl::flat_hash_set* remove_instrs) const = 0; + virtual void DoMatmulFoldOptimize( + Instruction* instr, + const Out2InstrType& out2instr, + const In2InstrType& in2instr, + const std::unordered_set& fetch_ids, + absl::flat_hash_set* remove_instrs) const = 0; std::unordered_set target_instrs_; std::unordered_set fold_instrs_; diff --git a/paddle/cinn/frontend/pass/transpose_folding_input.cc b/paddle/cinn/frontend/pass/transpose_folding_input.cc index da1c22d806795..e98cf2ceaf057 100644 --- a/paddle/cinn/frontend/pass/transpose_folding_input.cc +++ b/paddle/cinn/frontend/pass/transpose_folding_input.cc @@ -24,25 +24,31 @@ namespace cinn::frontend::pass { -// Pass `TransposeFoldingInput` folds transpose into dot, then both of them can be implemented by a -// GEMM kernel. For each dot operator, try folding every input that belong output of transpose. -// If output of tranpose in `fetch_ids`, keep the operator. +// Pass `TransposeFoldingInput` folds transpose into dot, then both of them can +// be implemented by a GEMM kernel. For each dot operator, try folding every +// input that belong output of transpose. If output of tranpose in `fetch_ids`, +// keep the operator. class TransposeFoldingInputPass : public TransposeFoldingBase { public: using TransposeFoldingBase::TransposeFoldingBase; protected: - void set_target_instrs() override { TransposeFoldingBase::target_instrs_ = {"matmul"}; } + void set_target_instrs() override { + TransposeFoldingBase::target_instrs_ = {"matmul"}; + } - bool IsValidBroadCast(const Instruction& broadcast, const Instruction& dot, const int input_id) const { + bool IsValidBroadCast(const Instruction& broadcast, + const Instruction& dot, + const int input_id) const { if ("broadcast_to" != broadcast->op_type) { return false; } - // check whether the output shape can infer from another input, if not, cannot remove this broadcast - int next_id = (input_id + 1) % dot->inputs.size(); + // check whether the output shape can infer from another input, if not, + // cannot remove this broadcast + int next_id = (input_id + 1) % dot->inputs.size(); const auto& next_shape = dot->inputs[next_id]->shape; - const auto& out_shape = dot->outputs[0]->shape; + const auto& out_shape = dot->outputs[0]->shape; if (next_shape.size() != out_shape.size()) { return false; @@ -56,12 +62,14 @@ class TransposeFoldingInputPass : public TransposeFoldingBase { return true; } - void DoMatmulFoldOptimize(Instruction* dot, - const Out2InstrType& out2instr, - const In2InstrType& in2instr, - const std::unordered_set& fetch_ids, - absl::flat_hash_set* remove_instrs) const override { - CHECK_EQ((*dot)->inputs.size(), 2UL) << "The matmul should only have two inputs."; + void DoMatmulFoldOptimize( + Instruction* dot, + const Out2InstrType& out2instr, + const In2InstrType& in2instr, + const std::unordered_set& fetch_ids, + absl::flat_hash_set* remove_instrs) const override { + CHECK_EQ((*dot)->inputs.size(), 2UL) + << "The matmul should only have two inputs."; auto debug_info = [](const std::vector& instrs) { std::stringstream ss; @@ -76,7 +84,8 @@ class TransposeFoldingInputPass : public TransposeFoldingBase { if (iter != out2instr.end()) { // for example: x -> scale -> y -> transpose -> z -> dot // fold_instrs = {"transpose", "scale"} - const auto& fold_instrs = GetFoldInstruction(iter->second, out2instr, in2instr, true); + const auto& fold_instrs = + GetFoldInstruction(iter->second, out2instr, in2instr, true); if (fold_instrs.empty()) { continue; @@ -92,10 +101,14 @@ class TransposeFoldingInputPass : public TransposeFoldingBase { if (IsValidTranspose(*instr)) { // fold transpose into trans_a/trans_b if (i == 0) { - bool trans_a = (*dot)->attrs.count("trans_a") ? absl::get((*dot)->attrs.at("trans_a")) : false; + bool trans_a = (*dot)->attrs.count("trans_a") + ? absl::get((*dot)->attrs.at("trans_a")) + : false; dot->SetAttr("trans_a", static_cast(trans_a ^ true)); } else if (i == 1) { - bool trans_b = (*dot)->attrs.count("trans_b") ? absl::get((*dot)->attrs.at("trans_b")) : false; + bool trans_b = (*dot)->attrs.count("trans_b") + ? absl::get((*dot)->attrs.at("trans_b")) + : false; dot->SetAttr("trans_b", static_cast(trans_b ^ true)); } else { LOG(FATAL) << "The matmul should only have two inputs."; @@ -106,9 +119,13 @@ class TransposeFoldingInputPass : public TransposeFoldingBase { } else if (IsValidScale(*instr)) { // assume C = alpha * A * B + beta * C // fold scale into alpha - float scale = (*instr)->attrs.count("scale") ? absl::get((*instr)->attrs.at("scale")) : 1.0f; + float scale = (*instr)->attrs.count("scale") + ? absl::get((*instr)->attrs.at("scale")) + : 1.0f; - float alpha = (*dot)->attrs.count("alpha") ? absl::get((*dot)->attrs.at("alpha")) : 1.0f; + float alpha = (*dot)->attrs.count("alpha") + ? absl::get((*dot)->attrs.at("alpha")) + : 1.0f; dot->SetAttr("alpha", alpha * scale); } else if (IsValidBroadCast(*instr, *dot, i)) { // nothin to do, can fold directly @@ -134,16 +151,21 @@ class TransposeFoldingInputPass : public TransposeFoldingBase { } // check whether the instruction can be removed - const auto& out_name = (*instr)->outputs[0]->id; + const auto& out_name = (*instr)->outputs[0]->id; const auto& out_instrs = in2instr.at(out_name); - bool can_remove = std::all_of(out_instrs.begin(), out_instrs.end(), [&](Instruction* out_instr) { - // the transpose had linked to not matmul op, cannot remove - return target_instrs_.count((*out_instr)->op_type) || (out_instr == next_instr); - }); + bool can_remove = std::all_of( + out_instrs.begin(), + out_instrs.end(), + [&](Instruction* out_instr) { + // the transpose had linked to not matmul op, cannot remove + return target_instrs_.count((*out_instr)->op_type) || + (out_instr == next_instr); + }); if (can_remove && !fetch_ids.count(out_name)) { - // the transpose is only link to matmul and its output is not in fetch_ids, should remove + // the transpose is only link to matmul and its output is not in + // fetch_ids, should remove remove_instrs->insert(instr); } } @@ -155,7 +177,8 @@ class TransposeFoldingInputPass : public TransposeFoldingBase { } // namespace cinn::frontend::pass CINN_REGISTER_HELPER(TransposeFoldingInput) { - CINN_REGISTER_PROGRAM_PASS(TransposeFoldingInput, ::cinn::frontend::pass::TransposeFoldingInputPass); + CINN_REGISTER_PROGRAM_PASS(TransposeFoldingInput, + ::cinn::frontend::pass::TransposeFoldingInputPass); return true; } diff --git a/paddle/cinn/frontend/pass/transpose_folding_input_test.cc b/paddle/cinn/frontend/pass/transpose_folding_input_test.cc index b9c0d188ca7be..63daa6f7d0a1f 100644 --- a/paddle/cinn/frontend/pass/transpose_folding_input_test.cc +++ b/paddle/cinn/frontend/pass/transpose_folding_input_test.cc @@ -48,11 +48,11 @@ TEST(TransposeFoldingInput, FoldIntoDotBatchedCase1) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); auto transpose_x = builder.Transpose(x, {0, 2, 1}); - auto out = builder.Matmul(transpose_x, y); - auto program = builder.Build(); + auto out = builder.Matmul(transpose_x, y); + auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; @@ -61,7 +61,10 @@ TEST(TransposeFoldingInput, FoldIntoDotBatchedCase1) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } @@ -70,11 +73,11 @@ TEST(TransposeFoldingInput, FoldIntoDotBachedCase2) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); - auto y = builder.CreateInput(Float(32), {4, 6, 5}, "Y"); + auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); + auto y = builder.CreateInput(Float(32), {4, 6, 5}, "Y"); auto transpose_y = builder.Transpose(y, {0, 2, 1}); - auto out = builder.Matmul(x, transpose_y); - auto program = builder.Build(); + auto out = builder.Matmul(x, transpose_y); + auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; @@ -83,7 +86,10 @@ TEST(TransposeFoldingInput, FoldIntoDotBachedCase2) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } @@ -92,12 +98,12 @@ TEST(TransposeFoldingInput, FoldIntoDotBachedCase3) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto y = builder.CreateInput(Float(32), {4, 6, 5}, "Y"); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto y = builder.CreateInput(Float(32), {4, 6, 5}, "Y"); auto transpose_x = builder.Transpose(x, {0, 2, 1}); auto transpose_y = builder.Transpose(y, {0, 2, 1}); - auto out = builder.Matmul(transpose_x, transpose_y); - auto program = builder.Build(); + auto out = builder.Matmul(transpose_x, transpose_y); + auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; @@ -106,7 +112,10 @@ TEST(TransposeFoldingInput, FoldIntoDotBachedCase3) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } @@ -115,11 +124,11 @@ TEST(TransposeFoldingInput, FoldIntoDotCase1) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {2, 3}, "X"); - auto y = builder.CreateInput(Float(32), {2, 3}, "Y"); + auto x = builder.CreateInput(Float(32), {2, 3}, "X"); + auto y = builder.CreateInput(Float(32), {2, 3}, "Y"); auto transpose_y = builder.Transpose(y, {1, 0}); - auto out = builder.Matmul(x, transpose_y); - auto program = builder.Build(); + auto out = builder.Matmul(x, transpose_y); + auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; @@ -128,7 +137,10 @@ TEST(TransposeFoldingInput, FoldIntoDotCase1) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } @@ -137,15 +149,15 @@ TEST(TransposeFoldingInput, FoldIntoDotCase2) { return; } NetBuilder builder("net_builder"); - auto a = builder.FillConstant({2, 20}, 2.0f, "A"); - auto b = builder.Transpose(a, {1, 0}); - auto c = builder.CreateInput(Float(32), {121, 20}, "C"); - auto d = builder.Matmul(c, b); - auto x = builder.FillConstant({2, 20}, 1.0f, "X"); - auto y = builder.Transpose(x, {1, 0}); - auto z = builder.CreateInput(Float(32), {121, 20}, "Z"); - auto q = builder.Matmul(z, y); - auto out = builder.Add(d, q); + auto a = builder.FillConstant({2, 20}, 2.0f, "A"); + auto b = builder.Transpose(a, {1, 0}); + auto c = builder.CreateInput(Float(32), {121, 20}, "C"); + auto d = builder.Matmul(c, b); + auto x = builder.FillConstant({2, 20}, 1.0f, "X"); + auto y = builder.Transpose(x, {1, 0}); + auto z = builder.CreateInput(Float(32), {121, 20}, "Z"); + auto q = builder.Matmul(z, y); + auto out = builder.Add(d, q); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -155,7 +167,10 @@ TEST(TransposeFoldingInput, FoldIntoDotCase2) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } @@ -164,11 +179,11 @@ TEST(TransposeFoldingInput, TransposeOutInFetchIds) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {2, 3}, "X"); - auto y = builder.CreateInput(Float(32), {2, 3}, "Y"); + auto x = builder.CreateInput(Float(32), {2, 3}, "X"); + auto y = builder.CreateInput(Float(32), {2, 3}, "Y"); auto transpose_y = builder.Transpose(y, {1, 0}); - auto out = builder.Matmul(x, transpose_y); - auto program = builder.Build(); + auto out = builder.Matmul(x, transpose_y); + auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; @@ -177,8 +192,18 @@ TEST(TransposeFoldingInput, TransposeOutInFetchIds) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; - CompareResult(&program, target, input_ids, {out->id, transpose_y->id}, 0, passes, 123, true); + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; + CompareResult(&program, + target, + input_ids, + {out->id, transpose_y->id}, + 0, + passes, + 123, + true); } TEST(TransposeFoldingInput, TransposeOutUsedByOtherInstrs) { @@ -186,12 +211,12 @@ TEST(TransposeFoldingInput, TransposeOutUsedByOtherInstrs) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {2, 2}, "X"); - auto y = builder.CreateInput(Float(32), {2, 2}, "Y"); + auto x = builder.CreateInput(Float(32), {2, 2}, "X"); + auto y = builder.CreateInput(Float(32), {2, 2}, "Y"); auto transpose_y = builder.Transpose(y, {1, 0}); - auto dot = builder.Matmul(x, transpose_y); - auto out = builder.Add(transpose_y, dot); - auto program = builder.Build(); + auto dot = builder.Matmul(x, transpose_y); + auto out = builder.Add(transpose_y, dot); + auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; @@ -200,7 +225,10 @@ TEST(TransposeFoldingInput, TransposeOutUsedByOtherInstrs) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 0, passes, 123, true); } @@ -213,10 +241,10 @@ TEST(TransposeFoldingInput, TransposeTwiceWithMatmul) { auto y = builder.CreateInput(Float(32), {10201, 20}, "Y"); auto z = builder.CreateInput(Float(32), {10201, 2}, "Z"); - auto x_t = builder.Transpose(x, {1, 0}); - auto x_t_t = builder.Transpose(x_t, {1, 0}); - auto dot1 = builder.Matmul(y, x_t); - auto dot2 = builder.Matmul(z, x_t_t); + auto x_t = builder.Transpose(x, {1, 0}); + auto x_t_t = builder.Transpose(x_t, {1, 0}); + auto dot1 = builder.Matmul(y, x_t); + auto dot2 = builder.Matmul(z, x_t_t); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -226,8 +254,12 @@ TEST(TransposeFoldingInput, TransposeTwiceWithMatmul) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; - CompareResult(&program, target, input_ids, {dot1->id, dot2->id}, 1, passes, 123, true); + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; + CompareResult( + &program, target, input_ids, {dot1->id, dot2->id}, 1, passes, 123, true); } TEST(TransposeFoldingInput, TransposeWithMultiMamtul) { @@ -235,13 +267,13 @@ TEST(TransposeFoldingInput, TransposeWithMultiMamtul) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {2, 2}, "X"); - auto y = builder.CreateInput(Float(32), {2, 2}, "Y"); + auto x = builder.CreateInput(Float(32), {2, 2}, "X"); + auto y = builder.CreateInput(Float(32), {2, 2}, "Y"); auto transpose_y = builder.Transpose(y, {1, 0}); - auto dot1 = builder.Matmul(x, transpose_y); - auto dot2 = builder.Matmul(transpose_y, x); - auto out = builder.Add(dot1, dot2); - auto program = builder.Build(); + auto dot1 = builder.Matmul(x, transpose_y); + auto dot2 = builder.Matmul(transpose_y, x); + auto out = builder.Add(dot1, dot2); + auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; @@ -250,7 +282,10 @@ TEST(TransposeFoldingInput, TransposeWithMultiMamtul) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } diff --git a/paddle/cinn/frontend/pass/transpose_folding_output.cc b/paddle/cinn/frontend/pass/transpose_folding_output.cc index b65fb0e7ce385..740e85d8b8df2 100644 --- a/paddle/cinn/frontend/pass/transpose_folding_output.cc +++ b/paddle/cinn/frontend/pass/transpose_folding_output.cc @@ -29,13 +29,16 @@ class TransposeFoldingOutputPass : public TransposeFoldingBase { using TransposeFoldingBase::TransposeFoldingBase; protected: - void set_target_instrs() override { TransposeFoldingBase::target_instrs_ = {"cublas_matmul"}; } + void set_target_instrs() override { + TransposeFoldingBase::target_instrs_ = {"cublas_matmul"}; + } - void DoMatmulFoldOptimize(Instruction* dot, - const Out2InstrType& out2instr, - const In2InstrType& in2instr, - const std::unordered_set& fetch_ids, - absl::flat_hash_set* remove_instrs) const override { + void DoMatmulFoldOptimize( + Instruction* dot, + const Out2InstrType& out2instr, + const In2InstrType& in2instr, + const std::unordered_set& fetch_ids, + absl::flat_hash_set* remove_instrs) const override { const auto& gemm_out_name = (*dot)->outputs[0]->id; auto debug_info = [](const std::vector& instrs) { @@ -46,11 +49,13 @@ class TransposeFoldingOutputPass : public TransposeFoldingBase { return ss.str(); }; - if (in2instr.contains(gemm_out_name) && in2instr.at(gemm_out_name).size() == 1) { + if (in2instr.contains(gemm_out_name) && + in2instr.at(gemm_out_name).size() == 1) { // for example: dot -> x -> scale -> y -> transpose -> z // fold_instrs = {"scale", "transpose"} // ensure the foldiong structions's output only link to one op - const auto& fold_instrs = GetFoldInstruction(*in2instr.at(gemm_out_name).begin(), out2instr, in2instr, false); + const auto& fold_instrs = GetFoldInstruction( + *in2instr.at(gemm_out_name).begin(), out2instr, in2instr, false); VLOG(4) << "Fold Instruction: [" << debug_info(fold_instrs) << "]" << " into output of matmul: " << *dot; @@ -61,12 +66,14 @@ class TransposeFoldingOutputPass : public TransposeFoldingBase { bool shape_has_changed = false; for (int i = fold_instrs.size() - 1; i >= 0; --i) { - auto instr = fold_instrs[i]; + auto instr = fold_instrs[i]; auto prev_instr = (i == 0) ? dot : fold_instrs[i - 1]; if (IsValidTranspose(*instr)) { // As for cublas_matmul, we can continue to set the `trans_out` attr. - bool trans_out = (*dot)->attrs.count("trans_out") ? absl::get((*dot)->attrs.at("trans_out")) : false; + bool trans_out = (*dot)->attrs.count("trans_out") + ? absl::get((*dot)->attrs.at("trans_out")) + : false; dot->SetAttr("trans_out", static_cast(trans_out ^ true)); // shape has changed, the ignore op should update shape @@ -74,10 +81,16 @@ class TransposeFoldingOutputPass : public TransposeFoldingBase { } else if (IsValidScale(*instr)) { // assume C = alpha * A * B + beta * C // fold scale into alpha/beta - float scale = (*instr)->attrs.count("scale") ? absl::get((*instr)->attrs.at("scale")) : 1.0f; + float scale = (*instr)->attrs.count("scale") + ? absl::get((*instr)->attrs.at("scale")) + : 1.0f; - float alpha = (*dot)->attrs.count("alpha") ? absl::get((*dot)->attrs.at("alpha")) : 1.0f; - float beta = (*dot)->attrs.count("beta") ? absl::get((*dot)->attrs.at("beta")) : 0.0f; + float alpha = (*dot)->attrs.count("alpha") + ? absl::get((*dot)->attrs.at("alpha")) + : 1.0f; + float beta = (*dot)->attrs.count("beta") + ? absl::get((*dot)->attrs.at("beta")) + : 0.0f; dot->SetAttr("alpha", alpha * scale); dot->SetAttr("beta", beta * scale); @@ -105,7 +118,9 @@ class TransposeFoldingOutputPass : public TransposeFoldingBase { } // namespace cinn::frontend::pass CINN_REGISTER_HELPER(TransposeFoldingOutput) { - CINN_REGISTER_PROGRAM_PASS(TransposeFoldingOutput, ::cinn::frontend::pass::TransposeFoldingOutputPass); + CINN_REGISTER_PROGRAM_PASS( + TransposeFoldingOutput, + ::cinn::frontend::pass::TransposeFoldingOutputPass); return true; } diff --git a/paddle/cinn/frontend/pass/transpose_folding_output_test.cc b/paddle/cinn/frontend/pass/transpose_folding_output_test.cc index f9d0083343a30..4004acbd8d0ea 100755 --- a/paddle/cinn/frontend/pass/transpose_folding_output_test.cc +++ b/paddle/cinn/frontend/pass/transpose_folding_output_test.cc @@ -34,13 +34,13 @@ TEST(TransposeFoldingOutput, BatchedMatmulTransLeft) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {3, 6, 8}, "A"); - auto b = builder.Transpose(a, {0, 2, 1}); - auto c = builder.CreateInput(Float(32), {3, 6, 7}, "C"); - auto d = builder.Matmul(b, c); - auto e = builder.Transpose(d, {0, 2, 1}); - auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); - auto out = builder.Subtract(e, f); + auto a = builder.CreateInput(Float(32), {3, 6, 8}, "A"); + auto b = builder.Transpose(a, {0, 2, 1}); + auto c = builder.CreateInput(Float(32), {3, 6, 7}, "C"); + auto d = builder.Matmul(b, c); + auto e = builder.Transpose(d, {0, 2, 1}); + auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); + auto out = builder.Subtract(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -50,7 +50,10 @@ TEST(TransposeFoldingOutput, BatchedMatmulTransLeft) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } @@ -59,13 +62,13 @@ TEST(TransposeFoldingOutput, BatchedGemmTransLeft) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {3, 6, 8}, "A"); - auto b = builder.Transpose(a, {0, 2, 1}); - auto c = builder.CreateInput(Float(32), {3, 6, 7}, "C"); - auto d = builder.Matmul(b, c); - auto e = builder.Transpose(d, {0, 2, 1}); - auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); - auto out = builder.Add(e, f); + auto a = builder.CreateInput(Float(32), {3, 6, 8}, "A"); + auto b = builder.Transpose(a, {0, 2, 1}); + auto c = builder.CreateInput(Float(32), {3, 6, 7}, "C"); + auto d = builder.Matmul(b, c); + auto e = builder.Transpose(d, {0, 2, 1}); + auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); + auto out = builder.Add(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -75,7 +78,10 @@ TEST(TransposeFoldingOutput, BatchedGemmTransLeft) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } @@ -84,13 +90,13 @@ TEST(TransposeFoldingOutput, BatchedMatmulTransRight) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {3, 8, 6}, "A"); - auto b = builder.CreateInput(Float(32), {3, 7, 6}, "B"); - auto c = builder.Transpose(b, {0, 2, 1}); - auto d = builder.Matmul(a, c); - auto e = builder.Transpose(d, {0, 2, 1}); - auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); - auto out = builder.Subtract(e, f); + auto a = builder.CreateInput(Float(32), {3, 8, 6}, "A"); + auto b = builder.CreateInput(Float(32), {3, 7, 6}, "B"); + auto c = builder.Transpose(b, {0, 2, 1}); + auto d = builder.Matmul(a, c); + auto e = builder.Transpose(d, {0, 2, 1}); + auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); + auto out = builder.Subtract(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -100,7 +106,10 @@ TEST(TransposeFoldingOutput, BatchedMatmulTransRight) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } @@ -109,13 +118,13 @@ TEST(TransposeFoldingOutput, BatchedGemmTransRight) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {3, 8, 6}, "A"); - auto b = builder.CreateInput(Float(32), {3, 7, 6}, "B"); - auto c = builder.Transpose(b, {0, 2, 1}); - auto d = builder.Matmul(a, c); - auto e = builder.Transpose(d, {0, 2, 1}); - auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); - auto out = builder.Add(e, f); + auto a = builder.CreateInput(Float(32), {3, 8, 6}, "A"); + auto b = builder.CreateInput(Float(32), {3, 7, 6}, "B"); + auto c = builder.Transpose(b, {0, 2, 1}); + auto d = builder.Matmul(a, c); + auto e = builder.Transpose(d, {0, 2, 1}); + auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); + auto out = builder.Add(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -125,7 +134,10 @@ TEST(TransposeFoldingOutput, BatchedGemmTransRight) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } @@ -134,14 +146,14 @@ TEST(TransposeFoldingOutput, BatchedMatmulTransTwo) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {3, 6, 8}, "A"); - auto b = builder.Transpose(a, {0, 2, 1}); - auto c = builder.CreateInput(Float(32), {3, 7, 6}, "C"); - auto d = builder.Transpose(c, {0, 2, 1}); - auto e = builder.Matmul(b, d); - auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); - auto g = builder.Transpose(e, {0, 2, 1}); - auto out = builder.Subtract(f, g); + auto a = builder.CreateInput(Float(32), {3, 6, 8}, "A"); + auto b = builder.Transpose(a, {0, 2, 1}); + auto c = builder.CreateInput(Float(32), {3, 7, 6}, "C"); + auto d = builder.Transpose(c, {0, 2, 1}); + auto e = builder.Matmul(b, d); + auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); + auto g = builder.Transpose(e, {0, 2, 1}); + auto out = builder.Subtract(f, g); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -149,9 +161,12 @@ TEST(TransposeFoldingOutput, BatchedMatmulTransTwo) { absl::c_transform(std::vector{a.id(), c.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 3, passes, 123, true); } @@ -160,14 +175,14 @@ TEST(TransposeFoldingOutput, BatchedGemmTransTwo) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {3, 6, 8}, "A"); - auto b = builder.Transpose(a, {0, 2, 1}); - auto c = builder.CreateInput(Float(32), {3, 7, 6}, "C"); - auto d = builder.Transpose(c, {0, 2, 1}); - auto e = builder.Matmul(b, d); - auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); - auto g = builder.Transpose(e, {0, 2, 1}); - auto out = builder.Add(f, g); + auto a = builder.CreateInput(Float(32), {3, 6, 8}, "A"); + auto b = builder.Transpose(a, {0, 2, 1}); + auto c = builder.CreateInput(Float(32), {3, 7, 6}, "C"); + auto d = builder.Transpose(c, {0, 2, 1}); + auto e = builder.Matmul(b, d); + auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); + auto g = builder.Transpose(e, {0, 2, 1}); + auto out = builder.Add(f, g); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -175,9 +190,12 @@ TEST(TransposeFoldingOutput, BatchedGemmTransTwo) { absl::c_transform(std::vector{a.id(), c.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 3, passes, 123, true); } @@ -186,12 +204,12 @@ TEST(TransposeFoldingOutput, BatchedMatmulNoTrans) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {3, 8, 6}, "A"); - auto c = builder.CreateInput(Float(32), {3, 6, 7}, "C"); - auto e = builder.Matmul(a, c); - auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); - auto g = builder.Transpose(e, {0, 2, 1}); - auto out = builder.Subtract(f, g); + auto a = builder.CreateInput(Float(32), {3, 8, 6}, "A"); + auto c = builder.CreateInput(Float(32), {3, 6, 7}, "C"); + auto e = builder.Matmul(a, c); + auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); + auto g = builder.Transpose(e, {0, 2, 1}); + auto out = builder.Subtract(f, g); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -199,9 +217,12 @@ TEST(TransposeFoldingOutput, BatchedMatmulNoTrans) { absl::c_transform(std::vector{a.id(), c.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } @@ -210,12 +231,12 @@ TEST(TransposeFoldingOutput, BatchedGemmNoTrans) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {3, 8, 6}, "A"); - auto c = builder.CreateInput(Float(32), {3, 6, 7}, "C"); - auto e = builder.Matmul(a, c); - auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); - auto g = builder.Transpose(e, {0, 2, 1}); - auto out = builder.Add(f, g); + auto a = builder.CreateInput(Float(32), {3, 8, 6}, "A"); + auto c = builder.CreateInput(Float(32), {3, 6, 7}, "C"); + auto e = builder.Matmul(a, c); + auto f = builder.CreateInput(Float(32), {3, 7, 8}, "F"); + auto g = builder.Transpose(e, {0, 2, 1}); + auto out = builder.Add(f, g); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -223,9 +244,12 @@ TEST(TransposeFoldingOutput, BatchedGemmNoTrans) { absl::c_transform(std::vector{a.id(), c.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } @@ -234,13 +258,13 @@ TEST(TransposeFoldingOutput, MatmulTransLeft) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {6, 8}, "A"); - auto b = builder.Transpose(a, {1, 0}); - auto c = builder.CreateInput(Float(32), {6, 7}, "C"); - auto d = builder.Matmul(b, c); - auto e = builder.Transpose(d, {1, 0}); - auto f = builder.CreateInput(Float(32), {7, 8}, "F"); - auto out = builder.Subtract(e, f); + auto a = builder.CreateInput(Float(32), {6, 8}, "A"); + auto b = builder.Transpose(a, {1, 0}); + auto c = builder.CreateInput(Float(32), {6, 7}, "C"); + auto d = builder.Matmul(b, c); + auto e = builder.Transpose(d, {1, 0}); + auto f = builder.CreateInput(Float(32), {7, 8}, "F"); + auto out = builder.Subtract(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -250,7 +274,10 @@ TEST(TransposeFoldingOutput, MatmulTransLeft) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } @@ -259,13 +286,13 @@ TEST(TransposeFoldingOutput, GemmTransLeft) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {6, 8}, "A"); - auto b = builder.Transpose(a, {1, 0}); - auto c = builder.CreateInput(Float(32), {6, 7}, "C"); - auto d = builder.Matmul(b, c); - auto e = builder.Transpose(d, {1, 0}); - auto f = builder.CreateInput(Float(32), {7, 8}, "F"); - auto out = builder.Add(e, f); + auto a = builder.CreateInput(Float(32), {6, 8}, "A"); + auto b = builder.Transpose(a, {1, 0}); + auto c = builder.CreateInput(Float(32), {6, 7}, "C"); + auto d = builder.Matmul(b, c); + auto e = builder.Transpose(d, {1, 0}); + auto f = builder.CreateInput(Float(32), {7, 8}, "F"); + auto out = builder.Add(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -275,7 +302,10 @@ TEST(TransposeFoldingOutput, GemmTransLeft) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } @@ -284,13 +314,13 @@ TEST(TransposeFoldingOutput, MatmulTransRight) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {8, 6}, "A"); - auto b = builder.CreateInput(Float(32), {7, 6}, "B"); - auto c = builder.Transpose(b, {1, 0}); - auto d = builder.Matmul(a, c); - auto e = builder.Transpose(d, {1, 0}); - auto f = builder.CreateInput(Float(32), {7, 8}, "F"); - auto out = builder.Subtract(e, f); + auto a = builder.CreateInput(Float(32), {8, 6}, "A"); + auto b = builder.CreateInput(Float(32), {7, 6}, "B"); + auto c = builder.Transpose(b, {1, 0}); + auto d = builder.Matmul(a, c); + auto e = builder.Transpose(d, {1, 0}); + auto f = builder.CreateInput(Float(32), {7, 8}, "F"); + auto out = builder.Subtract(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -300,7 +330,10 @@ TEST(TransposeFoldingOutput, MatmulTransRight) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } @@ -309,13 +342,13 @@ TEST(TransposeFoldingOutput, GemmTransRight) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {8, 6}, "A"); - auto b = builder.CreateInput(Float(32), {7, 6}, "B"); - auto c = builder.Transpose(b, {1, 0}); - auto d = builder.Matmul(a, c); - auto e = builder.Transpose(d, {1, 0}); - auto f = builder.CreateInput(Float(32), {7, 8}, "F"); - auto out = builder.Add(e, f); + auto a = builder.CreateInput(Float(32), {8, 6}, "A"); + auto b = builder.CreateInput(Float(32), {7, 6}, "B"); + auto c = builder.Transpose(b, {1, 0}); + auto d = builder.Matmul(a, c); + auto e = builder.Transpose(d, {1, 0}); + auto f = builder.CreateInput(Float(32), {7, 8}, "F"); + auto out = builder.Add(e, f); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -325,7 +358,10 @@ TEST(TransposeFoldingOutput, GemmTransRight) { [](absl::string_view id) { return std::string(id); }); std::pair, std::vector> passes{ {"Decomposer", "RemoveIdentity"}, - {"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}}; + {"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}}; CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } @@ -334,14 +370,14 @@ TEST(TransposeFoldingOutput, MatmulTransTwo) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {6, 8}, "A"); - auto b = builder.Transpose(a, {1, 0}); - auto c = builder.CreateInput(Float(32), {7, 6}, "C"); - auto d = builder.Transpose(c, {1, 0}); - auto e = builder.Matmul(b, d); - auto f = builder.CreateInput(Float(32), {7, 8}, "F"); - auto g = builder.Transpose(e, {1, 0}); - auto out = builder.Subtract(f, g); + auto a = builder.CreateInput(Float(32), {6, 8}, "A"); + auto b = builder.Transpose(a, {1, 0}); + auto c = builder.CreateInput(Float(32), {7, 6}, "C"); + auto d = builder.Transpose(c, {1, 0}); + auto e = builder.Matmul(b, d); + auto f = builder.CreateInput(Float(32), {7, 8}, "F"); + auto g = builder.Transpose(e, {1, 0}); + auto out = builder.Subtract(f, g); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -349,9 +385,12 @@ TEST(TransposeFoldingOutput, MatmulTransTwo) { absl::c_transform(std::vector{a.id(), c.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 3, passes, 123, true); } @@ -360,14 +399,14 @@ TEST(TransposeFoldingOutput, GemmTransTwo) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {6, 8}, "A"); - auto b = builder.Transpose(a, {1, 0}); - auto c = builder.CreateInput(Float(32), {7, 6}, "C"); - auto d = builder.Transpose(c, {1, 0}); - auto e = builder.Matmul(b, d); - auto f = builder.CreateInput(Float(32), {7, 8}, "F"); - auto g = builder.Transpose(e, {1, 0}); - auto out = builder.Add(f, g); + auto a = builder.CreateInput(Float(32), {6, 8}, "A"); + auto b = builder.Transpose(a, {1, 0}); + auto c = builder.CreateInput(Float(32), {7, 6}, "C"); + auto d = builder.Transpose(c, {1, 0}); + auto e = builder.Matmul(b, d); + auto f = builder.CreateInput(Float(32), {7, 8}, "F"); + auto g = builder.Transpose(e, {1, 0}); + auto out = builder.Add(f, g); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -375,9 +414,12 @@ TEST(TransposeFoldingOutput, GemmTransTwo) { absl::c_transform(std::vector{a.id(), c.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 3, passes, 123, true); } @@ -386,12 +428,12 @@ TEST(TransposeFoldingOutput, MatmulNoTrans) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {8, 6}, "A"); - auto c = builder.CreateInput(Float(32), {6, 7}, "C"); - auto e = builder.Matmul(a, c); - auto f = builder.CreateInput(Float(32), {7, 8}, "F"); - auto g = builder.Transpose(e, {1, 0}); - auto out = builder.Subtract(f, g); + auto a = builder.CreateInput(Float(32), {8, 6}, "A"); + auto c = builder.CreateInput(Float(32), {6, 7}, "C"); + auto e = builder.Matmul(a, c); + auto f = builder.CreateInput(Float(32), {7, 8}, "F"); + auto g = builder.Transpose(e, {1, 0}); + auto out = builder.Subtract(f, g); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -399,9 +441,12 @@ TEST(TransposeFoldingOutput, MatmulNoTrans) { absl::c_transform(std::vector{a.id(), c.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } @@ -410,12 +455,12 @@ TEST(TransposeFoldingOutput, GemmNoTrans) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {8, 6}, "A"); - auto c = builder.CreateInput(Float(32), {6, 7}, "C"); - auto e = builder.Matmul(a, c); - auto f = builder.CreateInput(Float(32), {7, 8}, "F"); - auto g = builder.Transpose(e, {1, 0}); - auto out = builder.Add(f, g); + auto a = builder.CreateInput(Float(32), {8, 6}, "A"); + auto c = builder.CreateInput(Float(32), {6, 7}, "C"); + auto e = builder.Matmul(a, c); + auto f = builder.CreateInput(Float(32), {7, 8}, "F"); + auto g = builder.Transpose(e, {1, 0}); + auto out = builder.Add(f, g); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -423,9 +468,12 @@ TEST(TransposeFoldingOutput, GemmNoTrans) { absl::c_transform(std::vector{a.id(), c.id(), f.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } @@ -434,27 +482,27 @@ TEST(TransposeFoldingOutput, BatchedComplex) { return; } NetBuilder builder("net_builder"); - auto a = builder.FillConstant({20, 2}, 2.0f, "A"); - auto b = builder.FillConstant({16, 2, 20}, 2.0f, "B"); - auto c = builder.Transpose(b, {0, 2, 1}); - auto d = builder.CreateInput(Float(32), {121, 20}, "D"); - auto e = builder.BroadcastTo(d, {16, 121, 20}, {1, 2}); - auto f = builder.Matmul(e, c); - auto x = builder.FillConstant({16, 2, 20}, 1.0f, "X"); - auto y = builder.Transpose(x, {0, 2, 1}); - auto z = builder.CreateInput(Float(32), {16, 20, 121}, "Z"); - auto l = builder.Transpose(z, {0, 2, 1}); - auto m = builder.Matmul(l, y); - auto n = builder.Matmul(d, a); - auto o = builder.BroadcastTo(n, {16, n->shape[0], n->shape[1]}, {1, 2}); - auto p = builder.Subtract(f, o); - auto q = builder.Transpose(f, {0, 2, 1}); - auto u = builder.Transpose(m, {0, 2, 1}); - auto v = builder.Add(q, u); - auto w = builder.Matmul(v, p); - auto i = builder.Transpose(w, {2, 1, 0}); - auto j = builder.FillConstant({2, 2, 16}, 3.14f, "I"); - auto out = builder.Add(i, j); + auto a = builder.FillConstant({20, 2}, 2.0f, "A"); + auto b = builder.FillConstant({16, 2, 20}, 2.0f, "B"); + auto c = builder.Transpose(b, {0, 2, 1}); + auto d = builder.CreateInput(Float(32), {121, 20}, "D"); + auto e = builder.BroadcastTo(d, {16, 121, 20}, {1, 2}); + auto f = builder.Matmul(e, c); + auto x = builder.FillConstant({16, 2, 20}, 1.0f, "X"); + auto y = builder.Transpose(x, {0, 2, 1}); + auto z = builder.CreateInput(Float(32), {16, 20, 121}, "Z"); + auto l = builder.Transpose(z, {0, 2, 1}); + auto m = builder.Matmul(l, y); + auto n = builder.Matmul(d, a); + auto o = builder.BroadcastTo(n, {16, n->shape[0], n->shape[1]}, {1, 2}); + auto p = builder.Subtract(f, o); + auto q = builder.Transpose(f, {0, 2, 1}); + auto u = builder.Transpose(m, {0, 2, 1}); + auto v = builder.Add(q, u); + auto w = builder.Matmul(v, p); + auto i = builder.Transpose(w, {2, 1, 0}); + auto j = builder.FillConstant({2, 2, 16}, 3.14f, "I"); + auto out = builder.Add(i, j); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -462,9 +510,12 @@ TEST(TransposeFoldingOutput, BatchedComplex) { absl::c_transform(std::vector{d.id(), z.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 5, passes, 123, false); } @@ -473,24 +524,24 @@ TEST(TransposeFoldingOutput, Complex) { return; } NetBuilder builder("net_builder"); - auto a = builder.FillConstant({2, 20}, 2.0f, "A"); - auto b = builder.Transpose(a, {1, 0}); // 20 * 2 - auto c = builder.CreateInput(Float(32), {121, 20}, "C"); - auto f = builder.Matmul(c, b); // 121 * 2 - auto x = builder.FillConstant({2, 20}, 1.0f, "X"); - auto z = builder.CreateInput(Float(32), {121, 20}, "Z"); - auto l = builder.Transpose(z, {1, 0}); // 20 * 121 - auto y = builder.Matmul(x, l); // 2 * 121 - auto m = builder.Transpose(y, {1, 0}); // 121 * 2 - auto n = builder.Matmul(z, a, false, true); - auto p = builder.Subtract(f, n); - auto q = builder.Transpose(f, {1, 0}); - auto u = builder.Transpose(m, {1, 0}); - auto v = builder.Add(q, u); - auto w = builder.Matmul(v, p); - auto i = builder.Transpose(w, {1, 0}); - auto j = builder.FillConstant({2, 2}, 3.14f, "I"); - auto out = builder.Add(i, j); + auto a = builder.FillConstant({2, 20}, 2.0f, "A"); + auto b = builder.Transpose(a, {1, 0}); // 20 * 2 + auto c = builder.CreateInput(Float(32), {121, 20}, "C"); + auto f = builder.Matmul(c, b); // 121 * 2 + auto x = builder.FillConstant({2, 20}, 1.0f, "X"); + auto z = builder.CreateInput(Float(32), {121, 20}, "Z"); + auto l = builder.Transpose(z, {1, 0}); // 20 * 121 + auto y = builder.Matmul(x, l); // 2 * 121 + auto m = builder.Transpose(y, {1, 0}); // 121 * 2 + auto n = builder.Matmul(z, a, false, true); + auto p = builder.Subtract(f, n); + auto q = builder.Transpose(f, {1, 0}); + auto u = builder.Transpose(m, {1, 0}); + auto v = builder.Add(q, u); + auto w = builder.Matmul(v, p); + auto i = builder.Transpose(w, {1, 0}); + auto j = builder.FillConstant({2, 2}, 3.14f, "I"); + auto out = builder.Add(i, j); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -498,10 +549,13 @@ TEST(TransposeFoldingOutput, Complex) { absl::c_transform(std::vector{c.id(), z.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{ - "TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter", "TransposeFoldingOutput"}); + auto passes = + std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter", + "TransposeFoldingOutput"}); CompareResult(&program, target, input_ids, {out->id}, 5, passes, 123, false); } @@ -510,29 +564,31 @@ TEST(TransposeFoldingOutput, MultiTransCaseOne) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {2, 10}, "A"); - auto b = builder.CreateInput(Float(32), {10, 50}, "B"); - auto c = builder.Matmul(a, b); // 2 * 50 - auto d = builder.Transpose(c, {1, 0}); // 50 * 2 - auto e = builder.CreateInput(Float(32), {50, 2}, "E"); - auto f = builder.Add(d, e); - auto g = builder.Transpose(f, {1, 0}); - auto h = builder.CreateInput(Float(32), {2, 50}, "H"); - auto out = builder.Add(h, g); + auto a = builder.CreateInput(Float(32), {2, 10}, "A"); + auto b = builder.CreateInput(Float(32), {10, 50}, "B"); + auto c = builder.Matmul(a, b); // 2 * 50 + auto d = builder.Transpose(c, {1, 0}); // 50 * 2 + auto e = builder.CreateInput(Float(32), {50, 2}, "E"); + auto f = builder.Add(d, e); + auto g = builder.Transpose(f, {1, 0}); + auto h = builder.CreateInput(Float(32), {2, 50}, "H"); + auto out = builder.Add(h, g); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); std::vector input_ids; - absl::c_transform(std::vector{a.id(), b.id(), e.id(), h.id()}, - std::back_inserter(input_ids), - [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", - "GemmRewriter", - "TransposeFoldingOutput", - "GemmRewriter", - "TransposeFoldingOutput", - "GemmRewriter"}); + absl::c_transform( + std::vector{a.id(), b.id(), e.id(), h.id()}, + std::back_inserter(input_ids), + [](absl::string_view id) { return std::string(id); }); + auto passes = + std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, true); } @@ -541,13 +597,13 @@ TEST(TransposeFoldingOutput, MultiTransCaseTwo) { return; } NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {2, 10}, "A"); - auto b = builder.CreateInput(Float(32), {10, 50}, "B"); - auto c = builder.Matmul(a, b); // 2 * 50 - auto d = builder.Transpose(c, {1, 0}); // 50 * 2 - auto g = builder.Transpose(d, {1, 0}); - auto h = builder.CreateInput(Float(32), {2, 50}, "H"); - auto out = builder.Add(h, g); + auto a = builder.CreateInput(Float(32), {2, 10}, "A"); + auto b = builder.CreateInput(Float(32), {10, 50}, "B"); + auto c = builder.Matmul(a, b); // 2 * 50 + auto d = builder.Transpose(c, {1, 0}); // 50 * 2 + auto g = builder.Transpose(d, {1, 0}); + auto h = builder.CreateInput(Float(32), {2, 50}, "H"); + auto out = builder.Add(h, g); auto program = builder.Build(); common::Target target = common::DefaultNVGPUTarget(); @@ -555,13 +611,14 @@ TEST(TransposeFoldingOutput, MultiTransCaseTwo) { absl::c_transform(std::vector{a.id(), b.id(), h.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, - std::vector{"TransposeFoldingInput", - "GemmRewriter", - "TransposeFoldingOutput", - "GemmRewriter", - "TransposeFoldingOutput", - "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer", "RemoveIdentity"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, true); } diff --git a/paddle/cinn/frontend/pass/transpose_scale_folding_test.cc b/paddle/cinn/frontend/pass/transpose_scale_folding_test.cc index e39382ea6f4ca..296ba7fba96a8 100644 --- a/paddle/cinn/frontend/pass/transpose_scale_folding_test.cc +++ b/paddle/cinn/frontend/pass/transpose_scale_folding_test.cc @@ -34,10 +34,10 @@ TEST(ScaleFolding, FoldIntoDotBatchedCase1) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); - auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); + auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); + auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); auto scale_x = builder.Scale(x); - auto out = builder.Matmul(scale_x, y); + auto out = builder.Matmul(scale_x, y); auto program = builder.Build(); common::Target target = common::DefaultTarget(); @@ -45,9 +45,12 @@ TEST(ScaleFolding, FoldIntoDotBatchedCase1) { absl::c_transform(std::vector{x.id(), y.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, false); } @@ -56,10 +59,10 @@ TEST(ScaleFolding, FoldIntoDotBatchedCase2) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); - auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); + auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); + auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); auto scale_x = builder.Scale(x, 2.0f); - auto out = builder.Matmul(scale_x, y); + auto out = builder.Matmul(scale_x, y); auto program = builder.Build(); common::Target target = common::DefaultTarget(); @@ -67,9 +70,12 @@ TEST(ScaleFolding, FoldIntoDotBatchedCase2) { absl::c_transform(std::vector{x.id(), y.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, false); } @@ -78,10 +84,10 @@ TEST(ScaleFolding, FoldIntoDotBatchedCase3) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); - auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); + auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); + auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); auto scale_x = builder.Scale(x, 2.0f, 1.0f); - auto out = builder.Matmul(scale_x, y); + auto out = builder.Matmul(scale_x, y); auto program = builder.Build(); common::Target target = common::DefaultTarget(); @@ -89,9 +95,12 @@ TEST(ScaleFolding, FoldIntoDotBatchedCase3) { absl::c_transform(std::vector{x.id(), y.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 0, passes, 123, false); } @@ -100,10 +109,10 @@ TEST(ScaleFolding, FoldIntoDotBatchedCase4) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); - auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); + auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); + auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); auto scale_y = builder.Scale(y, 2.0f); - auto out = builder.Matmul(x, scale_y); + auto out = builder.Matmul(x, scale_y); auto program = builder.Build(); common::Target target = common::DefaultTarget(); @@ -111,9 +120,12 @@ TEST(ScaleFolding, FoldIntoDotBatchedCase4) { absl::c_transform(std::vector{x.id(), y.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 1, passes, 123, false); } @@ -122,11 +134,11 @@ TEST(ScaleFolding, FoldIntoDotBatchedCase5) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); - auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); + auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); + auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); auto scale_x = builder.Scale(x, 2.0f); auto scale_y = builder.Scale(y, 2.0f); - auto out = builder.Matmul(scale_x, scale_y); + auto out = builder.Matmul(scale_x, scale_y); auto program = builder.Build(); common::Target target = common::DefaultTarget(); @@ -134,9 +146,12 @@ TEST(ScaleFolding, FoldIntoDotBatchedCase5) { absl::c_transform(std::vector{x.id(), y.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, false); } @@ -145,22 +160,25 @@ TEST(ScaleFolding, FoldIntoDotBatchedCase6) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); - auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); - auto scale_x = builder.Scale(x, 2.0f); - auto scale_y = builder.Scale(y, 2.0f); + auto x = builder.CreateInput(Float(32), {4, 3, 5}, "X"); + auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); + auto scale_x = builder.Scale(x, 2.0f); + auto scale_y = builder.Scale(y, 2.0f); auto orig_out = builder.Matmul(scale_x, scale_y); - auto out = builder.Scale(orig_out, 2.0f); - auto program = builder.Build(); + auto out = builder.Scale(orig_out, 2.0f); + auto program = builder.Build(); common::Target target = common::DefaultTarget(); std::vector input_ids; absl::c_transform(std::vector{x.id(), y.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 3, passes, 123, false); } @@ -169,25 +187,28 @@ TEST(TransposeScaleFolding, BatchComplexCase1) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto y = builder.CreateInput(Float(32), {4, 6, 5}, "Y"); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto y = builder.CreateInput(Float(32), {4, 6, 5}, "Y"); auto transpose_x = builder.Transpose(x, {0, 2, 1}); - auto scale_x = builder.Scale(transpose_x, 2.0f); + auto scale_x = builder.Scale(transpose_x, 2.0f); auto transpose_y = builder.Transpose(y, {0, 2, 1}); - auto scale_y = builder.Scale(transpose_y, 2.0f); - auto orig_out = builder.Matmul(scale_x, scale_y); - auto scale_out = builder.Scale(orig_out, 2.0f); - auto out = builder.Transpose(scale_out, {0, 2, 1}); - auto program = builder.Build(); + auto scale_y = builder.Scale(transpose_y, 2.0f); + auto orig_out = builder.Matmul(scale_x, scale_y); + auto scale_out = builder.Scale(orig_out, 2.0f); + auto out = builder.Transpose(scale_out, {0, 2, 1}); + auto program = builder.Build(); common::Target target = common::DefaultTarget(); std::vector input_ids; absl::c_transform(std::vector{x.id(), y.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 6, passes, 123, false); } @@ -196,25 +217,28 @@ TEST(TransposeScaleFolding, BatchComplexCase2) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto y = builder.CreateInput(Float(32), {4, 6, 5}, "Y"); - auto scale_x = builder.Scale(x, 2.0f); - auto transpose_x = builder.Transpose(scale_x, {0, 2, 1}); - auto scale_y = builder.Scale(y, 2.0f); - auto transpose_y = builder.Transpose(scale_y, {0, 2, 1}); - auto orig_out = builder.Matmul(transpose_x, transpose_y); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto y = builder.CreateInput(Float(32), {4, 6, 5}, "Y"); + auto scale_x = builder.Scale(x, 2.0f); + auto transpose_x = builder.Transpose(scale_x, {0, 2, 1}); + auto scale_y = builder.Scale(y, 2.0f); + auto transpose_y = builder.Transpose(scale_y, {0, 2, 1}); + auto orig_out = builder.Matmul(transpose_x, transpose_y); auto transpose_out = builder.Transpose(orig_out, {0, 2, 1}); - auto out = builder.Scale(transpose_out, 2.0f); - auto program = builder.Build(); + auto out = builder.Scale(transpose_out, 2.0f); + auto program = builder.Build(); common::Target target = common::DefaultTarget(); std::vector input_ids; absl::c_transform(std::vector{x.id(), y.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 6, passes, 123, false); } @@ -223,21 +247,24 @@ TEST(TransposeScaleFolding, BatchComplexCase3) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); auto transpose_x = builder.Transpose(x, {0, 2, 1}); - auto scale_y = builder.Scale(y, 2.0f); - auto out = builder.Matmul(transpose_x, scale_y); - auto program = builder.Build(); + auto scale_y = builder.Scale(y, 2.0f); + auto out = builder.Matmul(transpose_x, scale_y); + auto program = builder.Build(); common::Target target = common::DefaultTarget(); std::vector input_ids; absl::c_transform(std::vector{x.id(), y.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, false); } @@ -246,20 +273,23 @@ TEST(TransposeScaleFolding, BatchComplexCase4) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); auto transpose_x = builder.Transpose(x, {0, 2, 1}); - auto scale_x = builder.Scale(x, 2.0f); - auto out = builder.Matmul(transpose_x, scale_x); - auto program = builder.Build(); + auto scale_x = builder.Scale(x, 2.0f); + auto out = builder.Matmul(transpose_x, scale_x); + auto program = builder.Build(); common::Target target = common::DefaultTarget(); std::vector input_ids; - absl::c_transform(std::vector{x.id()}, std::back_inserter(input_ids), [](absl::string_view id) { - return std::string(id); - }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + absl::c_transform(std::vector{x.id()}, + std::back_inserter(input_ids), + [](absl::string_view id) { return std::string(id); }); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 2, passes, 123, false); } @@ -268,24 +298,27 @@ TEST(TransposeScaleFolding, BatchComplexCase5) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); - auto z = builder.FillConstant({4, 3, 6}, 1.0f, "Z"); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto y = builder.CreateInput(Float(32), {4, 5, 6}, "Y"); + auto z = builder.FillConstant({4, 3, 6}, 1.0f, "Z"); auto transpose_x = builder.Transpose(x, {0, 2, 1}); - auto scale_y = builder.Scale(y, 2.0f); - auto out_matmul = builder.Matmul(transpose_x, scale_y); + auto scale_y = builder.Scale(y, 2.0f); + auto out_matmul = builder.Matmul(transpose_x, scale_y); auto transpose_o = builder.Transpose(out_matmul, {0, 2, 1}); - auto out = builder.Matmul(transpose_o, z); - auto program = builder.Build(); + auto out = builder.Matmul(transpose_o, z); + auto program = builder.Build(); common::Target target = common::DefaultTarget(); std::vector input_ids; absl::c_transform(std::vector{x.id(), y.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 3, passes, 123, false); } @@ -294,22 +327,25 @@ TEST(TransposeScaleFolding, BatchComplexCase6) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {20, 3}, "X"); - auto reshape_x = builder.Reshape(x, {4, 5, 3}); - auto scale_x = builder.Scale(reshape_x, 2.0f); + auto x = builder.CreateInput(Float(32), {20, 3}, "X"); + auto reshape_x = builder.Reshape(x, {4, 5, 3}); + auto scale_x = builder.Scale(reshape_x, 2.0f); auto transpose_x = builder.Transpose(scale_x, {0, 2, 1}); - auto out_matmul = builder.Matmul(scale_x, transpose_x); - auto out = builder.Transpose(out_matmul, {0, 2, 1}); - auto program = builder.Build(); + auto out_matmul = builder.Matmul(scale_x, transpose_x); + auto out = builder.Transpose(out_matmul, {0, 2, 1}); + auto program = builder.Build(); common::Target target = common::DefaultTarget(); std::vector input_ids; - absl::c_transform(std::vector{x.id()}, std::back_inserter(input_ids), [](absl::string_view id) { - return std::string(id); - }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + absl::c_transform(std::vector{x.id()}, + std::back_inserter(input_ids), + [](absl::string_view id) { return std::string(id); }); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 3, passes, 123, false); } @@ -318,24 +354,27 @@ TEST(TransposeBroadCastFolding, BatchComplexCase1) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto y = builder.CreateInput(Float(32), {5, 6}, "Y"); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto y = builder.CreateInput(Float(32), {5, 6}, "Y"); auto transpose_x = builder.Transpose(x, {0, 2, 1}); - auto scale_y = builder.Scale(y, 2.0f); + auto scale_y = builder.Scale(y, 2.0f); auto broadcast_y = builder.BroadcastTo(scale_y, {4, 5, 6}); - auto out_matmul = builder.Matmul(transpose_x, broadcast_y); - auto out_trans = builder.Transpose(out_matmul, {0, 2, 1}); - auto out = builder.Scale(out_trans, 2.0f); - auto program = builder.Build(); + auto out_matmul = builder.Matmul(transpose_x, broadcast_y); + auto out_trans = builder.Transpose(out_matmul, {0, 2, 1}); + auto out = builder.Scale(out_trans, 2.0f); + auto program = builder.Build(); common::Target target = common::DefaultTarget(); std::vector input_ids; absl::c_transform(std::vector{x.id(), y.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 5, passes, 123, false); } @@ -344,26 +383,29 @@ TEST(TransposeBroadCastFolding, BatchComplexCase2) { return; } NetBuilder builder("net_builder"); - auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); - auto y = builder.CreateInput(Float(32), {5, 6}, "Y"); + auto x = builder.CreateInput(Float(32), {4, 5, 3}, "X"); + auto y = builder.CreateInput(Float(32), {5, 6}, "Y"); auto transpose_x = builder.Transpose(x, {0, 2, 1}); - auto cast_x = builder.Cast(transpose_x, "float32"); - auto scale_y = builder.Scale(y, 2.0f); + auto cast_x = builder.Cast(transpose_x, "float32"); + auto scale_y = builder.Scale(y, 2.0f); auto broadcast_y = builder.BroadcastTo(scale_y, {4, 5, 6}); - auto out_matmul = builder.Matmul(cast_x, broadcast_y); - auto out_cast = builder.Cast(out_matmul, "float32"); - auto out_trans = builder.Transpose(out_cast, {0, 2, 1}); - auto out = builder.Scale(out_trans, 2.0f); - auto program = builder.Build(); + auto out_matmul = builder.Matmul(cast_x, broadcast_y); + auto out_cast = builder.Cast(out_matmul, "float32"); + auto out_trans = builder.Transpose(out_cast, {0, 2, 1}); + auto out = builder.Scale(out_trans, 2.0f); + auto program = builder.Build(); common::Target target = common::DefaultTarget(); std::vector input_ids; absl::c_transform(std::vector{x.id(), y.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - auto passes = std::make_pair( - std::vector{"Decomposer"}, - std::vector{"TransposeFoldingInput", "GemmRewriter", "TransposeFoldingOutput", "GemmRewriter"}); + auto passes = + std::make_pair(std::vector{"Decomposer"}, + std::vector{"TransposeFoldingInput", + "GemmRewriter", + "TransposeFoldingOutput", + "GemmRewriter"}); CompareResult(&program, target, input_ids, {out->id}, 5, passes, 123, false); } diff --git a/paddle/cinn/frontend/program_pass.cc b/paddle/cinn/frontend/program_pass.cc index ec2116fa92dd0..1cd0903f97a03 100644 --- a/paddle/cinn/frontend/program_pass.cc +++ b/paddle/cinn/frontend/program_pass.cc @@ -32,13 +32,15 @@ void ProgramPass::Apply(Program* prog, } for (const auto* pass : fpass) { int before = prog->size(); - cinn::hlir::framework::PassPrinter::GetInstance()->PassBegin(pass->name(), *prog); + cinn::hlir::framework::PassPrinter::GetInstance()->PassBegin(pass->name(), + *prog); pass->ApplyImpl(prog, fetch_ids, target); const_cast(pass)->Clear(); int after = prog->size(); - cinn::hlir::framework::PassPrinter::GetInstance()->PassEnd(pass->name(), *prog); - VLOG(1) << "Apply " << pass->name() << " pass, program size: " << before << " -> " << after - << ", diff: " << after - before; + cinn::hlir::framework::PassPrinter::GetInstance()->PassEnd(pass->name(), + *prog); + VLOG(1) << "Apply " << pass->name() << " pass, program size: " << before + << " -> " << after << ", diff: " << after - before; } } diff --git a/paddle/cinn/frontend/program_pass.h b/paddle/cinn/frontend/program_pass.h index 94c86e30b75ce..4e54c70fc5ee4 100755 --- a/paddle/cinn/frontend/program_pass.h +++ b/paddle/cinn/frontend/program_pass.h @@ -84,7 +84,8 @@ class ProgramPassRegistry : public Registry { return pass; } - inline ProgramPass* __REGISTER_OR_GET__(const std::string& name, ProgramPass* pass) { + inline ProgramPass* __REGISTER_OR_GET__(const std::string& name, + ProgramPass* pass) { if (!fmap_.count(name)) { return __REGISTER__(name, pass); } else { @@ -108,9 +109,10 @@ class ProgramPassRegistry : public Registry { * CINN_REGISTER_PROGRAM_PASS(decompose, DecomposerPass()); * \endcode */ -#define CINN_REGISTER_PROGRAM_PASS(PassType, PassClass) \ - static ::cinn::frontend::ProgramPass* __make_##PassType##__ = \ - ::cinn::frontend::ProgramPassRegistry::Global()->__REGISTER_OR_GET__(#PassType, new PassClass{#PassType}) +#define CINN_REGISTER_PROGRAM_PASS(PassType, PassClass) \ + static ::cinn::frontend::ProgramPass* __make_##PassType##__ = \ + ::cinn::frontend::ProgramPassRegistry::Global()->__REGISTER_OR_GET__( \ + #PassType, new PassClass{#PassType}) } // namespace frontend } // namespace cinn diff --git a/paddle/cinn/frontend/syntax.cc b/paddle/cinn/frontend/syntax.cc index 92f39973142fa..b2d4e9a54e968 100644 --- a/paddle/cinn/frontend/syntax.cc +++ b/paddle/cinn/frontend/syntax.cc @@ -41,19 +41,22 @@ void Instruction::PrepareOutputs() { } } -Instruction::Instruction(absl::string_view op_type, const std::vector& inputs, Program* parent) +Instruction::Instruction(absl::string_view op_type, + const std::vector& inputs, + Program* parent) : common::Shared<_Instruction_>(common::make_shared<_Instruction_>()) { - get()->op_type = std::string(op_type); + get()->op_type = std::string(op_type); get()->parent_program = parent; - get()->inputs = inputs; + get()->inputs = inputs; PrepareOutputs(); } Placeholder::operator Variable() const { return var_; } -Variable Program::conv2d(const Variable& a, - const Variable& b, - const absl::flat_hash_map& attr_store) { +Variable Program::conv2d( + const Variable& a, + const Variable& b, + const absl::flat_hash_map& attr_store) { Instruction instr("conv2d"); instr.SetInputs({a, b}); for (auto& iter : attr_store) { @@ -63,7 +66,9 @@ Variable Program::conv2d(const Variable& a, return instr.GetOutput(0); } -Variable Program::layout_transform(const Variable& a, const absl::flat_hash_map& attr_store) { +Variable Program::layout_transform( + const Variable& a, + const absl::flat_hash_map& attr_store) { Instruction instr("layout_transform"); instr.SetInputs({a}); for (auto& iter : attr_store) { @@ -73,9 +78,10 @@ Variable Program::layout_transform(const Variable& a, const absl::flat_hash_map< return instr.GetOutput(0); } -Variable Program::conv2d_NCHWc(const Variable& a, - const Variable& b, - const absl::flat_hash_map& attr_store) { +Variable Program::conv2d_NCHWc( + const Variable& a, + const Variable& b, + const absl::flat_hash_map& attr_store) { Instruction instr("conv2d_NCHWc"); instr.SetInputs({a, b}); for (auto& iter : attr_store) { @@ -85,9 +91,10 @@ Variable Program::conv2d_NCHWc(const Variable& a, return instr.GetOutput(0); } -Variable Program::depthwise_conv2d(const Variable& a, - const Variable& b, - const absl::flat_hash_map& attr_store) { +Variable Program::depthwise_conv2d( + const Variable& a, + const Variable& b, + const absl::flat_hash_map& attr_store) { Instruction instr("depthwise_conv2d"); instr.SetInputs({a, b}); for (auto& iter : attr_store) { @@ -97,7 +104,9 @@ Variable Program::depthwise_conv2d(const Variable& a, return instr.GetOutput(0); } -Variable Program::pool2d(const Variable& a, const absl::flat_hash_map& attr_store) { +Variable Program::pool2d( + const Variable& a, + const absl::flat_hash_map& attr_store) { Instruction instr("pool2d"); instr.SetInputs({a}); for (auto& iter : attr_store) { @@ -107,12 +116,13 @@ Variable Program::pool2d(const Variable& a, const absl::flat_hash_map& attr_store) { +Variable Program::batchnorm( + const Variable& a, + const Variable& scale, + const Variable& bias, + const Variable& mean, + const Variable& variance, + const absl::flat_hash_map& attr_store) { Instruction instr("batch_norm"); instr.SetInputs({a, scale, bias, mean, variance}); for (auto& iter : attr_store) { @@ -123,7 +133,8 @@ Variable Program::batchnorm(const Variable& a, } template -Variable Program::primitive_const_scalar(PrimType value, const std::string& name) { +Variable Program::primitive_const_scalar(PrimType value, + const std::string& name) { Instruction instr("const_scalar"); instr.SetInputs({}); instr.SetAttr("value", value); @@ -131,15 +142,17 @@ Variable Program::primitive_const_scalar(PrimType value, const std::string& name auto out = instr.GetOutput(0); out.set_id(name); auto out_type = type_of(); - CHECK(out_type.is_float() || out_type.is_int() || out_type.is_bool()) << "no supported type: " << out_type; + CHECK(out_type.is_float() || out_type.is_int() || out_type.is_bool()) + << "no supported type: " << out_type; out->type = out_type; out.set_const(true); return out; } -Variable Program::primitive_broadcast_to(const Variable& a, - const std::vector& out_shape, - const std::vector& broadcast_axes) { +Variable Program::primitive_broadcast_to( + const Variable& a, + const std::vector& out_shape, + const std::vector& broadcast_axes) { Instruction instr("broadcast_to"); instr.SetInputs({a}); instr.SetAttr("out_shape", out_shape); @@ -148,59 +161,65 @@ Variable Program::primitive_broadcast_to(const Variable& a, return instr.GetOutput(0); } -Variable Program::fused_meta_batchnorm_inference(const Variable& a, - const Variable& scale, - const Variable& bias, - const Variable& mean, - const Variable& variance, - const absl::flat_hash_map& attr_store) { +Variable Program::fused_meta_batchnorm_inference( + const Variable& a, + const Variable& scale, + const Variable& bias, + const Variable& mean, + const Variable& variance, + const absl::flat_hash_map& attr_store) { float epsilon = 0.00001f; if (attr_store.find("epsilon") != attr_store.end()) { epsilon = absl::get(attr_store.at("epsilon")); } - auto eps_var = primitive_const_scalar(epsilon, common::UniqName("epsilon")); + auto eps_var = + primitive_const_scalar(epsilon, common::UniqName("epsilon")); CHECK(!scale->shape.empty()) << "scale's shape is empty."; auto broadcast_eps = primitive_broadcast_to(eps_var, scale->shape, {0}); - auto var_add_eps = add(variance, broadcast_eps); - auto rsrqt_var = primitive_rsqrt(var_add_eps); - auto new_scale = multiply(rsrqt_var, scale); - auto neg_mean = primitive_negative(mean); - auto new_shift = multiply(new_scale, neg_mean); - auto shift_bias = add(new_shift, bias); + auto var_add_eps = add(variance, broadcast_eps); + auto rsrqt_var = primitive_rsqrt(var_add_eps); + auto new_scale = multiply(rsrqt_var, scale); + auto neg_mean = primitive_negative(mean); + auto new_shift = multiply(new_scale, neg_mean); + auto shift_bias = add(new_shift, bias); CHECK(!a->shape.empty()) << "variable a's shape is empty."; - auto broadcast_new_scale = primitive_broadcast_to(new_scale, a->shape, {1}); + auto broadcast_new_scale = primitive_broadcast_to(new_scale, a->shape, {1}); auto broadcast_shift_bias = primitive_broadcast_to(shift_bias, a->shape, {1}); - auto temp_out = multiply(broadcast_new_scale, a); - auto bn_out = add(temp_out, broadcast_shift_bias); + auto temp_out = multiply(broadcast_new_scale, a); + auto bn_out = add(temp_out, broadcast_shift_bias); return bn_out; } -Variable Program::fused_batchnorm_inference(const Variable& a, - const Variable& scale, - const Variable& bias, - const Variable& mean, - const Variable& variance, - const absl::flat_hash_map& attr_store) { +Variable Program::fused_batchnorm_inference( + const Variable& a, + const Variable& scale, + const Variable& bias, + const Variable& mean, + const Variable& variance, + const absl::flat_hash_map& attr_store) { float epsilon = 0.00001f; if (attr_store.find("epsilon") != attr_store.end()) { epsilon = absl::get(attr_store.at("epsilon")); } - auto eps_var = primitive_const_scalar(epsilon, common::UniqName("epsilon")); + auto eps_var = + primitive_const_scalar(epsilon, common::UniqName("epsilon")); CHECK(!scale->shape.empty()) << "scale's shape is empty."; auto var_add_eps = elementwise_add(variance, eps_var); - auto rsrqt_var = primitive_rsqrt(var_add_eps); - auto new_scale = elementwise_mul(rsrqt_var, scale); - auto neg_mean = primitive_negative(mean); - auto new_shift = elementwise_mul(new_scale, neg_mean); - auto shift_bias = elementwise_add(new_shift, bias); - auto temp_out = elementwise_mul(a, new_scale, 1); - auto bn_out = elementwise_add(temp_out, shift_bias, 1); + auto rsrqt_var = primitive_rsqrt(var_add_eps); + auto new_scale = elementwise_mul(rsrqt_var, scale); + auto neg_mean = primitive_negative(mean); + auto new_shift = elementwise_mul(new_scale, neg_mean); + auto shift_bias = elementwise_add(new_shift, bias); + auto temp_out = elementwise_mul(a, new_scale, 1); + auto bn_out = elementwise_add(temp_out, shift_bias, 1); return bn_out; } -Variable Program::scale(const Variable& a, const absl::flat_hash_map& attr_store) { +Variable Program::scale( + const Variable& a, + const absl::flat_hash_map& attr_store) { Instruction instr("scale", {a}); for (auto& iter : attr_store) { instr.SetAttr(iter.first, iter.second); @@ -209,7 +228,9 @@ Variable Program::scale(const Variable& a, const absl::flat_hash_map& attr_store) { +Variable Program::softmax( + const Variable& a, + const absl::flat_hash_map& attr_store) { Instruction instr("softmax", {a}); for (auto& iter : attr_store) { instr.SetAttr(iter.first, iter.second); @@ -224,7 +245,9 @@ Variable Program::sigmoid(const Variable& a) { return instr.GetOutput(0); } -Variable Program::slice(const Variable& a, const absl::flat_hash_map& attr_store) { +Variable Program::slice( + const Variable& a, + const absl::flat_hash_map& attr_store) { Instruction instr("slice", {a}); for (auto& iter : attr_store) { instr.SetAttr(iter.first, iter.second); @@ -233,7 +256,9 @@ Variable Program::slice(const Variable& a, const absl::flat_hash_map& attr_store) { +Variable Program::dropout_infer( + const Variable& a, + const absl::flat_hash_map& attr_store) { Instruction instr("dropout_infer", {a}); for (auto& iter : attr_store) { instr.SetAttr(iter.first, iter.second); @@ -253,7 +278,8 @@ const Instruction& Program::operator[](size_t i) const { } std::ostream& operator<<(std::ostream& os, const Variable& x) { - os << "Var(" << x->id << ": shape=[" << utils::Join(x->shape, ", ") << "], dtype=" << x->type; + os << "Var(" << x->id << ": shape=[" << utils::Join(x->shape, ", ") + << "], dtype=" << x->type; if (x->is_const) { os << ", CONST"; } @@ -270,11 +296,12 @@ std::tuple, absl::flat_hash_map, absl::flat_hash_map, absl::flat_hash_set> -LoadPaddleProgram(const std::string& model_dir, - Scope* scope, - std::unordered_map>& input_shape_map, - bool is_combined, - const common::Target& target) { +LoadPaddleProgram( + const std::string& model_dir, + Scope* scope, + std::unordered_map>& input_shape_map, + bool is_combined, + const common::Target& target) { VLOG(1) << "Loading Paddle model from " << model_dir; PaddleModelToProgram paddle_to_program(scope, input_shape_map, target); return std::make_tuple(paddle_to_program(model_dir, is_combined), @@ -286,15 +313,18 @@ LoadPaddleProgram(const std::string& model_dir, void Program::SetInputs(const std::vector& xs) { CHECK(!xs.empty()) << "At least one input is needed for a program!"; for (int i = 0; i < xs.size(); i++) { - CHECK(!xs[i]->shape.empty()) << "Found " << i << "-th input's shape is not set yet"; - CHECK(!xs[i]->type.is_unk()) << "Found " << i << "-th input's type is not set yet"; + CHECK(!xs[i]->shape.empty()) + << "Found " << i << "-th input's shape is not set yet"; + CHECK(!xs[i]->type.is_unk()) + << "Found " << i << "-th input's type is not set yet"; inputs_.push_back(xs[i]); } } void Program::Validate() const { - // Existing some program don't have input, such as a program only has `fill_constant` - // CHECK(!inputs_.empty()) << "Inputs of the program is not set yet"; + // Existing some program don't have input, such as a program only has + // `fill_constant` CHECK(!inputs_.empty()) << "Inputs of the program is not + // set yet"; CHECK(!instrs_.empty()) << "No instruction is added yet"; } @@ -383,41 +413,50 @@ SYNTAX_PRIM_BINARY_IMPL(bitwise_and) SYNTAX_PRIM_BINARY_IMPL(left_shift) SYNTAX_PRIM_BINARY_IMPL(right_shift) -Variable Program::elementwise_add(const Variable& a, const Variable& b, int axis) { +Variable Program::elementwise_add(const Variable& a, + const Variable& b, + int axis) { Instruction instr("elementwise_add", {a, b}); instr.SetAttr("axis", axis); AppendInstruction(instr); return instr.GetOutput(0); } -Variable Program::elementwise_mul(const Variable& a, const Variable& b, int axis) { +Variable Program::elementwise_mul(const Variable& a, + const Variable& b, + int axis) { Instruction instr("elementwise_mul", {a, b}); instr.SetAttr("axis", axis); AppendInstruction(instr); return instr.GetOutput(0); } -Variable Program::elementwise_div(const Variable& a, const Variable& b, int axis) { +Variable Program::elementwise_div(const Variable& a, + const Variable& b, + int axis) { Instruction instr("divide", {a, b}); instr.SetAttr("axis", axis); AppendInstruction(instr); return instr.GetOutput(0); } -Variable Program::elementwise_sub(const Variable& a, const Variable& b, int axis) { +Variable Program::elementwise_sub(const Variable& a, + const Variable& b, + int axis) { Instruction instr("subtract", {a, b}); instr.SetAttr("axis", axis); AppendInstruction(instr); return instr.GetOutput(0); } -#define SYNTAX_PRIM_REDUCE_IMPL(name__) \ - Variable Program::reduce_##name__(const Variable& a, const std::vector& dim, bool keep_dim) { \ - Instruction instr("reduce_" #name__, {a}); \ - instr.SetAttr("dim", dim); \ - instr.SetAttr("keep_dim", keep_dim); \ - AppendInstruction(instr); \ - return instr.GetOutput(0); \ +#define SYNTAX_PRIM_REDUCE_IMPL(name__) \ + Variable Program::reduce_##name__( \ + const Variable& a, const std::vector& dim, bool keep_dim) { \ + Instruction instr("reduce_" #name__, {a}); \ + instr.SetAttr("dim", dim); \ + instr.SetAttr("keep_dim", keep_dim); \ + AppendInstruction(instr); \ + return instr.GetOutput(0); \ } SYNTAX_PRIM_REDUCE_IMPL(sum) @@ -443,7 +482,10 @@ Variable Program::relu6(const Variable& a) { return instr.GetOutput(0); } -Variable Program::mul(const Variable& a, const Variable& b, int x_num_col_dims, int y_num_col_dims) { +Variable Program::mul(const Variable& a, + const Variable& b, + int x_num_col_dims, + int y_num_col_dims) { Instruction instr("mul", {a, b}); instr.SetAttr("x_num_col_dims", x_num_col_dims); instr.SetAttr("y_num_col_dims", y_num_col_dims); @@ -451,7 +493,11 @@ Variable Program::mul(const Variable& a, const Variable& b, int x_num_col_dims, return instr.GetOutput(0); } -Variable Program::matmul(const Variable& a, const Variable& b, bool trans_a, bool trans_b, float alpha) { +Variable Program::matmul(const Variable& a, + const Variable& b, + bool trans_a, + bool trans_b, + float alpha) { Instruction instr("matmul", {a, b}); instr.SetAttr("trans_a", trans_a); instr.SetAttr("trans_b", trans_b); @@ -491,20 +537,36 @@ std::string _Instruction_::debug_string() const { void operator()(double x) { s_ << x; } void operator()(bool x) { s_ << (x ? "true" : "false"); } void operator()(const std::string& x) { s_ << x; } - void operator()(const std::vector& x) { s_ << "[" + utils::Join(x, ",") + "]"; } - void operator()(const std::vector& x) { s_ << "[" + utils::Join(x, ",") + "]"; } - void operator()(const std::vector& x) { s_ << "[" + utils::Join(x, ",") + "]"; } - void operator()(const std::vector& x) { s_ << "[" + utils::Join(x, ",") + "]"; } - void operator()(const std::vector& x) { s_ << "[" + utils::Join(x, ",") + "]"; } - void operator()(const std::vector& x) { s_ << "[" + utils::Join(x, ",") + "]"; } + void operator()(const std::vector& x) { + s_ << "[" + utils::Join(x, ",") + "]"; + } + void operator()(const std::vector& x) { + s_ << "[" + utils::Join(x, ",") + "]"; + } + void operator()(const std::vector& x) { + s_ << "[" + utils::Join(x, ",") + "]"; + } + void operator()(const std::vector& x) { + s_ << "[" + utils::Join(x, ",") + "]"; + } + void operator()(const std::vector& x) { + s_ << "[" + utils::Join(x, ",") + "]"; + } + void operator()(const std::vector& x) { + s_ << "[" + utils::Join(x, ",") + "]"; + } }; std::stringstream ss; std::vector input_names, output_names; - std::transform( - inputs.begin(), inputs.end(), std::back_inserter(input_names), [](const Variable& x) { return x->id; }); - std::transform( - outputs.begin(), outputs.end(), std::back_inserter(output_names), [](const Variable& x) { return x->id; }); + std::transform(inputs.begin(), + inputs.end(), + std::back_inserter(input_names), + [](const Variable& x) { return x->id; }); + std::transform(outputs.begin(), + outputs.end(), + std::back_inserter(output_names), + [](const Variable& x) { return x->id; }); ss << utils::Join(output_names, ", "); ss << " = "; @@ -535,11 +597,14 @@ std::string _Instruction_::debug_string() const { struct HashVariable { bool operator()(const Variable& lhs, const Variable& rhs) const { - return lhs->id == rhs->id && lhs->shape == rhs->shape && lhs->type == rhs->type; + return lhs->id == rhs->id && lhs->shape == rhs->shape && + lhs->type == rhs->type; } std::size_t operator()(const Variable& var) const { - return std::hash()(var->id + cinn::utils::Join(var->shape, ", ") + cinn::common::Type2Str(var->type)); + return std::hash()(var->id + + cinn::utils::Join(var->shape, ", ") + + cinn::common::Type2Str(var->type)); } }; diff --git a/paddle/cinn/frontend/syntax.h b/paddle/cinn/frontend/syntax.h index 198377ecb286f..2d5b0223d5834 100644 --- a/paddle/cinn/frontend/syntax.h +++ b/paddle/cinn/frontend/syntax.h @@ -57,11 +57,14 @@ struct _Variable_ : public common::Object { struct Variable : public common::Shared<_Variable_> { /** * Constructor. - * @param id_hint The identifier of the variable, if null, a random ID will be assigned. + * @param id_hint The identifier of the variable, if null, a random ID will be + * assigned. */ - explicit Variable(const std::string& id_hint = "") : common::Shared<_Variable_>(common::make_shared<_Variable_>()) { + explicit Variable(const std::string& id_hint = "") + : common::Shared<_Variable_>(common::make_shared<_Variable_>()) { if (!id_hint.empty()) CheckVarNameValid(id_hint); - get()->id = id_hint.empty() ? common::Context::Global().NewName("var") : id_hint; + get()->id = + id_hint.empty() ? common::Context::Global().NewName("var") : id_hint; } void set_id(const std::string& id) { operator->()->id = id; } @@ -85,17 +88,18 @@ class Placeholder { Placeholder(const common::Type& type, const std::vector& shape, absl::string_view id_hint = "", - bool is_const = false) { + bool is_const = false) { if (!id_hint.empty()) CheckVarNameValid(std::string(id_hint)); - id_ = id_hint.empty() ? common::Context::Global().NewName("placeholder") : (std::string)id_hint; - var_ = Variable(id_); - var_->shape = shape; - var_->type = type; + id_ = id_hint.empty() ? common::Context::Global().NewName("placeholder") + : (std::string)id_hint; + var_ = Variable(id_); + var_->shape = shape; + var_->type = type; var_->is_const = is_const; } explicit Placeholder(const Variable& var) { - id_ = var->id; + id_ = var->id; var_ = var; } @@ -138,10 +142,13 @@ struct _Instruction_ : public common::Object { }; /** - * Instruction is the basic computational unit of a Program, similar to the operator concept in a DNN platform. + * Instruction is the basic computational unit of a Program, similar to the + * operator concept in a DNN platform. */ struct Instruction : public common::Shared<_Instruction_> { - explicit Instruction(absl::string_view op_type, const std::vector& inputs = {}, Program* parent = nullptr); + explicit Instruction(absl::string_view op_type, + const std::vector& inputs = {}, + Program* parent = nullptr); /** * Set the inputs of the instruction. @@ -174,9 +181,11 @@ struct Instruction : public common::Shared<_Instruction_> { template T GetAttrs(const std::string& key) const { auto it = get()->attrs.find(key); - CHECK(it != get()->attrs.end()) << "No attribute called [" << key << "] in op " << get()->op_type; + CHECK(it != get()->attrs.end()) + << "No attribute called [" << key << "] in op " << get()->op_type; CHECK(absl::holds_alternative(it->second)) - << "Try get attribute " << key << " from a error type " << typeid(T()).name() << " in op " << get()->op_type; + << "Try get attribute " << key << " from a error type " + << typeid(T()).name() << " in op " << get()->op_type; return absl::get(it->second); } @@ -259,12 +268,19 @@ struct Program { /** * Multiply two matrix. */ - Variable mul(const Variable& a, const Variable& b, int x_num_col_dims = 1, int y_num_col_dims = 1); + Variable mul(const Variable& a, + const Variable& b, + int x_num_col_dims = 1, + int y_num_col_dims = 1); /** * Multiply two matrix. */ - Variable matmul(const Variable& a, const Variable& b, bool trans_a = false, bool trans_b = false, float alpha = 1); + Variable matmul(const Variable& a, + const Variable& b, + bool trans_a = false, + bool trans_b = false, + float alpha = 1); /** * Reshape a tensor. @@ -284,7 +300,8 @@ struct Program { Variable transpose(const Variable& input_vars, const std::vector& axis); -#define SYNTAX_PRIM_UNARY_DECL(name__) Variable primitive_##name__(const Variable& a); +#define SYNTAX_PRIM_UNARY_DECL(name__) \ + Variable primitive_##name__(const Variable& a); SYNTAX_PRIM_UNARY_DECL(exp); SYNTAX_PRIM_UNARY_DECL(erf); @@ -321,7 +338,8 @@ struct Program { SYNTAX_PRIM_UNARY_DECL(abs); SYNTAX_PRIM_UNARY_DECL(rsqrt); -#define SYNTAX_PRIM_BINARY_DECL(name__) Variable primitive_##name__(const Variable& a, const Variable& b); +#define SYNTAX_PRIM_BINARY_DECL(name__) \ + Variable primitive_##name__(const Variable& a, const Variable& b); SYNTAX_PRIM_BINARY_DECL(subtract) SYNTAX_PRIM_BINARY_DECL(divide) SYNTAX_PRIM_BINARY_DECL(floor_divide) @@ -347,7 +365,8 @@ struct Program { SYNTAX_PRIM_BINARY_DECL(right_shift) #define SYNTAX_PRIM_REDUCE_DECL(name__) \ - Variable reduce_##name__(const Variable& a, const std::vector& dim, bool keep_dim = false); + Variable reduce_##name__( \ + const Variable& a, const std::vector& dim, bool keep_dim = false); SYNTAX_PRIM_REDUCE_DECL(sum) SYNTAX_PRIM_REDUCE_DECL(prod) @@ -357,7 +376,8 @@ struct Program { /** broadcast one operand to the target shape * broadcast axes: the target axis which a's ith axis is mapped to * Notes: a's dim should be one or same with the output dim mapped to. - * e.g. if a[64] broadcasts to out[1, 64, 112, 112], then out_shape is {1, 64, 112, 112} and broadcast_axes are {1} + * e.g. if a[64] broadcasts to out[1, 64, 112, 112], then out_shape is {1, 64, + * 112, 112} and broadcast_axes are {1} */ Variable primitive_broadcast_to(const Variable& a, const std::vector& out_shape, @@ -405,15 +425,22 @@ struct Program { * @param attr_store The params like padding, stride, dilation, etc. * @return The result. */ - Variable conv2d(const Variable& a, const Variable& b, const absl::flat_hash_map& attr_store); - Variable layout_transform(const Variable& a, const absl::flat_hash_map& attr_store); - Variable conv2d_NCHWc(const Variable& a, - const Variable& b, - const absl::flat_hash_map& attr_store); - Variable depthwise_conv2d(const Variable& a, - const Variable& b, - const absl::flat_hash_map& attr_store); - Variable pool2d(const Variable& a, const absl::flat_hash_map& attr_store); + Variable conv2d(const Variable& a, + const Variable& b, + const absl::flat_hash_map& attr_store); + Variable layout_transform( + const Variable& a, + const absl::flat_hash_map& attr_store); + Variable conv2d_NCHWc( + const Variable& a, + const Variable& b, + const absl::flat_hash_map& attr_store); + Variable depthwise_conv2d( + const Variable& a, + const Variable& b, + const absl::flat_hash_map& attr_store); + Variable pool2d(const Variable& a, + const absl::flat_hash_map& attr_store); /** * The batchnorm layer can be used as a normalizer function @@ -424,39 +451,47 @@ struct Program { * @param attr_store The params like eplison. * @return The result. */ - Variable batchnorm(const Variable& a, - const Variable& scale, - const Variable& bias, - const Variable& mean, - const Variable& variance, - const absl::flat_hash_map& attr_store); + Variable batchnorm( + const Variable& a, + const Variable& scale, + const Variable& bias, + const Variable& mean, + const Variable& variance, + const absl::flat_hash_map& attr_store); /** * batchnorm composed of primitive ops */ - Variable fused_meta_batchnorm_inference(const Variable& a, - const Variable& scale, - const Variable& bias, - const Variable& mean, - const Variable& variance, - const absl::flat_hash_map& attr_store); - - Variable fused_batchnorm_inference(const Variable& a, - const Variable& scale, - const Variable& bias, - const Variable& mean, - const Variable& variance, - const absl::flat_hash_map& attr_store); - - Variable scale(const Variable& a, const absl::flat_hash_map& attr_store); - - Variable softmax(const Variable& a, const absl::flat_hash_map& attr_store); + Variable fused_meta_batchnorm_inference( + const Variable& a, + const Variable& scale, + const Variable& bias, + const Variable& mean, + const Variable& variance, + const absl::flat_hash_map& attr_store); + + Variable fused_batchnorm_inference( + const Variable& a, + const Variable& scale, + const Variable& bias, + const Variable& mean, + const Variable& variance, + const absl::flat_hash_map& attr_store); + + Variable scale(const Variable& a, + const absl::flat_hash_map& attr_store); + + Variable softmax(const Variable& a, + const absl::flat_hash_map& attr_store); Variable sigmoid(const Variable& a); - Variable slice(const Variable& a, const absl::flat_hash_map& attr_store); + Variable slice(const Variable& a, + const absl::flat_hash_map& attr_store); - Variable dropout_infer(const Variable& a, const absl::flat_hash_map& attr_store); + Variable dropout_infer( + const Variable& a, + const absl::flat_hash_map& attr_store); /** * Get \p i-th instruction. @@ -486,18 +521,19 @@ struct Program { * Load a Paddle model and return a frontend program. * @param model_dir The directory of the model. * @param is_combined Whether the parameters in the Paddle model is combined. - * @returns program, a map from name to variable and a map from variable name in Paddle model to the corresponding in - * program + * @returns program, a map from name to variable and a map from variable name in + * Paddle model to the corresponding in program */ std::tuple, absl::flat_hash_map, absl::flat_hash_map, absl::flat_hash_set> -LoadPaddleProgram(const std::string& model_dir, - hlir::framework::Scope* scope, - std::unordered_map>& input_shape_map, - bool is_combined, - const common::Target& target = common::DefaultHostTarget()); +LoadPaddleProgram( + const std::string& model_dir, + hlir::framework::Scope* scope, + std::unordered_map>& input_shape_map, + bool is_combined, + const common::Target& target = common::DefaultHostTarget()); std::ostream& operator<<(std::ostream& os, const Variable& x); std::ostream& operator<<(std::ostream& os, const Instruction& instr); diff --git a/paddle/cinn/frontend/syntax_test.cc b/paddle/cinn/frontend/syntax_test.cc index 59689cc0c27d4..925cd8fb49d2c 100644 --- a/paddle/cinn/frontend/syntax_test.cc +++ b/paddle/cinn/frontend/syntax_test.cc @@ -61,7 +61,7 @@ TEST(syntax, basic) { } TEST(syntax, program_execute_multi_elementwise_add) { - auto program = CreateAddProgram(); + auto program = CreateAddProgram(); Target target = common::DefaultTarget(); std::unordered_set fetch_ids; auto graph = Optimize(&program, fetch_ids, target); @@ -81,7 +81,7 @@ TEST(syntax, program_execute_multi_elementwise_add) { } TEST(syntax, program_execute_multi_elementwise_add2) { - auto program = CreateAddProgram(); + auto program = CreateAddProgram(); Target target = common::DefaultTarget(); std::unordered_set fetch_ids; auto graph = Optimize(&program, fetch_ids, target); @@ -107,11 +107,12 @@ TEST(syntax, program_execute_multi_elementwise_add2) { TEST(load_paddle_model, fc_execute) { auto scope = std::make_shared(); - std::unordered_map> input_shape_map = {{"A", {1, 30}}}; - auto programTuple = LoadPaddleProgram(FLAGS_model_dir, scope.get(), input_shape_map, false); - auto& program = std::get<0>(programTuple); - auto& var_map = std::get<1>(programTuple); - auto& var_map_paddle_to_program = std::get<2>(programTuple); + std::unordered_map> input_shape_map = {{"A", {1, +30}}}; auto programTuple = LoadPaddleProgram(FLAGS_model_dir, +scope.get(), input_shape_map, false); auto& program = +std::get<0>(programTuple); auto& var_map = +std::get<1>(programTuple); auto& var_map_paddle_to_program = +std::get<2>(programTuple); LOG(INFO) << "program:\n" << *program; @@ -132,10 +133,11 @@ TEST(load_paddle_model, fc_execute) { LOG(INFO) << "scope.names: " << Join(scope->var_names(), ","); const std::string output_name = "fc_0.tmp_2"; - auto tensor = scope->GetTensor(var_map_paddle_to_program.at(output_name)); - LOG(INFO) << "tensor.shape: " << utils::Join(tensor->shape().data(), ","); - auto data = GetTensorData(tensor, target); - for (int i = 0; i < 10; i++) LOG(INFO) << "data: " << data[i]; + auto tensor = +scope->GetTensor(var_map_paddle_to_program.at(output_name)); LOG(INFO) << +"tensor.shape: " << utils::Join(tensor->shape().data(), ","); auto data = +GetTensorData(tensor, target); for (int i = 0; i < 10; i++) LOG(INFO) << +"data: " << data[i]; } */ diff --git a/paddle/cinn/frontend/var_type_utils.h b/paddle/cinn/frontend/var_type_utils.h index a46af22620f72..b11c222da3f80 100644 --- a/paddle/cinn/frontend/var_type_utils.h +++ b/paddle/cinn/frontend/var_type_utils.h @@ -62,7 +62,8 @@ inline common::Type CppVarType2CommonType(paddle::cpp::VarDescAPI::Type type) { "PSTRING", // 29 "SPARSE_COO", // 30 "SPARSE_CSR"}; // 31 - CHECK_LT(static_cast(type), var_type_names_.size()) << "Unknown VarDesc type: " << static_cast(type); + CHECK_LT(static_cast(type), var_type_names_.size()) + << "Unknown VarDesc type: " << static_cast(type); switch (type) { SET_TYPE_CASE_ITEM(BOOL, Bool) @@ -81,14 +82,16 @@ inline common::Type CppVarType2CommonType(paddle::cpp::VarDescAPI::Type type) { // so here need convert back to unkown type. SET_TYPE_CASE_ITEM(RAW, Type) default: - LOG(FATAL) << "Unknown VarDesc type: " << var_type_names_[static_cast(type)] << "(" << static_cast(type) - << ")"; + LOG(FATAL) << "Unknown VarDesc type: " + << var_type_names_[static_cast(type)] << "(" + << static_cast(type) << ")"; } #undef SET_DATA_TYPE_CASE_ITEM return common::Type(); } -inline OpMapperContext::FeedInfo GetFeedInfoFromDesc(const paddle::cpp::VarDesc& desc) { +inline OpMapperContext::FeedInfo GetFeedInfoFromDesc( + const paddle::cpp::VarDesc& desc) { OpMapperContext::FeedInfo info; for (auto num : desc.GetShape()) { info.shape.emplace_back(static_cast(num)); diff --git a/paddle/cinn/hlir/framework/accuracy_checker.cc b/paddle/cinn/hlir/framework/accuracy_checker.cc index 23cf7e6d78593..be58a9443f427 100644 --- a/paddle/cinn/hlir/framework/accuracy_checker.cc +++ b/paddle/cinn/hlir/framework/accuracy_checker.cc @@ -108,11 +108,14 @@ std::string GetTypeString() { } template -std::string DebugString(const Tensor& cpu_tensor, const std::string& name, const CheckResult& res) { +std::string DebugString(const Tensor& cpu_tensor, + const std::string& name, + const CheckResult& res) { std::stringstream ss; - ss << "name=" << name << ", dtype=" << GetTypeString() << ", shape=" << cpu_tensor->shape().data() << ", data=["; - size_t numel = cpu_tensor->shape().numel(); - const T* data = cpu_tensor->data(); + ss << "name=" << name << ", dtype=" << GetTypeString() + << ", shape=" << cpu_tensor->shape().data() << ", data=["; + size_t numel = cpu_tensor->shape().numel(); + const T* data = cpu_tensor->data(); size_t print_num = 5L; if (FLAGS_cinn_self_check_accuracy_num < 0) { print_num = numel; @@ -191,10 +194,12 @@ std::string AccuracyChecker::operator()(const std::string& arg_name) { } } -std::string AccuracyChecker::operator()(const std::map* name2podargs, - const std::string& arg_name) { +std::string AccuracyChecker::operator()( + const std::map* name2podargs, + const std::string& arg_name) { CHECK(name2podargs) << "name2podargs should not be nullptr."; - const cinn_buffer_t* buffer = cinn_pod_value_to_buffer_p(const_cast(&name2podargs->at(arg_name))); + const cinn_buffer_t* buffer = cinn_pod_value_to_buffer_p( + const_cast(&name2podargs->at(arg_name))); if (buffer->type == cinn_float32_t()) { return CheckBuffer(buffer, arg_name); } else if (buffer->type == cinn_float64_t()) { @@ -228,7 +233,8 @@ std::string AccuracyChecker::operator()(const std::map -std::string AccuracyChecker::CheckTensor(const Tensor& tensor, const std::string& arg_name) { +std::string AccuracyChecker::CheckTensor(const Tensor& tensor, + const std::string& arg_name) { Tensor cpu_tensor; cpu_tensor->Resize(tensor->shape()); T* dst = cpu_tensor->mutable_data(common::DefaultHostTarget()); @@ -237,13 +243,14 @@ std::string AccuracyChecker::CheckTensor(const Tensor& tensor, const std::string size_t numel = tensor->shape().numel(); MemcpyDeviceToHost(src, numel, dst); - auto res = CheckNanOrInf(cpu_tensor); + auto res = CheckNanOrInf(cpu_tensor); auto result_str = DebugString(cpu_tensor, arg_name, res); return result_str; } template -std::string AccuracyChecker::CheckBuffer(const cinn_buffer_t* buffer, const std::string& arg_name) { +std::string AccuracyChecker::CheckBuffer(const cinn_buffer_t* buffer, + const std::string& arg_name) { std::vector shape; shape.resize(buffer->dimensions); for (size_t i = 0; i < shape.size(); ++i) { @@ -258,7 +265,7 @@ std::string AccuracyChecker::CheckBuffer(const cinn_buffer_t* buffer, const std: size_t numel = cpu_tensor->shape().numel(); MemcpyDeviceToHost(src, numel, dst); - auto res = CheckNanOrInf(cpu_tensor); + auto res = CheckNanOrInf(cpu_tensor); auto result_str = DebugString(cpu_tensor, arg_name, res); return result_str; } @@ -283,9 +290,9 @@ void AccuracyChecker::MemcpyDeviceToHost(const T* src, size_t numel, T* dst) { template CheckResult AccuracyChecker::CheckNanOrInf(const Tensor& cpu_tensor) { bool zero_flag = true; - bool one_flag = true; - size_t numel = cpu_tensor->shape().numel(); - const T* data = cpu_tensor->data(); + bool one_flag = true; + size_t numel = cpu_tensor->shape().numel(); + const T* data = cpu_tensor->data(); for (size_t i = 0; i < numel; ++i) { if (std::isnan(data[i])) { return CheckResult::kNaN; diff --git a/paddle/cinn/hlir/framework/accuracy_checker.h b/paddle/cinn/hlir/framework/accuracy_checker.h index db24d4d84b9e2..a27c1f5bc7344 100644 --- a/paddle/cinn/hlir/framework/accuracy_checker.h +++ b/paddle/cinn/hlir/framework/accuracy_checker.h @@ -25,17 +25,21 @@ enum CheckResult { kOK = 0, kZero = 1, kNaN = 2, kInf = 3, kOne = 4 }; class AccuracyChecker { public: - AccuracyChecker(const Target& target, Scope* scope) : target_(target), scope_(scope) {} + AccuracyChecker(const Target& target, Scope* scope) + : target_(target), scope_(scope) {} std::string operator()(const std::string& arg_name); - std::string operator()(const std::map* name2podargs, const std::string& arg_name); + std::string operator()( + const std::map* name2podargs, + const std::string& arg_name); private: template std::string CheckTensor(const Tensor& tensor, const std::string& arg_name); template - std::string CheckBuffer(const cinn_buffer_t* buffer, const std::string& arg_name); + std::string CheckBuffer(const cinn_buffer_t* buffer, + const std::string& arg_name); template void MemcpyDeviceToHost(const T* src, size_t numel, T* dst); diff --git a/paddle/cinn/hlir/framework/accuracy_checker_test.cc b/paddle/cinn/hlir/framework/accuracy_checker_test.cc index 39e22c7692eff..b712752953709 100644 --- a/paddle/cinn/hlir/framework/accuracy_checker_test.cc +++ b/paddle/cinn/hlir/framework/accuracy_checker_test.cc @@ -43,14 +43,17 @@ void GenerateRandomData(float* data, size_t numel, bool generate_nan) { void SetRandomTensor(Tensor tensor, Target target, bool generate_nan) { size_t numel = tensor->shape().numel(); - float* dst = tensor->mutable_data(target); + float* dst = tensor->mutable_data(target); std::vector random_nan_vec(numel); GenerateRandomData(random_nan_vec.data(), numel, generate_nan); #ifdef CINN_WITH_CUDA if (target == common::DefaultNVGPUTarget()) { - cudaMemcpy(dst, random_nan_vec.data(), numel * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(dst, + random_nan_vec.data(), + numel * sizeof(float), + cudaMemcpyHostToDevice); } #endif if (target == common::DefaultHostTarget()) { @@ -78,10 +81,12 @@ std::unique_ptr GetLoweredFunc(Target target) { lang::Placeholder x("x", {m, n}); auto y = Compute( - {m, n}, [=](Expr i, Expr j) { return lang::CallExtern("sqrt", {x(i, j)}); }, "y"); + {m, n}, + [=](Expr i, Expr j) { return lang::CallExtern("sqrt", {x(i, j)}); }, + "y"); auto stages = CreateStages({y}); - auto fn = Lower("fn_sqrt", stages, {x, y}); + auto fn = Lower("fn_sqrt", stages, {x, y}); ir::Module::Builder builder("some_module", target); builder.AddFunction(fn); @@ -105,7 +110,7 @@ TEST(AccuracyChecker, instruction) { Scope scope; InstantiateScope(&scope, target); - auto jit = GetLoweredFunc(target); + auto jit = GetLoweredFunc(target); auto fn_ptr = jit->Lookup("fn_sqrt"); CHECK(fn_ptr); @@ -122,17 +127,21 @@ TEST(AccuracyChecker, instruction) { void InitName2PodArgs(Target target, std::vector* args_buffer, std::map* name2podargs) { - auto* default_memory_mng = MemoryManager::Global().RetrieveSafely(target.arch); + auto* default_memory_mng = + MemoryManager::Global().RetrieveSafely(target.arch); - int count = 0; + int count = 0; const auto& shape = Shape({16, 16}); - size_t numel = shape.numel(); + size_t numel = shape.numel(); for (const auto& name : std::vector({"x", "y"})) { auto* buffer = &args_buffer->at(count++); buffer->type = cinn_float32_t(); - buffer->resize(reinterpret_cast(shape.data().data()), shape.size()); - buffer->memory = reinterpret_cast(default_memory_mng->malloc(numel * sizeof(float))); - float* data = reinterpret_cast(buffer->memory); + buffer->resize( + reinterpret_cast(shape.data().data()), + shape.size()); + buffer->memory = reinterpret_cast( + default_memory_mng->malloc(numel * sizeof(float))); + float* data = reinterpret_cast(buffer->memory); GenerateRandomData(data, numel, false); name2podargs->emplace(name, buffer); } @@ -144,7 +153,7 @@ TEST(AccuracyChecker, instruction_podargs) { std::map name2podargs; InitName2PodArgs(target, &args_buffer, &name2podargs); - auto jit = GetLoweredFunc(target); + auto jit = GetLoweredFunc(target); auto fn_ptr = jit->Lookup("fn_sqrt"); CHECK(fn_ptr); diff --git a/paddle/cinn/hlir/framework/buffer.cc b/paddle/cinn/hlir/framework/buffer.cc index 8051236acd4db..83427abe9cbe7 100755 --- a/paddle/cinn/hlir/framework/buffer.cc +++ b/paddle/cinn/hlir/framework/buffer.cc @@ -25,9 +25,9 @@ void Buffer::Resize(uint32_t size) { } if (size_ != size) { - data_.memory = reinterpret_cast(Malloc(size)); + data_.memory = reinterpret_cast(Malloc(size)); data_.memory_size = size; - size_ = size; + size_ = size; } } @@ -38,14 +38,14 @@ void Buffer::Resize(uint32_t alignment, uint32_t size) { } if (size_ != size) { - data_.memory = reinterpret_cast(AlignedAlloc(alignment, size)); + data_.memory = reinterpret_cast(AlignedAlloc(alignment, size)); data_.memory_size = size; - size_ = size; + size_ = size; } } void Buffer::SetTarget(const common::Target& target) { - target_ = target; + target_ = target; memory_mng_cache_ = MemoryManager::Global().RetrieveSafely(target_.arch); } @@ -67,7 +67,9 @@ void Buffer::Resize(uint32_t size, const common::Target& target) { Resize(size); } -void Buffer::Resize(uint32_t alignment, uint32_t size, const common::Target& target) { +void Buffer::Resize(uint32_t alignment, + uint32_t size, + const common::Target& target) { if (target.arch != target_.arch) { Free(); SetTarget(target); @@ -83,7 +85,9 @@ void Buffer::ResizeLazy(uint32_t size, const common::Target& target) { ResizeLazy(size); } -void Buffer::ResizeLazy(uint32_t alignment, uint32_t size, const common::Target& target) { +void Buffer::ResizeLazy(uint32_t alignment, + uint32_t size, + const common::Target& target) { if (target.arch != target_.arch) { Free(); SetTarget(target); diff --git a/paddle/cinn/hlir/framework/buffer.h b/paddle/cinn/hlir/framework/buffer.h index f41a717a80c44..4d5e7cb0afbea 100644 --- a/paddle/cinn/hlir/framework/buffer.h +++ b/paddle/cinn/hlir/framework/buffer.h @@ -29,7 +29,8 @@ namespace hlir { namespace framework { /** - * Buffer helps to hold the memory, and offers a set of methods to help manage the memory. + * Buffer helps to hold the memory, and offers a set of methods to help manage + * the memory. */ struct Buffer final { Buffer() = default; @@ -49,7 +50,9 @@ struct Buffer final { //! Lazily resize the memory to \p size in target \p target. void ResizeLazy(uint32_t size, const common::Target& target); - void ResizeLazy(uint32_t alignment, uint32_t size, const common::Target& target); + void ResizeLazy(uint32_t alignment, + uint32_t size, + const common::Target& target); void SetTarget(const common::Target& target); @@ -68,7 +71,8 @@ struct Buffer final { return memory_mng_cache_->malloc(size); } - inline void* AlignedAlloc(uint32_t alignment, uint32_t size) CINN_RESULT_SHOULD_USE { + inline void* AlignedAlloc(uint32_t alignment, + uint32_t size) CINN_RESULT_SHOULD_USE { CHECK(memory_mng_cache_) << "Should set target first"; return memory_mng_cache_->aligned_alloc(alignment, size); } diff --git a/paddle/cinn/hlir/framework/graph.cc b/paddle/cinn/hlir/framework/graph.cc index 17666ea77591d..6ebd405aeed7f 100644 --- a/paddle/cinn/hlir/framework/graph.cc +++ b/paddle/cinn/hlir/framework/graph.cc @@ -39,10 +39,11 @@ void Graph::Initialize(const frontend::Program& prog, int counter = 0; for (size_t i = 0; i < prog.size(); i++) { auto temp = prog[i]; - VLOG(3) << "operator [" << temp->op_type << "] has [" << temp->inputs.size() << "] inputs, and [" - << temp->outputs.size() << "] outputs"; - Node* node_tmp = - new Node(Operator::Get(temp->op_type), temp->op_type, temp->op_type + "_" + std::to_string(counter++)); + VLOG(3) << "operator [" << temp->op_type << "] has [" << temp->inputs.size() + << "] inputs, and [" << temp->outputs.size() << "] outputs"; + Node* node_tmp = new Node(Operator::Get(temp->op_type), + temp->op_type, + temp->op_type + "_" + std::to_string(counter++)); Shared node_ptr(node_tmp); node_tmp->attrs.attr_store = temp->attrs; for (auto& input_v : temp->inputs) { @@ -50,7 +51,8 @@ void Graph::Initialize(const frontend::Program& prog, if (!graph_node) { dtype_dict[input_v->id] = input_v->type; shape_dict[input_v->id] = input_v->shape; - NodeData* input_data = new NodeData(nullptr, 0, 0, input_v->id, input_v.is_const()); + NodeData* input_data = + new NodeData(nullptr, 0, 0, input_v->id, input_v.is_const()); input_data->LinkTo(node_tmp); this->RegisterNode(input_v->id, input_data); } else { @@ -63,7 +65,7 @@ void Graph::Initialize(const frontend::Program& prog, if (!graph_node) { dtype_dict[output_v->id] = output_v->type; shape_dict[output_v->id] = output_v->shape; - auto* output_data = new NodeData(node_ptr, out_idx++, 0, output_v->id); + auto* output_data = new NodeData(node_ptr, out_idx++, 0, output_v->id); if (fetch_var_ids.count(output_v->id)) { outputs.push_back(output_data); } @@ -73,7 +75,7 @@ void Graph::Initialize(const frontend::Program& prog, node_tmp->LinkTo(graph_node->as()); graph_node->as()->set_const(false); graph_node->as()->output_index = out_idx++; - graph_node->as()->source_node = node_ptr; + graph_node->as()->source_node = node_ptr; } } this->RegisterNode(node_tmp->id(), node_tmp); @@ -87,7 +89,8 @@ std::vector> Graph::FusionGroupsToGroups() { if (fusion_groups.empty()) { // if no fusion_groups, the graph will be treated as a big group const auto& nodes = this->CollectNodes([](const common::GraphNode* node) { - return node->safe_as() != nullptr && node->safe_as()->op() != nullptr; + return node->safe_as() != nullptr && + node->safe_as()->op() != nullptr; }); std::vector group; group.reserve(nodes.size()); @@ -104,7 +107,8 @@ std::vector> Graph::FusionGroupsToGroups() { return groups; } -std::string Graph::DebugGroupedGraph(const std::unordered_set& fetch_var_ids) { +std::string Graph::DebugGroupedGraph( + const std::unordered_set& fetch_var_ids) { if (!fusion_groups.empty()) { return DebugGroupedGraph(FusionGroupsToGroups(), fetch_var_ids); } @@ -128,10 +132,13 @@ std::string Graph::DebugGroupedGraph(const std::unordered_set& fetc return debug_str.str(); } -std::string Graph::DebugGroupedGraph(const std::vector& group, - const std::unordered_set& fetch_var_ids) { - auto& shape_dict = HasAttr("infershape") ? GetAttrs("infershape") : ShapeDict{}; - auto& dtype_dict = HasAttr("inferdtype") ? GetAttrs("inferdtype") : DTypeDict{}; +std::string Graph::DebugGroupedGraph( + const std::vector& group, + const std::unordered_set& fetch_var_ids) { + auto& shape_dict = + HasAttr("infershape") ? GetAttrs("infershape") : ShapeDict{}; + auto& dtype_dict = + HasAttr("inferdtype") ? GetAttrs("inferdtype") : DTypeDict{}; auto get_all_out_names = [](const std::vector& nodes) { // collect all op's output var name in group @@ -144,8 +151,10 @@ std::string Graph::DebugGroupedGraph(const std::vector& group, } return out_names; }; - auto get_feed_list = [](const std::vector& nodes, const std::unordered_set& out_names) { - // if the op's input var name cannot found in out_names, it is the group's feed var + auto get_feed_list = [](const std::vector& nodes, + const std::unordered_set& out_names) { + // if the op's input var name cannot found in out_names, it is the group's + // feed var std::unordered_set feed_list; for (auto* node : nodes) { for (const auto& link : node->inlinks()) { @@ -157,7 +166,8 @@ std::string Graph::DebugGroupedGraph(const std::vector& group, } return std::vector(feed_list.begin(), feed_list.end()); }; - auto get_fetch_list = [&](const std::vector& nodes, const std::unordered_set& out_names) { + auto get_fetch_list = [&](const std::vector& nodes, + const std::unordered_set& out_names) { // if the fetch var in out_names, it's the group's fetch var, otherwise not std::unordered_set in_names; for (auto* node : nodes) { @@ -169,7 +179,8 @@ std::string Graph::DebugGroupedGraph(const std::vector& group, std::vector fetch_list; for (const auto& out : out_names) { if (!in_names.count(out) || fetch_var_ids.count(out)) { - // if the var not any op's input, or in fetch_var_ids, it's the group's fetch list + // if the var not any op's input, or in fetch_var_ids, it's the group's + // fetch list fetch_list.emplace_back(out); } } @@ -182,12 +193,15 @@ std::string Graph::DebugGroupedGraph(const std::vector& group, std::stringstream debug_str; // generator python test code for (const auto& id : feed_list) { - const auto& shape = shape_dict.count(id) ? cinn::utils::Join(shape_dict.at(id), ", ") : "-1"; - const auto& dtype = dtype_dict.count(id) ? common::Type2Str(dtype_dict.at(id)) : "float32"; + const auto& shape = shape_dict.count(id) + ? cinn::utils::Join(shape_dict.at(id), ", ") + : "-1"; + const auto& dtype = + dtype_dict.count(id) ? common::Type2Str(dtype_dict.at(id)) : "float32"; // generator python create_input code - debug_str << " " << id << " = builder.create_input(type=\"" << dtype << "\", shape=[" << shape << "], id_hint=\"" - << id << "\")\n"; + debug_str << " " << id << " = builder.create_input(type=\"" << dtype + << "\", shape=[" << shape << "], id_hint=\"" << id << "\")\n"; } debug_str << "\n"; // generator builder.op code @@ -196,17 +210,22 @@ std::string Graph::DebugGroupedGraph(const std::vector& group, } debug_str << "\n"; // generator - debug_str << " feed_list = [" << cinn::utils::Join(feed_list, ", ") << "]\n"; - debug_str << " fetch_list = [" << cinn::utils::Join(get_fetch_list(group, out_names), ", ") << "]\n"; + debug_str << " feed_list = [" << cinn::utils::Join(feed_list, ", ") + << "]\n"; + debug_str << " fetch_list = [" + << cinn::utils::Join(get_fetch_list(group, out_names), ", ") + << "]\n"; return debug_str.str(); } -std::string Graph::GenerateGroupPythonCode(const std::vector& group, - const std::unordered_set& fetch_var_ids) { +std::string Graph::GenerateGroupPythonCode( + const std::vector& group, + const std::unordered_set& fetch_var_ids) { std::stringstream ss; ss << "#!/usr/bin/env python3\n"; - ss << "# Please set \"export PYTHONPATH=${CINN_ROOT}/build/python:${PYTHONPATH}\" first\n"; + ss << "# Please set \"export " + "PYTHONPATH=${CINN_ROOT}/build/python:${PYTHONPATH}\" first\n"; ss << "\n"; ss << "import unittest\n"; ss << "import numpy as np\n"; @@ -222,8 +241,10 @@ std::string Graph::GenerateGroupPythonCode(const std::vector& group, ss << "\n"; ss << " prog = builder.build()\n"; ss << "\n"; - ss << " feed_data = [OpTest.random(shape=var.shape(), dtype=var.type()) for var in feed_list]\n"; - ss << " result = prog.build_and_get_output(DefaultNVGPUTarget(), feed_list, feed_data, fetch_list)\n"; + ss << " feed_data = [OpTest.random(shape=var.shape(), dtype=var.type()) " + "for var in feed_list]\n"; + ss << " result = prog.build_and_get_output(DefaultNVGPUTarget(), " + "feed_list, feed_data, fetch_list)\n"; ss << "\n"; ss << " result = [res.numpy(DefaultNVGPUTarget()) for res in result]\n"; ss << " for i in range(len(result)):\n"; @@ -241,8 +262,9 @@ std::string Graph::GenerateGroupPythonCode(const std::vector& group, return ss.str(); } -std::string Graph::DebugGroupedGraph(const std::vector>& groups, - const std::unordered_set& fetch_var_ids) { +std::string Graph::DebugGroupedGraph( + const std::vector>& groups, + const std::unordered_set& fetch_var_ids) { std::unordered_set fetch_list; if (!fetch_var_ids.empty()) { fetch_list = fetch_var_ids; @@ -265,32 +287,44 @@ std::string Graph::DebugGroupedGraph(const std::vector>& grou debug_str << "\n"; debug_str << "graph_fetch_list=[" - << cinn::utils::Join(std::vector(fetch_list.begin(), fetch_list.end()), ", ") << "]\n"; + << cinn::utils::Join(std::vector(fetch_list.begin(), + fetch_list.end()), + ", ") + << "]\n"; return debug_str.str(); } -void Graph::VisualizeGroupedGraph(const std::unordered_set& fetch_var_ids) { +void Graph::VisualizeGroupedGraph( + const std::unordered_set& fetch_var_ids) { VisualizeGroupedGraph(FusionGroupsToGroups(), fetch_var_ids); } -void Graph::VisualizeGroupedGraph(const std::vector>& origin_groups, - const std::unordered_set& fetch_var_ids) { - if (cinn::runtime::CheckStringFlagFalse(FLAGS_cinn_fusion_groups_graphviz_dir)) { +void Graph::VisualizeGroupedGraph( + const std::vector>& origin_groups, + const std::unordered_set& fetch_var_ids) { + if (cinn::runtime::CheckStringFlagFalse( + FLAGS_cinn_fusion_groups_graphviz_dir)) { return; } int viz_id = viz_count_.fetch_add(1); { // create base Directory - viz_path_ = utils::StringFormat("%s/fusion_groups_%d/", FLAGS_cinn_fusion_groups_graphviz_dir.c_str(), viz_id); - if (!MakeDirectory(viz_path_, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) { - LOG_IF(WARNING, viz_id == 0) << "Failed to make directory: \"" << viz_path_ - << "\", the CINN subgraph's fusion group information will not print."; + viz_path_ = + utils::StringFormat("%s/fusion_groups_%d/", + FLAGS_cinn_fusion_groups_graphviz_dir.c_str(), + viz_id); + if (!MakeDirectory(viz_path_, + S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) { + LOG_IF(WARNING, viz_id == 0) + << "Failed to make directory: \"" << viz_path_ + << "\", the CINN subgraph's fusion group information will not print."; viz_path_.clear(); return; } - LOG_IF(INFO, viz_id == 0) << "The CINN subgraph's fusion group information will writing into path: \"" + LOG_IF(INFO, viz_id == 0) << "The CINN subgraph's fusion group information " + "will writing into path: \"" << FLAGS_cinn_fusion_groups_graphviz_dir << "\""; } @@ -298,9 +332,11 @@ void Graph::VisualizeGroupedGraph(const std::vector>& origin_ { // save python test file std::string py_test_path = viz_path_ + "/tests/"; - if (!MakeDirectory(py_test_path, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) { - LOG_IF(WARNING, viz_id == 0) << "Failed to make directory: \"" << py_test_path - << "\", the CINN subgraph's python test file will not generate."; + if (!MakeDirectory(py_test_path, + S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) { + LOG_IF(WARNING, viz_id == 0) + << "Failed to make directory: \"" << py_test_path + << "\", the CINN subgraph's python test file will not generate."; py_test_path.clear(); } if (!py_test_path.empty()) { @@ -312,14 +348,17 @@ void Graph::VisualizeGroupedGraph(const std::vector>& origin_ } Summary(groups, viz_path_); - WriteToFile(viz_path_ + "grouped_graph.dot", VisualizeGraph(groups, fetch_var_ids)); + WriteToFile(viz_path_ + "grouped_graph.dot", + VisualizeGraph(groups, fetch_var_ids)); { // save each group's graphviz dot file std::string group_path = viz_path_ + "/groups/"; - if (!MakeDirectory(group_path, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) { - LOG_IF(WARNING, viz_id == 0) << "Failed to make directory: \"" << group_path - << "\", the CINN subgraph's group graphviz file will not save."; + if (!MakeDirectory(group_path, + S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) { + LOG_IF(WARNING, viz_id == 0) + << "Failed to make directory: \"" << group_path + << "\", the CINN subgraph's group graphviz file will not save."; group_path.clear(); } if (!group_path.empty()) { @@ -331,14 +370,18 @@ void Graph::VisualizeGroupedGraph(const std::vector>& origin_ } } -std::string Graph::VisualizeGraph(const std::unordered_set& fetch_var_ids) { +std::string Graph::VisualizeGraph( + const std::unordered_set& fetch_var_ids) { return VisualizeGraph(FusionGroupsToGroups(), fetch_var_ids); } -std::string Graph::VisualizeGraph(const std::vector>& groups, - const std::unordered_set& fetch_var_ids) { - auto& shape_dict = HasAttr("infershape") ? GetAttrs("infershape") : ShapeDict{}; - auto& dtype_dict = HasAttr("inferdtype") ? GetAttrs("inferdtype") : DTypeDict{}; +std::string Graph::VisualizeGraph( + const std::vector>& groups, + const std::unordered_set& fetch_var_ids) { + auto& shape_dict = + HasAttr("infershape") ? GetAttrs("infershape") : ShapeDict{}; + auto& dtype_dict = + HasAttr("inferdtype") ? GetAttrs("inferdtype") : DTypeDict{}; std::unordered_map recompute_nodes; FindRecomputeNodes(groups, &recompute_nodes); @@ -371,14 +414,18 @@ std::string Graph::VisualizeGraph(const std::vector>& groups, return dot(); } -std::vector Graph::VisualizeGroups(const std::unordered_set& fetch_var_ids) { +std::vector Graph::VisualizeGroups( + const std::unordered_set& fetch_var_ids) { return VisualizeGroups(FusionGroupsToGroups(), fetch_var_ids); } -std::vector Graph::VisualizeGroups(const std::vector>& groups, - const std::unordered_set& fetch_var_ids) { - auto& shape_dict = HasAttr("infershape") ? GetAttrs("infershape") : ShapeDict{}; - auto& dtype_dict = HasAttr("inferdtype") ? GetAttrs("inferdtype") : DTypeDict{}; +std::vector Graph::VisualizeGroups( + const std::vector>& groups, + const std::unordered_set& fetch_var_ids) { + auto& shape_dict = + HasAttr("infershape") ? GetAttrs("infershape") : ShapeDict{}; + auto& dtype_dict = + HasAttr("inferdtype") ? GetAttrs("inferdtype") : DTypeDict{}; std::unordered_map recompute_nodes; FindRecomputeNodes(groups, &recompute_nodes); @@ -466,7 +513,9 @@ std::unordered_set Graph::Group::GetInputNodeDatas() { continue; } - if (std::find(this->input_names.begin(), this->input_names.end(), input_data->id()) != this->input_names.end()) { + if (std::find(this->input_names.begin(), + this->input_names.end(), + input_data->id()) != this->input_names.end()) { // if the input data in group's input_names group_inputs.insert(input_data); continue; @@ -495,14 +544,18 @@ std::unordered_set Graph::Group::GetOutputNodeDatas() { } void Graph::SaveSourceCode(const std::string& code) { - if (cinn::runtime::CheckStringFlagFalse(FLAGS_cinn_fusion_groups_graphviz_dir) || viz_path_.empty()) { + if (cinn::runtime::CheckStringFlagFalse( + FLAGS_cinn_fusion_groups_graphviz_dir) || + viz_path_.empty()) { return; } WriteToFile(viz_path_ + "source_code.cu", code); } void Graph::SavePTXCode(const std::string& ptx) { - if (cinn::runtime::CheckStringFlagFalse(FLAGS_cinn_fusion_groups_graphviz_dir) || viz_path_.empty()) { + if (cinn::runtime::CheckStringFlagFalse( + FLAGS_cinn_fusion_groups_graphviz_dir) || + viz_path_.empty()) { return; } diff --git a/paddle/cinn/hlir/framework/graph.h b/paddle/cinn/hlir/framework/graph.h index 47002c8a34708..39f9ae4a12ce5 100644 --- a/paddle/cinn/hlir/framework/graph.h +++ b/paddle/cinn/hlir/framework/graph.h @@ -39,7 +39,9 @@ class Graph : public cinn::common::Graph { std::unordered_set fetch_var_ids; Initialize(prog, fetch_var_ids, target); } - Graph(const frontend::Program& prog, const std::unordered_set& fetch_var_ids, const Target& target) { + Graph(const frontend::Program& prog, + const std::unordered_set& fetch_var_ids, + const Target& target) { Initialize(prog, fetch_var_ids, target); } @@ -85,18 +87,28 @@ class Graph : public cinn::common::Graph { } }; struct SharedGroupComparator { - bool operator()(const std::shared_ptr& first, const std::shared_ptr& second) const noexcept { + bool operator()(const std::shared_ptr& first, + const std::shared_ptr& second) const noexcept { return first.get() == second.get(); } }; // input groups - std::unordered_set, SharedGroupHasher, SharedGroupComparator> producer_groups; + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator> + producer_groups; // output grous - std::unordered_set, SharedGroupHasher, SharedGroupComparator> consumer_groups; + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator> + consumer_groups; // fused sub-groups, used for fusion merge pass std::vector> fused_sub_groups; // if as sub-group, used for belong groups. - std::unordered_set, SharedGroupHasher, SharedGroupComparator> belong_groups; + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator> + belong_groups; // for op lowering. std::vector input_names; @@ -106,7 +118,8 @@ class Graph : public cinn::common::Graph { if (fused_sub_groups.size()) { std::vector tmp_nodes; for (auto& group : fused_sub_groups) { - tmp_nodes.insert(tmp_nodes.end(), group->nodes.begin(), group->nodes.end()); + tmp_nodes.insert( + tmp_nodes.end(), group->nodes.begin(), group->nodes.end()); } return tmp_nodes; } else { @@ -129,7 +142,9 @@ class Graph : public cinn::common::Graph { }; std::vector> fusion_groups; - void RegisterNode(size_t key, Node* node) { this->common::Graph::RegisterNode(key, node->as()); } + void RegisterNode(size_t key, Node* node) { + this->common::Graph::RegisterNode(key, node->as()); + } void RegisterNode(size_t key, NodeData* node) { this->common::Graph::RegisterNode(key, node->as()); } @@ -149,7 +164,8 @@ class Graph : public cinn::common::Graph { template inline const T& GetAttrs(const std::string& attr_name) const { auto it = attrs.find(attr_name); - CHECK(it != attrs.end()) << "Cannot find attribute [" << attr_name << "] in the graph"; + CHECK(it != attrs.end()) + << "Cannot find attribute [" << attr_name << "] in the graph"; return absl::any_cast(*it->second); } @@ -162,7 +178,8 @@ class Graph : public cinn::common::Graph { template inline T& GetMutableAttrs(const std::string& attr_name) { auto it = attrs.find(attr_name); - CHECK(it != attrs.end()) << "Cannot find attribute [" << attr_name << "] in the graph"; + CHECK(it != attrs.end()) + << "Cannot find attribute [" << attr_name << "] in the graph"; return absl::any_cast(*it->second); } @@ -179,45 +196,56 @@ class Graph : public cinn::common::Graph { /** * \brief Debug the grouped graph according to fusion_groups. */ - std::string DebugGroupedGraph(const std::unordered_set& fetch_var_ids = {}); - std::string DebugGroupedGraph(const std::vector& group, - const std::unordered_set& fetch_var_ids = {}); + std::string DebugGroupedGraph( + const std::unordered_set& fetch_var_ids = {}); + std::string DebugGroupedGraph( + const std::vector& group, + const std::unordered_set& fetch_var_ids = {}); /** - * \brief Debug the grouped graph with GraphViz dot format according to fusion_groups. + * \brief Debug the grouped graph with GraphViz dot format according to + * fusion_groups. */ - std::string VisualizeGraph(const std::unordered_set& fetch_var_ids = {}); - std::vector VisualizeGroups(const std::unordered_set& fetch_var_ids = {}); + std::string VisualizeGraph( + const std::unordered_set& fetch_var_ids = {}); + std::vector VisualizeGroups( + const std::unordered_set& fetch_var_ids = {}); /** * \brief Genereate the python test code for group test */ - std::string GenerateGroupPythonCode(const std::vector& group, - const std::unordered_set& fetch_var_ids = {}); + std::string GenerateGroupPythonCode( + const std::vector& group, + const std::unordered_set& fetch_var_ids = {}); /** * \brief Visualize the grouped graph according to fusion_groups. */ - void VisualizeGroupedGraph(const std::unordered_set& fetch_var_ids = {}); + void VisualizeGroupedGraph( + const std::unordered_set& fetch_var_ids = {}); /** * \brief Visualize the grouped graph according to user specified groups. */ - void VisualizeGroupedGraph(const std::vector>& groups, - const std::unordered_set& fetch_var_ids = {}); + void VisualizeGroupedGraph( + const std::vector>& groups, + const std::unordered_set& fetch_var_ids = {}); void SaveSourceCode(const std::string& code); void SavePTXCode(const std::string& ptx); private: - std::string DebugGroupedGraph(const std::vector>& groups, - const std::unordered_set& fetch_var_ids = {}); + std::string DebugGroupedGraph( + const std::vector>& groups, + const std::unordered_set& fetch_var_ids = {}); - std::string VisualizeGraph(const std::vector>& groups, - const std::unordered_set& fetch_var_ids = {}); + std::string VisualizeGraph( + const std::vector>& groups, + const std::unordered_set& fetch_var_ids = {}); - std::vector VisualizeGroups(const std::vector>& groups, - const std::unordered_set& fetch_var_ids = {}); + std::vector VisualizeGroups( + const std::vector>& groups, + const std::unordered_set& fetch_var_ids = {}); std::vector> FusionGroupsToGroups(); diff --git a/paddle/cinn/hlir/framework/graph_compiler.cc b/paddle/cinn/hlir/framework/graph_compiler.cc index 6cbfa5d677624..575283987cd5a 100644 --- a/paddle/cinn/hlir/framework/graph_compiler.cc +++ b/paddle/cinn/hlir/framework/graph_compiler.cc @@ -51,7 +51,8 @@ void AddAttrs(const absl::flat_hash_map& attrs_store, instr->attrs.push_back(absl::get(attrs_store.at(attr))); break; case 3: - instr->str_attrs.push_back(absl::get(attrs_store.at(attr))); + instr->str_attrs.push_back( + absl::get(attrs_store.at(attr))); break; case 5: auto temp = absl::get>(attrs_store.at(attr)); @@ -64,7 +65,8 @@ void AddAttrs(const absl::flat_hash_map& attrs_store, } } -Program::Program(const std::shared_ptr& scope, std::vector>&& instrs) +Program::Program(const std::shared_ptr& scope, + std::vector>&& instrs) : scope_(scope) { for (auto& ins : instrs) { if (ins->pre_run) { @@ -75,7 +77,8 @@ Program::Program(const std::shared_ptr& scope, std::vector* name2podargs) { +void Program::PreRun( + const std::map* name2podargs) { for (auto& ins : prerun_instrs_) { ins->Run(name2podargs); } @@ -86,7 +89,8 @@ void Program::PreRun(const std::map* name2podargs } } -void Program::Export(const std::vector& persistent_vars, const std::string& filename) { +void Program::Export(const std::vector& persistent_vars, + const std::string& filename) { auto writeplaceholder = [=](int s, int n, FILE* f) -> int { int pos = ftell(f); for (int i = 0; i < s * n; i++) { @@ -105,7 +109,7 @@ void Program::Export(const std::vector& persistent_vars, const std: setplaceholder(p, &cur, 4, 1, f); }; auto padding = [=](int alignment, uint8_t value, FILE* f) { - int cur = ftell(f); + int cur = ftell(f); int padding = (alignment - (cur % alignment)) % alignment; for (int i = 0; i < padding; i++) { fwrite(&value, 1, 1, f); @@ -129,7 +133,7 @@ void Program::Export(const std::vector& persistent_vars, const std: // varname list int varnamesec = writeplaceholder(4, 1, f); - int namesnum = varnames.size(); + int namesnum = varnames.size(); fwrite(&namesnum, 4, 1, f); int nameoffset = writeplaceholder(4, namesnum, f); for (int i = 0; i < namesnum; i++) { @@ -148,12 +152,14 @@ void Program::Export(const std::vector& persistent_vars, const std: tellplaceholder(bufoffset, f); std::vector> pvars; for (auto& varname : varnames) { - std::string name = (std::string)varname; - auto t = scope_->GetTensor(name); + std::string name = (std::string)varname; + auto t = scope_->GetTensor(name); cinn_buffer_t buffer = *t->buffer(); - buffer.memory = (uint8_t*)0; - if (std::find(persistent_vars.begin(), persistent_vars.end(), name) != persistent_vars.end()) { - pvars.emplace_back(t->buffer(), ftell(f) + offsetof(cinn_buffer_t, memory)); + buffer.memory = (uint8_t*)0; + if (std::find(persistent_vars.begin(), persistent_vars.end(), name) != + persistent_vars.end()) { + pvars.emplace_back(t->buffer(), + ftell(f) + offsetof(cinn_buffer_t, memory)); } fwrite(&buffer, sizeof(cinn_buffer_t), 1, f); } @@ -172,22 +178,23 @@ void Program::Export(const std::vector& persistent_vars, const std: tellplaceholder(pbuffer, f); // instructions int instsec = writeplaceholder(4, 1, f); - int insnum = 0; + int insnum = 0; for (auto& ins : instrs_) { ins->Run(nullptr, true); insnum += ins->GetFnNames().size(); } fwrite(&insnum, 4, 1, f); int instplaceholder = writeplaceholder(4 * 3, insnum, f); - int findex = 0; + int findex = 0; for (auto& ins : instrs_) { - auto in_args = ins->GetInArgs(); + auto in_args = ins->GetInArgs(); auto out_args = ins->GetOutArgs(); auto fn_names = ins->GetFnNames(); for (int i = 0; i < fn_names.size(); i++, findex++) { std::vector all_args(in_args[i].begin(), in_args[i].end()); - all_args.insert(std::end(all_args), out_args[i].begin(), out_args[i].end()); - auto fname = fn_names[i]; + all_args.insert( + std::end(all_args), out_args[i].begin(), out_args[i].end()); + auto fname = fn_names[i]; int fnamesize = fname.size(); fwrite(&fnamesize, 4, 1, f); tellplaceholder(instplaceholder + findex * 12, f); @@ -209,7 +216,10 @@ void Program::Export(const std::vector& persistent_vars, const std: fclose(f); } -void Program::Execute(const std::map* name2podargs, void* stream, bool use_cache) { +void Program::Execute( + const std::map* name2podargs, + void* stream, + bool use_cache) { for (auto& ins : instrs_) { ins->Run(name2podargs, false, stream, use_cache); } @@ -240,13 +250,14 @@ void Program::ExecuteTest(int repeat_) { } #endif double test_op_time = timer1.Stop() / repeat_; - VLOG(3) << "Repeat times: [" << repeat_ << "], average op time: [" << test_op_time << "] ms"; + VLOG(3) << "Repeat times: [" << repeat_ << "], average op time: [" + << test_op_time << "] ms"; } void GraphCompiler::PrintFunc() { auto topo_order = graph_->topological_order(); - auto& nodes = std::get<0>(topo_order); - auto& edges = std::get<1>(topo_order); + auto& nodes = std::get<0>(topo_order); + auto& edges = std::get<1>(topo_order); for (auto& n : nodes) { auto* node = n->safe_as(); @@ -258,8 +269,8 @@ void GraphCompiler::PrintFunc() { std::string GraphCompiler::GenSourceCode() { auto topo_order = graph_->topological_order(); - auto& nodes = std::get<0>(topo_order); - auto& edges = std::get<1>(topo_order); + auto& nodes = std::get<0>(topo_order); + auto& edges = std::get<1>(topo_order); for (auto& n : nodes) { auto* node = n->safe_as(); @@ -280,7 +291,8 @@ std::string GraphCompiler::GenSourceCode() { return compiler_->GetSourceCode(build_module); } -const std::string& GraphCompiler::GetOrGenFullFuncName(const std::string& prefix) { +const std::string& GraphCompiler::GetOrGenFullFuncName( + const std::string& prefix) { // try_emplace only insert once, so the same function // can get a consistent name next time prefix2full_namemap_.try_emplace(prefix, Context::Global().NewName(prefix)); @@ -301,10 +313,11 @@ std::vector GraphCompiler::GetOpFuncWithIRSchedule( // 1.Collect inputs info and outputs info for (auto& i : node->inlinks_in_order()) { std::string id = i->source()->as()->id(); - auto shape = shape_dict_.at(id); - Type dtype = type_dict_.at(id); - CHECK(dtype.is_supported()) << "The dtype of node " << id - << " is not float or bool or int! Other dtype is not implemented yet."; + auto shape = shape_dict_.at(id); + Type dtype = type_dict_.at(id); + CHECK(dtype.is_supported()) + << "The dtype of node " << id + << " is not float or bool or int! Other dtype is not implemented yet."; ir::Tensor input; if (dtype.is_float(32)) { input = lang::Placeholder(id, shape); @@ -351,28 +364,35 @@ std::vector GraphCompiler::GetOpFuncWithIRSchedule( input_output_nodes.push_back(out_name); } - auto impl = - OpStrategy::SelectImpl(cinn_strategy[node->op()](node->attrs, tensor_inputs, out_types, out_shapes, target_)); + auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()]( + node->attrs, tensor_inputs, out_types, out_shapes, target_)); - auto res = - GetFuncFromImpl(impl, common::CINNValuePack{cinn_inputs}, tensor_inputs, input_output_nodes, node->id(), target_); + auto res = GetFuncFromImpl(impl, + common::CINNValuePack{cinn_inputs}, + tensor_inputs, + input_output_nodes, + node->id(), + target_); return res; } std::vector GraphCompiler::GetOpFunc(const Node* node) { - auto& strategy = Operator::GetAttrs("CINNStrategy"); - auto& shape_dict = graph_->GetAttrs>("infershape"); - auto& dtype_dict = graph_->GetAttrs>("inferdtype"); + auto& strategy = Operator::GetAttrs("CINNStrategy"); + auto& shape_dict = + graph_->GetAttrs>("infershape"); + auto& dtype_dict = + graph_->GetAttrs>("inferdtype"); std::vector inputs; std::vector cinn_inputs; std::vector> output_shapes; VLOG(3) << "GetOpFunc of op " << node->id(); for (auto& i : node->inlinks_in_order()) { std::string input_id = i->source()->as()->id(); - auto in_shape = shape_dict.at(input_id); - Type dtype = dtype_dict.at(input_id); - CHECK(dtype.is_supported()) << "The dtype of node " << input_id - << " is not float or bool or int! Other dtype is not implemented yet."; + auto in_shape = shape_dict.at(input_id); + Type dtype = dtype_dict.at(input_id); + CHECK(dtype.is_supported()) + << "The dtype of node " << input_id + << " is not float or bool or int! Other dtype is not implemented yet."; ir::Tensor temp; if (dtype.is_float(32)) { temp = lang::Placeholder(input_id, in_shape); @@ -407,16 +427,17 @@ std::vector GraphCompiler::GetOpFunc(const Node* node) { std::vector out_types; for (auto& out : node->outlinks_in_order()) { std::string out_id = out->sink()->safe_as()->id(); - auto out_shape = shape_dict.at(out_id); - Type dtype = dtype_dict.at(out_id); + auto out_shape = shape_dict.at(out_id); + Type dtype = dtype_dict.at(out_id); output_shapes.push_back(out_shape); out_types.push_back(dtype); } - auto impl = OpStrategy::SelectImpl(strategy[node->op()](node->attrs, inputs, out_types, output_shapes, target_)); + auto impl = OpStrategy::SelectImpl(strategy[node->op()]( + node->attrs, inputs, out_types, output_shapes, target_)); common::CINNValuePack C = impl->fcompute(common::CINNValuePack{cinn_inputs}); - poly::StageMap stages = C.back(); + poly::StageMap stages = C.back(); // make sure all the tensors in the stages before schedule launch. for (int i = 0; i < C->size() - 1; i++) { ir::Expr temp = C[i]; @@ -427,46 +448,58 @@ std::vector GraphCompiler::GetOpFunc(const Node* node) { for (int i = 0; i < C->size() - 1; i++) { ir::Expr temp = C[i]; // checkout whether the tensor is with buffer. - if ((!temp.as_tensor_ref()->buffer.defined() || this->target_ != common::DefaultNVGPUTarget()) && + if ((!temp.as_tensor_ref()->buffer.defined() || + this->target_ != common::DefaultNVGPUTarget()) && !stages[temp.as_tensor_ref()]->inlined()) { inputs.push_back(temp.as_tensor_ref()); } } - auto func = lang::LowerVec(GetOrGenFullFuncName(GenOpFuncName(node)), stages, inputs, {}, {}, nullptr, this->target_); - VLOG(3) << "The [" << func.size() << "] functions of node [" << node->attrs.node_name << "] are:\n"; + auto func = lang::LowerVec(GetOrGenFullFuncName(GenOpFuncName(node)), + stages, + inputs, + {}, + {}, + nullptr, + this->target_); + VLOG(3) << "The [" << func.size() << "] functions of node [" + << node->attrs.node_name << "] are:\n"; for (auto& i : func) { VLOG(3) << i; } return func; } -// get the most complex op's index in the fused groups according to the OpPattern. If the OpPattern is same, we will -// take the latter. +// get the most complex op's index in the fused groups according to the +// OpPattern. If the OpPattern is same, we will take the latter. int GetMasterRefNode(const std::vector& nodes) { auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - int master_index = 0; - int master_pattern = op_pattern_dict[nodes[0]->op()]; + int master_index = 0; + int master_pattern = op_pattern_dict[nodes[0]->op()]; for (int i = 1; i < nodes.size(); i++) { - int pattern = op_pattern_dict[nodes[i]->op()]; - master_index = pattern >= master_pattern ? i : master_index; + int pattern = op_pattern_dict[nodes[i]->op()]; + master_index = pattern >= master_pattern ? i : master_index; master_pattern = std::max(pattern, master_pattern); } - VLOG(3) << "master_index: " << master_index << ", master op: " << nodes[master_index]->op()->name; + VLOG(3) << "master_index: " << master_index + << ", master op: " << nodes[master_index]->op()->name; return master_index; } -std::vector GraphCompiler::GetOpFunc(const std::vector& nodes) { +std::vector GraphCompiler::GetOpFunc( + const std::vector& nodes) { CHECK_GT(nodes.size(), 1) << "fuse nodes number must be greater than 1"; - auto& strategy = Operator::GetAttrs("CINNStrategy"); - auto& shape_dict = graph_->GetAttrs>("infershape"); - auto& dtype_dict = graph_->GetAttrs>("inferdtype"); - int fuse_number = nodes.size(); + auto& strategy = Operator::GetAttrs("CINNStrategy"); + auto& shape_dict = + graph_->GetAttrs>("infershape"); + auto& dtype_dict = + graph_->GetAttrs>("inferdtype"); + int fuse_number = nodes.size(); VLOG(3) << "fuse begin: " << nodes[0]->id(); std::vector inputs; std::vector outputs; poly::StageMap stages; - int index = 0; + int index = 0; std::string fuse_name = "fn_"; std::unordered_set in_vars; std::unordered_set out_vars; @@ -491,10 +524,11 @@ std::vector GraphCompiler::GetOpFunc(const std::vector& temp_inputs.push_back(fuse_out.as_tensor_ref()); } else { std::string input_id = source_data->id(); - auto in_shape = shape_dict.at(input_id); - Type dtype = dtype_dict.at(input_id); + auto in_shape = shape_dict.at(input_id); + Type dtype = dtype_dict.at(input_id); CHECK(dtype.is_supported()) << "The dtype of node " << input_id - << " is not float or bool or int! Other dtype is not implemented yet."; + << " is not float or bool or int! Other " + "dtype is not implemented yet."; ir::Tensor temp_in; if (dtype.is_float(32)) { temp_in = lang::Placeholder(input_id, in_shape); @@ -540,19 +574,20 @@ std::vector GraphCompiler::GetOpFunc(const std::vector& std::string out_id = out_var->id(); VLOG(3) << "out_id " << out_id; auto out_shape = shape_dict.at(out_id); - Type dtype = dtype_dict.at(out_id); + Type dtype = dtype_dict.at(out_id); output_shapes.push_back(out_shape); out_types.push_back(dtype); } - auto impl = - OpStrategy::SelectImpl(strategy[node->op()](node->attrs, temp_inputs, out_types, output_shapes, target_)); + auto impl = OpStrategy::SelectImpl(strategy[node->op()]( + node->attrs, temp_inputs, out_types, output_shapes, target_)); - common::CINNValuePack C = impl->fcompute(common::CINNValuePack{cinn_inputs}); + common::CINNValuePack C = + impl->fcompute(common::CINNValuePack{cinn_inputs}); if (index == master_index) { // use the most complex op's schedule as the fused ops' schedule. C = impl->fschedule(C); CHECK(!C.empty()); - Expr out = C[0]; + Expr out = C[0]; master_out_tensor = out.as_tensor_ref(); } @@ -568,7 +603,7 @@ std::vector GraphCompiler::GetOpFunc(const std::vector& C = C_temp; } for (int i = 0; i < C.size() - 1; i++) { - Expr out = C[i]; + Expr out = C[i]; temp_var_map[temp_outvars[i]] = out; if (fetch_var_ids_.count(temp_outvars[i]->id())) { VLOG(3) << "get fetch output var " << temp_outvars[i]->id(); @@ -589,7 +624,8 @@ std::vector GraphCompiler::GetOpFunc(const std::vector& auto temp_tensor = temp.as_tensor_ref(); stages->InsertLazily(temp_tensor, temp_stages[temp_tensor]); if (index < fuse_number - 1 && !temp_tensor->is_reduce_tensor()) { - // assume that only the first out_var links to other op node which will compute inline + // assume that only the first out_var links to other op node which will + // compute inline if (fetch_tensors.count(temp_tensor)) { VLOG(3) << "add op's fetch out_vars: " << temp_tensor->name; outputs.insert(outputs.begin(), temp_tensor); @@ -620,7 +656,9 @@ std::vector GraphCompiler::GetOpFunc(const std::vector& // args order: inputs + final output + fetch outputs + other no_fused outputs for (auto& tensor : outputs) { // checkout the tensor is with buffer. - if ((!tensor->buffer.defined() || this->target_ != common::DefaultNVGPUTarget()) && !stages[tensor]->inlined()) { + if ((!tensor->buffer.defined() || + this->target_ != common::DefaultNVGPUTarget()) && + !stages[tensor]->inlined()) { inputs.push_back(tensor); } } @@ -637,20 +675,21 @@ std::vector GraphCompiler::GetOpFunc(const std::vector& for (auto& s : stages) { auto& compute_ats = s.second->GetComputeAts(); - auto tensor = s.second->tensor(); + auto tensor = s.second->tensor(); if (!compute_ats.empty()) { poly::ComputeAtRelation new_relation; CHECK_EQ(compute_ats.size(), 1U); auto new_stage = stages[final_out_tensor]; for (auto& compute_at : compute_ats) { - auto& old_relation = compute_at.second; + auto& old_relation = compute_at.second; auto old_target_tensor = old_relation.stage->tensor(); if (stages[old_target_tensor]->inlined()) { new_relation.stage = new_stage; new_relation.level = old_relation.level; compute_ats.clear(); - CHECK(new_relation.IsCompatible(s.second.get())) << "new computeAt should be compatible"; + CHECK(new_relation.IsCompatible(s.second.get())) + << "new computeAt should be compatible"; compute_ats[new_stage->id()] = new_relation; break; } @@ -659,10 +698,13 @@ std::vector GraphCompiler::GetOpFunc(const std::vector& } // deal with fetch tensors, not compute_inline but do compute_at for (auto& fetch_tensor : fetch_tensors) { - if (fetch_tensor->is_reduce_tensor() || fetch_tensor->name == final_out_tensor->name) continue; + if (fetch_tensor->is_reduce_tensor() || + fetch_tensor->name == final_out_tensor->name) + continue; stages[fetch_tensor]->DisableComputeInline(); int level = stages[final_out_tensor]->n_out_dims() - 1; - VLOG(3) << "no fuse fetch tensor " << fetch_tensor->name << " and recomputeAt in level " << level; + VLOG(3) << "no fuse fetch tensor " << fetch_tensor->name + << " and recomputeAt in level " << level; // if the fetch tensor size is 1, the fetch tensor cannot fuse by ComputeAt2 int len = 1; @@ -676,7 +718,13 @@ std::vector GraphCompiler::GetOpFunc(const std::vector& stages[fetch_tensor]->ComputeAt2(stages[final_out_tensor], level); } - auto func = lang::LowerVec(GetOrGenFullFuncName(fuse_name), stages, inputs, {}, {}, nullptr, this->target_); + auto func = lang::LowerVec(GetOrGenFullFuncName(fuse_name), + stages, + inputs, + {}, + {}, + nullptr, + this->target_); VLOG(3) << "The [" << func.size() << "] functions are:\n"; for (auto& i : func) { VLOG(3) << "Function [" << i->name << "] is:\n"; @@ -686,7 +734,8 @@ std::vector GraphCompiler::GetOpFunc(const std::vector& return func; } -void GraphCompiler::ProcessFunction(const std::vector& lowered_funcs) { +void GraphCompiler::ProcessFunction( + const std::vector& lowered_funcs) { for (auto&& func : lowered_funcs) { std::vector input_args; std::vector output_args; @@ -699,12 +748,14 @@ void GraphCompiler::ProcessFunction(const std::vector& lowered_ input_args.push_back(arg_name); auto* var = scope_->FindVar(arg_name); if (!arg.is_buffer()) { - VLOG(3) << "function:" << func->name << "-argument:" << arg_name << " type is not buffer, lowered_func:\n" + VLOG(3) << "function:" << func->name << "-argument:" << arg_name + << " type is not buffer, lowered_func:\n" << func; } - if (!var && arg.is_buffer()) { // For argument buffer not in scope, create it. + if (!var && + arg.is_buffer()) { // For argument buffer not in scope, create it. auto* new_var = scope_->Var(arg_name); - auto& tensor = absl::get(*new_var); + auto& tensor = absl::get(*new_var); std::vector shape; for (auto& shape_dim : arg.buffer_arg()->shape) { CHECK(shape_dim.is_constant()); @@ -713,14 +764,15 @@ void GraphCompiler::ProcessFunction(const std::vector& lowered_ tensor->Resize(Shape{shape}); tensor->set_type(arg.buffer_arg()->dtype); VLOG(3) << utils::StringFormat( - "Will create a new variable in scope for argument[%s] in function[%s] with shape[%s],dtype[%s]", + "Will create a new variable in scope for argument[%s] in " + "function[%s] with shape[%s],dtype[%s]", arg_name.c_str(), func->name.c_str(), utils::Join(tensor->shape().data(), ","), common::Type2Str(tensor->type())); } } - function2input_args_[func->name] = input_args; + function2input_args_[func->name] = input_args; function2output_args_[func->name] = output_args; m_builder_.AddFunction(func); } @@ -729,38 +781,44 @@ void GraphCompiler::ProcessFunction(const std::vector& lowered_ std::unique_ptr GraphCompiler::Build(const std::string& code) { utils::RecordEvent("GraphCompiler::Build", utils::EventType::kGraph); GraphCompiler::CompileOptions options; - options.attached_code = code; + options.attached_code = code; options.with_instantiate_variables = true; auto&& result = Build(options); return std::move(result.runtime_program); } -void GraphCompiler::CompileOptions::Apply(const auto_schedule::TuningResult& tuning_result) { +void GraphCompiler::CompileOptions::Apply( + const auto_schedule::TuningResult& tuning_result) { // assign options with TuningResult directly groups.assign(tuning_result.subgraphs.begin(), tuning_result.subgraphs.end()); - lowered_funcs.assign(tuning_result.function_groups.begin(), tuning_result.function_groups.end()); + lowered_funcs.assign(tuning_result.function_groups.begin(), + tuning_result.function_groups.end()); } -GraphCompiler::CompilationResult GraphCompiler::Build(const GraphCompiler::CompileOptions& options, - std::unordered_set&& fetch_var_ids, - void* stream) { +GraphCompiler::CompilationResult GraphCompiler::Build( + const GraphCompiler::CompileOptions& options, + std::unordered_set&& fetch_var_ids, + void* stream) { Context::Global().ResetNameId(); if (FLAGS_cinn_parallel_compile_size) { // write group's information into FLAGS_cinn_fusion_groups_graphviz_dir - graph_->VisualizeGroupedGraph(fetch_var_ids.empty() ? fetch_var_ids_ : fetch_var_ids); + graph_->VisualizeGroupedGraph(fetch_var_ids.empty() ? fetch_var_ids_ + : fetch_var_ids); if (options.with_instantiate_variables) { VLOG(3) << "Instantiate all variables on compile-time"; - utils::RecordEvent("GraphCompiler MutableData", utils::EventType::kOrdinary); + utils::RecordEvent("GraphCompiler MutableData", + utils::EventType::kOrdinary); // All variables reside in scope_, so traverse it to instantiate each one for (auto& name : scope_->var_names()) { - auto* var = scope_->Var(std::string({name.data(), name.size()})); + auto* var = + scope_->Var(std::string({name.data(), name.size()})); auto& tensor = absl::get(*var); if (reuse_vars_map_.count(name)) { auto src_var_name = reuse_vars_map_.at(name); - auto* src_var = scope_->Var(src_var_name); - auto& src_tensor = absl::get(*src_var); + auto* src_var = scope_->Var(src_var_name); + auto& src_tensor = absl::get(*src_var); tensor->set_buffer(src_tensor->get_buffer()); } else { tensor->mutable_data(target_, tensor->type()); @@ -769,12 +827,14 @@ GraphCompiler::CompilationResult GraphCompiler::Build(const GraphCompiler::Compi } VLOG(2) << "Compile With Parallel Compiler!"; - utils::RecordEvent("GraphCompiler CompileResult", utils::EventType::kOrdinary); + utils::RecordEvent("GraphCompiler CompileResult", + utils::EventType::kOrdinary); ParallelCompiler::CompileOptions option; option.lowered_funcs = options.lowered_funcs; - parallel_compiler_ = std::make_shared(scope_, graph_, option, target_); - auto instructions = (*parallel_compiler_.get())(); + parallel_compiler_ = + std::make_shared(scope_, graph_, option, target_); + auto instructions = (*parallel_compiler_.get())(); if (options.remove_unused_variables) { RemoveInvalidVariables(instructions); @@ -787,20 +847,22 @@ GraphCompiler::CompilationResult GraphCompiler::Build(const GraphCompiler::Compi VLOG(2) << "Compile With Parallel Compiler Done!"; GraphCompiler::CompilationResult compilation_result; - compilation_result.runtime_program.reset(new Program(scope_, std::move(instructions))); + compilation_result.runtime_program.reset( + new Program(scope_, std::move(instructions))); return compilation_result; } compile_options_ = options; - fetch_var_ids_ = std::move(fetch_var_ids); - auto topo_order = graph_->topological_order(); - auto& nodes = std::get<0>(topo_order); + fetch_var_ids_ = std::move(fetch_var_ids); + auto topo_order = graph_->topological_order(); + auto& nodes = std::get<0>(topo_order); VLOG(3) << "Begin GraphCompiler::Build"; function2input_args_.clear(); function2output_args_.clear(); m_builder_.Clear(); // if there are no available groups, we will take each node as a group - if (options.groups.empty() && graph_->groups.empty() && graph_->fusion_groups.empty()) { + if (options.groups.empty() && graph_->groups.empty() && + graph_->fusion_groups.empty()) { VLOG(3) << "not run opfusion pass"; for (auto& node : nodes) { auto op_node = node->safe_as(); @@ -819,33 +881,46 @@ GraphCompiler::CompilationResult GraphCompiler::Build(const GraphCompiler::Compi } } - // if the input lowered_funcs is empty, we will use the default lowering process to generate + // if the input lowered_funcs is empty, we will use the default lowering + // process to generate std::vector> local_lowered_funcs; if (options.lowered_funcs.empty()) { - utils::RecordEvent("GraphCompiler LoweredFuncs", utils::EventType::kOrdinary); - // lowering of new fusion pass is not compatible with the groups from the input options, - // thus process it separately + utils::RecordEvent("GraphCompiler LoweredFuncs", + utils::EventType::kOrdinary); + // lowering of new fusion pass is not compatible with the groups from the + // input options, thus process it separately if (!graph_->fusion_groups.empty()) { - auto& dtype_dict = graph_->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph_->GetMutableAttrs>("infershape"); + auto& dtype_dict = + graph_->GetMutableAttrs>( + "inferdtype"); + auto& shape_dict = + graph_->GetMutableAttrs>( + "infershape"); OpLowerer op_lowerer(dtype_dict, shape_dict, target_); for (auto& group : graph_->fusion_groups) { - VLOG(3) << "group_id is : " << group->group_id << ", and its number is : " << group->nodes.size(); + VLOG(3) << "group_id is : " << group->group_id + << ", and its number is : " << group->nodes.size(); groups.push_back(std::move(group->CollectNodes())); local_lowered_funcs.emplace_back(std::move(op_lowerer.Lower(group))); - CHECK_EQ(local_lowered_funcs.back().size(), 1) << "Lowered Function Is Not Equal 1!"; + CHECK_EQ(local_lowered_funcs.back().size(), 1) + << "Lowered Function Is Not Equal 1!"; VLOG(3) << local_lowered_funcs.back()[0]; } } else { VLOG(3) << "fusion_groups is empty"; std::vector lowered_func; if (FLAGS_cinn_ir_schedule) { - auto& dtype_dict = graph_->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph_->GetMutableAttrs>("infershape"); + auto& dtype_dict = + graph_->GetMutableAttrs>( + "inferdtype"); + auto& shape_dict = + graph_->GetMutableAttrs>( + "infershape"); for (int i = 0; i < groups.size(); i++) { for (int j = 0; j < groups[i].size(); j++) { - lowered_func = GetOpFuncWithIRSchedule(groups[i][j], dtype_dict, shape_dict); + lowered_func = + GetOpFuncWithIRSchedule(groups[i][j], dtype_dict, shape_dict); local_lowered_funcs.emplace_back(std::move(lowered_func)); } } @@ -862,13 +937,18 @@ GraphCompiler::CompilationResult GraphCompiler::Build(const GraphCompiler::Compi } } // write group's information into FLAGS_cinn_fusion_groups_graphviz_dir - graph_->VisualizeGroupedGraph(groups, fetch_var_ids.empty() ? fetch_var_ids_ : fetch_var_ids); + graph_->VisualizeGroupedGraph( + groups, fetch_var_ids.empty() ? fetch_var_ids_ : fetch_var_ids); // use the input lowered_funcs in options firstly if exists - const auto& lowered_funcs = options.lowered_funcs.empty() ? local_lowered_funcs : options.lowered_funcs; - CHECK_EQ(groups.size(), lowered_funcs.size()) << "The size of groups and lowered_funcs should be equal"; + const auto& lowered_funcs = options.lowered_funcs.empty() + ? local_lowered_funcs + : options.lowered_funcs; + CHECK_EQ(groups.size(), lowered_funcs.size()) + << "The size of groups and lowered_funcs should be equal"; { - utils::RecordEvent("GraphCompiler ProcessFunction", utils::EventType::kOrdinary); + utils::RecordEvent("GraphCompiler ProcessFunction", + utils::EventType::kOrdinary); for (auto&& lowered_func : lowered_funcs) { this->ProcessFunction(lowered_func); } @@ -876,13 +956,15 @@ GraphCompiler::CompilationResult GraphCompiler::Build(const GraphCompiler::Compi // compile the module // Need to create a new compiler for every call of Build, - // because the underneath jit engine doesn't support addIRModule repeatedly now. + // because the underneath jit engine doesn't support addIRModule repeatedly + // now. compiler_ = backends::Compiler::Create(target_); auto build_module = m_builder_.Build(); VLOG(3) << "End of m_builder_.Build()"; if (this->target_.arch == Target::Arch::X86) { - utils::RecordEvent("GraphCompiler CodeGenCX86", utils::EventType::kOrdinary); + utils::RecordEvent("GraphCompiler CodeGenCX86", + utils::EventType::kOrdinary); CodeGenCX86 codegen(this->target_, CodeGenCX86::Feature::AVX512); codegen.SetInlineBuiltinCodes(false); auto out = codegen.Compile(build_module, CodeGenC::OutputKind::CImpl); @@ -890,12 +972,14 @@ GraphCompiler::CompilationResult GraphCompiler::Build(const GraphCompiler::Compi } { - utils::RecordEvent("GraphCompiler BackendsBuild", utils::EventType::kOrdinary); + utils::RecordEvent("GraphCompiler BackendsBuild", + utils::EventType::kOrdinary); compiler_->Build(build_module, options.attached_code); VLOG(3) << "End of compiler_->Build"; } - auto instructions = BuildInstructions(groups, options.groups.empty() ? graph_->fusion_groups : options.groups); + auto instructions = BuildInstructions( + groups, options.groups.empty() ? graph_->fusion_groups : options.groups); VLOG(3) << "End of BuildInstructions"; if (options.remove_unused_variables) { RemoveInvalidVariables(instructions); @@ -907,15 +991,16 @@ GraphCompiler::CompilationResult GraphCompiler::Build(const GraphCompiler::Compi if (options.with_instantiate_variables) { VLOG(3) << "Instantiate all variables on compile-time"; - utils::RecordEvent("GraphCompiler MutableData", utils::EventType::kOrdinary); + utils::RecordEvent("GraphCompiler MutableData", + utils::EventType::kOrdinary); // All variables reside in scope_, so traverse it to instantiate each one for (auto& name : scope_->var_names()) { - auto* var = scope_->Var(std::string({name.data(), name.size()})); + auto* var = scope_->Var(std::string({name.data(), name.size()})); auto& tensor = absl::get(*var); if (reuse_vars_map_.count(name)) { auto src_var_name = reuse_vars_map_.at(name); - auto* src_var = scope_->Var(src_var_name); - auto& src_tensor = absl::get(*src_var); + auto* src_var = scope_->Var(src_var_name); + auto& src_tensor = absl::get(*src_var); tensor->set_buffer(src_tensor->get_buffer()); } else { tensor->mutable_data(target_, tensor->type()); @@ -928,8 +1013,9 @@ GraphCompiler::CompilationResult GraphCompiler::Build(const GraphCompiler::Compi return result; } -void GraphCompiler::SetSubKernels(Instruction* instr, const std::string& func_name) { - int i = 1; +void GraphCompiler::SetSubKernels(Instruction* instr, + const std::string& func_name) { + int i = 1; std::string new_op_func = func_name + "_" + std::to_string(i); if (function2input_args_.count(new_op_func) != 0) { CHECK_GT(function2input_args_.count(func_name), 0); @@ -947,29 +1033,34 @@ void GraphCompiler::SetSubKernels(Instruction* instr, const std::string& func_na } } -void GraphCompiler::BuildCublasInstr(const Node& node, Instruction* instr) const { +void GraphCompiler::BuildCublasInstr(const Node& node, + Instruction* instr) const { instr->ClearInArgs(); instr->AddInArgs(OpGetInputNames(&node)); - auto& shape_dict = graph_->GetAttrs>("infershape"); + auto& shape_dict = + graph_->GetAttrs>("infershape"); // shape info std::vector shape_sizes; for (auto& in_node : node.inlinks_in_order()) { std::string in_id = in_node->source()->safe_as()->id(); - auto in_shape = shape_dict.at(in_id); + auto in_shape = shape_dict.at(in_id); instr->attrs.insert(instr->attrs.end(), in_shape.begin(), in_shape.end()); shape_sizes.push_back(in_shape.size()); } - // cublas_gemm has three input vars, and its output shape is equal to the input bias. - // cublas_matmul only has two input vars, so we should get its output shape from shape_dict. + // cublas_gemm has three input vars, and its output shape is equal to the + // input bias. cublas_matmul only has two input vars, so we should get its + // output shape from shape_dict. if (node.op()->name == "cublas_matmul") { for (auto& out_node : node.outlinks_in_order()) { std::string out_id = out_node->sink()->safe_as()->id(); - auto out_shape = shape_dict.at(out_id); - instr->attrs.insert(instr->attrs.end(), out_shape.begin(), out_shape.end()); + auto out_shape = shape_dict.at(out_id); + instr->attrs.insert( + instr->attrs.end(), out_shape.begin(), out_shape.end()); shape_sizes.push_back(out_shape.size()); } } - instr->attrs.insert(instr->attrs.end(), shape_sizes.begin(), shape_sizes.end()); + instr->attrs.insert( + instr->attrs.end(), shape_sizes.begin(), shape_sizes.end()); // attribute info bool trans_a = false; if (node.attrs.attr_store.contains("trans_a")) { @@ -994,17 +1085,20 @@ void GraphCompiler::BuildCublasInstr(const Node& node, Instruction* instr) const } std::vector> GraphCompiler::BuildInstructions( - const std::vector>& groups, const std::vector>& fusion_groups) { - utils::RecordEvent("GraphCompiler BuildInstructions", utils::EventType::kOrdinary); + const std::vector>& groups, + const std::vector>& fusion_groups) { + utils::RecordEvent("GraphCompiler BuildInstructions", + utils::EventType::kOrdinary); std::vector> instructions; auto topo_order = graph_->topological_order(); - auto& nodes = std::get<0>(topo_order); - auto& edges = std::get<1>(topo_order); + auto& nodes = std::get<0>(topo_order); + auto& edges = std::get<1>(topo_order); VLOG(3) << "Begin GraphCompiler::BuildInstructions"; CHECK_GT(groups.size(), 0); CHECK_EQ(fusion_groups.size() != 0, groups.size() == fusion_groups.size()) - << "fusion_groups's size must be 0 or equal to groups. Currently fusion_group's size = " << fusion_groups.size() - << ", group's size = " << groups.size(); + << "fusion_groups's size must be 0 or equal to groups. Currently " + "fusion_group's size = " + << fusion_groups.size() << ", group's size = " << groups.size(); for (int idx = 0; idx < groups.size(); ++idx) { auto& group = groups[idx]; std::shared_ptr fusion_group(nullptr); @@ -1012,37 +1106,47 @@ std::vector> GraphCompiler::BuildInstructions( fusion_group = fusion_groups[idx]; } if (group.size() == 1) { - auto node = group[0]; + auto node = group[0]; auto instr_name = node->op()->name; - if (node->op()->name == "reshape" && compile_options_.with_instantiate_variables) { + if (node->op()->name == "reshape" && + compile_options_.with_instantiate_variables) { // not run instruction and shares buffer only when instantiate_variables - const auto& inlinks = node->inlinks_in_order(); + const auto& inlinks = node->inlinks_in_order(); const auto& outlinks = node->outlinks_in_order(); CHECK_EQ(inlinks.size(), 1U); CHECK_EQ(outlinks.size(), 1U); - std::string in_id = inlinks[0]->source()->safe_as()->id(); - std::string out_id = outlinks[0]->sink()->safe_as()->id(); + std::string in_id = inlinks[0]->source()->safe_as()->id(); + std::string out_id = outlinks[0]->sink()->safe_as()->id(); reuse_vars_map_[out_id] = in_id; - instr_name = "no_run"; + instr_name = "no_run"; } auto instr = std::unique_ptr( new Instruction(target_, scope_.get(), - fusion_group.get() ? fusion_group->input_names : OpGetInputNames(node), - fusion_group.get() ? fusion_group->output_names : OpGetOutputNames(node), + fusion_group.get() ? fusion_group->input_names + : OpGetInputNames(node), + fusion_group.get() ? fusion_group->output_names + : OpGetOutputNames(node), instr_name)); if (target_.arch == Target::Arch::NVGPU) { if (node->op()->name == "conv2d") { - auto& shape_dict = graph_->GetAttrs>("infershape"); + auto& shape_dict = + graph_->GetAttrs>( + "infershape"); for (auto& in_node : node->inlinks_in_order()) { std::string in_id = in_node->source()->safe_as()->id(); - auto in_shape = shape_dict.at(in_id); - instr->attrs.insert(instr->attrs.end(), in_shape.begin(), in_shape.end()); + auto in_shape = shape_dict.at(in_id); + instr->attrs.insert( + instr->attrs.end(), in_shape.begin(), in_shape.end()); } - AddAttrs(node->attrs.attr_store, {"padding", "stride", "dilation"}, instr.get()); - if (node->attrs.attr_store.find("groups") != node->attrs.attr_store.end()) { - auto conv_groups = absl::get(node->attrs.attr_store.at("groups")); + AddAttrs(node->attrs.attr_store, + {"padding", "stride", "dilation"}, + instr.get()); + if (node->attrs.attr_store.find("groups") != + node->attrs.attr_store.end()) { + auto conv_groups = + absl::get(node->attrs.attr_store.at("groups")); instr->attrs.push_back(conv_groups); } else { instr->attrs.push_back(1); @@ -1050,30 +1154,41 @@ std::vector> GraphCompiler::BuildInstructions( // output shape const auto& out_links = node->outlinks_in_order(); CHECK(!out_links.empty()); - auto& out_node = out_links.front(); + auto& out_node = out_links.front(); std::string out_id = out_node->sink()->safe_as()->id(); - auto out_shape = shape_dict.at(out_id); - instr->attrs.insert(instr->attrs.end(), out_shape.begin(), out_shape.end()); + auto out_shape = shape_dict.at(out_id); + instr->attrs.insert( + instr->attrs.end(), out_shape.begin(), out_shape.end()); CHECK_EQ(instr->attrs.size(), 19UL); // conv type {forward, backward_data, backward_filter} std::string type = "forward"; - if (node->attrs.attr_store.find("conv_type") != node->attrs.attr_store.end()) { - type = absl::get(node->attrs.attr_store.at("conv_type")); + if (node->attrs.attr_store.find("conv_type") != + node->attrs.attr_store.end()) { + type = + absl::get(node->attrs.attr_store.at("conv_type")); } instr->str_attrs.push_back(type); - if (node->attrs.attr_store.find("data_format") != node->attrs.attr_store.end()) { - instr->str_attrs.push_back(absl::get(node->attrs.attr_store["data_format"])); + if (node->attrs.attr_store.find("data_format") != + node->attrs.attr_store.end()) { + instr->str_attrs.push_back( + absl::get(node->attrs.attr_store["data_format"])); } } else if (node->op()->name == "depthwise_conv2d") { - auto& shape_dict = graph_->GetAttrs>("infershape"); + auto& shape_dict = + graph_->GetAttrs>( + "infershape"); for (auto& in_node : node->inlinks_in_order()) { std::string in_id = in_node->source()->safe_as()->id(); - auto in_shape = shape_dict.at(in_id); - instr->attrs.insert(instr->attrs.end(), in_shape.begin(), in_shape.end()); + auto in_shape = shape_dict.at(in_id); + instr->attrs.insert( + instr->attrs.end(), in_shape.begin(), in_shape.end()); } // conv - AddAttrs(node->attrs.attr_store, {"padding", "stride", "dilation"}, instr.get()); - if (node->attrs.attr_store.find("groups") != node->attrs.attr_store.end()) { + AddAttrs(node->attrs.attr_store, + {"padding", "stride", "dilation"}, + instr.get()); + if (node->attrs.attr_store.find("groups") != + node->attrs.attr_store.end()) { auto groups = absl::get(node->attrs.attr_store.at("groups")); instr->attrs.push_back(groups); } else { @@ -1082,43 +1197,59 @@ std::vector> GraphCompiler::BuildInstructions( // output shape const auto& out_links = node->outlinks_in_order(); CHECK(!out_links.empty()); - auto& out_node = out_links.front(); + auto& out_node = out_links.front(); std::string out_id = out_node->sink()->safe_as()->id(); - auto out_shape = shape_dict.at(out_id); - instr->attrs.insert(instr->attrs.end(), out_shape.begin(), out_shape.end()); + auto out_shape = shape_dict.at(out_id); + instr->attrs.insert( + instr->attrs.end(), out_shape.begin(), out_shape.end()); CHECK_EQ(instr->attrs.size(), 19UL); // conv type {forward, backward_data, backward_filter} std::string type = "forward"; - if (node->attrs.attr_store.find("conv_type") != node->attrs.attr_store.end()) { - type = absl::get(node->attrs.attr_store.at("conv_type")); + if (node->attrs.attr_store.find("conv_type") != + node->attrs.attr_store.end()) { + type = + absl::get(node->attrs.attr_store.at("conv_type")); } instr->str_attrs.push_back(type); } else if (node->op()->name == "pool2d") { - auto& shape_dict = graph_->GetAttrs>("infershape"); + auto& shape_dict = + graph_->GetAttrs>( + "infershape"); for (auto& in_node : node->inlinks_in_order()) { std::string in_id = in_node->source()->safe_as()->id(); - auto in_shape = shape_dict.at(in_id); + auto in_shape = shape_dict.at(in_id); CHECK_EQ(in_shape.size(), 4UL); - instr->attrs.insert(instr->attrs.end(), in_shape.begin(), in_shape.end()); + instr->attrs.insert( + instr->attrs.end(), in_shape.begin(), in_shape.end()); } bool global_pooling = false; - if (node->attrs.attr_store.find("global_pooling") != node->attrs.attr_store.end()) { - global_pooling = absl::get(node->attrs.attr_store.at("global_pooling")); + if (node->attrs.attr_store.find("global_pooling") != + node->attrs.attr_store.end()) { + global_pooling = + absl::get(node->attrs.attr_store.at("global_pooling")); } - if (node->attrs.attr_store.find("kernel_size") != node->attrs.attr_store.end()) { + if (node->attrs.attr_store.find("kernel_size") != + node->attrs.attr_store.end()) { if (global_pooling == false) { - auto kernel_size = absl::get>(node->attrs.attr_store.at("kernel_size")); - instr->attrs.insert(instr->attrs.end(), kernel_size.begin(), kernel_size.end()); + auto kernel_size = absl::get>( + node->attrs.attr_store.at("kernel_size")); + instr->attrs.insert( + instr->attrs.end(), kernel_size.begin(), kernel_size.end()); } else { instr->attrs.push_back(instr->attrs[2]); instr->attrs.push_back(instr->attrs[3]); } } - if (node->attrs.attr_store.find("padding_size") != node->attrs.attr_store.end()) { + if (node->attrs.attr_store.find("padding_size") != + node->attrs.attr_store.end()) { if (global_pooling == false) { - auto padding = absl::get>(node->attrs.attr_store.at("padding_size")); - instr->attrs.insert(instr->attrs.end(), padding.begin(), padding.end()); - if (padding.size() == 2) instr->attrs.insert(instr->attrs.end(), padding.begin(), padding.end()); + auto padding = absl::get>( + node->attrs.attr_store.at("padding_size")); + instr->attrs.insert( + instr->attrs.end(), padding.begin(), padding.end()); + if (padding.size() == 2) + instr->attrs.insert( + instr->attrs.end(), padding.begin(), padding.end()); } else { instr->attrs.push_back(0); instr->attrs.push_back(0); @@ -1126,15 +1257,20 @@ std::vector> GraphCompiler::BuildInstructions( instr->attrs.push_back(0); } } - AddAttrs(node->attrs.attr_store, {"stride_size", "pool_type"}, instr.get()); + AddAttrs(node->attrs.attr_store, + {"stride_size", "pool_type"}, + instr.get()); for (auto& out_node : node->outlinks_in_order()) { std::string out_id = out_node->sink()->safe_as()->id(); - auto out_shape = shape_dict.at(out_id); - instr->attrs.insert(instr->attrs.end(), out_shape.begin(), out_shape.end()); + auto out_shape = shape_dict.at(out_id); + instr->attrs.insert( + instr->attrs.end(), out_shape.begin(), out_shape.end()); } - if (node->attrs.attr_store.find("adaptive") != node->attrs.attr_store.end()) { - bool adaptive = absl::get(node->attrs.attr_store.at("adaptive")); + if (node->attrs.attr_store.find("adaptive") != + node->attrs.attr_store.end()) { + bool adaptive = + absl::get(node->attrs.attr_store.at("adaptive")); if (adaptive) instr->attrs.push_back(1); else @@ -1143,38 +1279,50 @@ std::vector> GraphCompiler::BuildInstructions( CHECK_EQ(instr->attrs.size(), 17UL); CHECK_EQ(instr->str_attrs.size(), 1UL); } else if (node->op()->name == "softmax") { - auto& shape_dict = graph_->GetAttrs>("infershape"); + auto& shape_dict = + graph_->GetAttrs>( + "infershape"); for (auto& in_node : node->inlinks_in_order()) { std::string in_id = in_node->source()->safe_as()->id(); - auto in_shape = shape_dict.at(in_id); - instr->attrs.insert(instr->attrs.end(), in_shape.begin(), in_shape.end()); + auto in_shape = shape_dict.at(in_id); + instr->attrs.insert( + instr->attrs.end(), in_shape.begin(), in_shape.end()); } AddAttrs(node->attrs.attr_store, {"axis"}, instr.get()); } else if (node->op()->name == "mul") { - auto& shape_dict = graph_->GetAttrs>("infershape"); + auto& shape_dict = + graph_->GetAttrs>( + "infershape"); for (auto& in_node : node->inlinks_in_order()) { std::string in_id = in_node->source()->safe_as()->id(); - auto in_shape = shape_dict.at(in_id); - instr->attrs.insert(instr->attrs.end(), in_shape.begin(), in_shape.end()); + auto in_shape = shape_dict.at(in_id); + instr->attrs.insert( + instr->attrs.end(), in_shape.begin(), in_shape.end()); } - if (node->attrs.attr_store.find("x_num_col_dims") != node->attrs.attr_store.end()) { - auto axis = absl::get(node->attrs.attr_store.at("x_num_col_dims")); + if (node->attrs.attr_store.find("x_num_col_dims") != + node->attrs.attr_store.end()) { + auto axis = + absl::get(node->attrs.attr_store.at("x_num_col_dims")); instr->attrs.push_back(axis); } else { instr->attrs.push_back(1); } - if (node->attrs.attr_store.find("y_num_col_dims") != node->attrs.attr_store.end()) { - auto axis = absl::get(node->attrs.attr_store.at("y_num_col_dims")); + if (node->attrs.attr_store.find("y_num_col_dims") != + node->attrs.attr_store.end()) { + auto axis = + absl::get(node->attrs.attr_store.at("y_num_col_dims")); instr->attrs.push_back(axis); } else { instr->attrs.push_back(1); } - } else if (node->op()->name == "cublas_gemm" || node->op()->name == "cublas_matmul") { + } else if (node->op()->name == "cublas_gemm" || + node->op()->name == "cublas_matmul") { BuildCublasInstr(*node, instr.get()); } } std::string op_func_name = - fusion_group.get() ? fusion_group->GetFuncName() : GetOrGenFullFuncName(GenOpFuncName(node)); + fusion_group.get() ? fusion_group->GetFuncName() + : GetOrGenFullFuncName(GenOpFuncName(node)); auto* fn_ptr = compiler_->Lookup(op_func_name); CHECK(fn_ptr); instr->SetLoweredFunc(reinterpret_cast(fn_ptr), op_func_name); @@ -1185,7 +1333,8 @@ std::vector> GraphCompiler::BuildInstructions( if (node->attrs.attr_store.count("pre_run")) { instr->pre_run = absl::get(node->attrs.attr_store["pre_run"]); } - // explicitly call Finalize of the instruction after all assignments on it were done + // explicitly call Finalize of the instruction after all assignments on it + // were done instr->Finalize(); instructions.push_back(std::move(instr)); } else { @@ -1193,7 +1342,7 @@ std::vector> GraphCompiler::BuildInstructions( std::vector inputNames; std::vector outputNames; std::unordered_set names_set; - int count = 0; + int count = 0; std::string fuse_name = "fn_"; if (!fusion_group.get()) { for (int i = 0; i < group.size(); i++) { @@ -1232,13 +1381,14 @@ std::vector> GraphCompiler::BuildInstructions( VLOG(3) << "input_names: " << utils::Join(inputNames, ", "); VLOG(3) << "out_names: " << utils::Join(outputNames, ", "); } - fuse_name = fusion_group.get() ? fusion_group->GetFuncName() : GetOrGenFullFuncName(fuse_name); - auto instr = - std::unique_ptr(new Instruction(target_, - scope_.get(), - fusion_group.get() ? fusion_group->input_names : inputNames, - fusion_group.get() ? fusion_group->output_names : outputNames, - fuse_name)); + fuse_name = fusion_group.get() ? fusion_group->GetFuncName() + : GetOrGenFullFuncName(fuse_name); + auto instr = std::unique_ptr(new Instruction( + target_, + scope_.get(), + fusion_group.get() ? fusion_group->input_names : inputNames, + fusion_group.get() ? fusion_group->output_names : outputNames, + fuse_name)); auto* fn_ptr = compiler_->Lookup(fuse_name); CHECK(fn_ptr); @@ -1249,11 +1399,13 @@ std::vector> GraphCompiler::BuildInstructions( for (int j = 0; j < group.size(); j++) { auto node = group[j]; - if (node->attrs.attr_store.count("pre_run") && absl::get(node->attrs.attr_store["pre_run"]) == true) { + if (node->attrs.attr_store.count("pre_run") && + absl::get(node->attrs.attr_store["pre_run"]) == true) { instr->pre_run = true; } } - // explicitly call Finalize of the instruction after all assignments on it were done + // explicitly call Finalize of the instruction after all assignments on it + // were done instr->Finalize(); instructions.push_back(std::move(instr)); } @@ -1261,52 +1413,63 @@ std::vector> GraphCompiler::BuildInstructions( return instructions; } -void GraphCompiler::RemoveInvalidVariables(const std::vector>& instructions) { +void GraphCompiler::RemoveInvalidVariables( + const std::vector>& instructions) { // mark all variables are invalid initially - utils::RecordEvent("GraphCompiler RemoveInvalidVariables", utils::EventType::kOrdinary); + utils::RecordEvent("GraphCompiler RemoveInvalidVariables", + utils::EventType::kOrdinary); std::unordered_set invalid_variables; auto var_names = scope_->var_names(); invalid_variables.reserve(var_names.size()); - std::transform(var_names.begin(), - var_names.end(), - std::inserter(invalid_variables, invalid_variables.end()), - [](const auto& name_view) { return std::string(name_view.data()); }); + std::transform( + var_names.begin(), + var_names.end(), + std::inserter(invalid_variables, invalid_variables.end()), + [](const auto& name_view) { return std::string(name_view.data()); }); // erase used variable names - auto exclude_arguments_fn = [&invalid_variables](const std::vector& args) { - std::for_each(args.begin(), args.end(), [&invalid_variables](const std::string& var_name) { - invalid_variables.erase(var_name); - }); - }; + auto exclude_arguments_fn = + [&invalid_variables](const std::vector& args) { + std::for_each(args.begin(), + args.end(), + [&invalid_variables](const std::string& var_name) { + invalid_variables.erase(var_name); + }); + }; // iterate the arguments of each instruction, eliminate the // used variables, and remain variables are invalid finally auto unused_var_num = invalid_variables.size(); - VLOG(3) << "Before removing invalid variables: " << instructions.size() << " instructions, " - << invalid_variables.size() << " variables"; + VLOG(3) << "Before removing invalid variables: " << instructions.size() + << " instructions, " << invalid_variables.size() << " variables"; for (auto i = 0; i < instructions.size(); ++i) { - const auto& instr = instructions.at(i); - const auto& in_args = instr->GetInArgs(); + const auto& instr = instructions.at(i); + const auto& in_args = instr->GetInArgs(); const auto& out_args = instr->GetOutArgs(); std::for_each(in_args.begin(), in_args.end(), exclude_arguments_fn); std::for_each(out_args.begin(), out_args.end(), exclude_arguments_fn); - VLOG(3) << "Instruction-" << i << " filter " << unused_var_num - invalid_variables.size() << " used variables"; + VLOG(3) << "Instruction-" << i << " filter " + << unused_var_num - invalid_variables.size() << " used variables"; unused_var_num = invalid_variables.size(); } - VLOG(3) << "There are " << unused_var_num << " invalid variables to be removed from scope"; - std::for_each(invalid_variables.begin(), invalid_variables.end(), [this](const std::string& var_name) { - scope_->EraseVar(var_name); - VLOG(3) << "Variable(" << var_name << ") is erased"; - }); + VLOG(3) << "There are " << unused_var_num + << " invalid variables to be removed from scope"; + std::for_each(invalid_variables.begin(), + invalid_variables.end(), + [this](const std::string& var_name) { + scope_->EraseVar(var_name); + VLOG(3) << "Variable(" << var_name << ") is erased"; + }); } static void BufferMallocWithCallback(void* args, int num_args) { cinn_pod_value_t* pod_args = static_cast(args); for (int i = 0; i < num_args; ++i) { cinn_buffer_t* buffer = static_cast(pod_args[i]); - CHECK(buffer->external_malloc) << "external_malloc is nullptr at " << i << "-th argumemnts"; + CHECK(buffer->external_malloc) + << "external_malloc is nullptr at " << i << "-th argumemnts"; buffer->external_malloc->operator()(nullptr, buffer); } } @@ -1320,10 +1483,12 @@ static void BufferFreeWithCallback(void* args, int num_args) { } } -void GraphCompiler::AnalyzeVariableLifeTime(const std::vector>& instructions, - std::unordered_map>* step2malloc, - std::unordered_map>* step2free) { - utils::RecordEvent("GraphCompiler AnalyzeVariableLifeTime", utils::EventType::kOrdinary); +void GraphCompiler::AnalyzeVariableLifeTime( + const std::vector>& instructions, + std::unordered_map>* step2malloc, + std::unordered_map>* step2free) { + utils::RecordEvent("GraphCompiler AnalyzeVariableLifeTime", + utils::EventType::kOrdinary); absl::flat_hash_map variable_last_used, variable_first_used; for (auto step = 0; step < instructions.size(); ++step) { const auto& instr = instructions.at(step); @@ -1352,8 +1517,10 @@ void GraphCompiler::AnalyzeVariableLifeTime(const std::vector>* instructions) { - utils::RecordEvent("GraphCompiler InsertBufferHandlers", utils::EventType::kOrdinary); +void GraphCompiler::InsertBufferHandlers( + std::vector>* instructions) { + utils::RecordEvent("GraphCompiler InsertBufferHandlers", + utils::EventType::kOrdinary); std::unordered_map> step2malloc, step2free; AnalyzeVariableLifeTime(*instructions, &step2malloc, &step2free); @@ -1366,11 +1533,17 @@ void GraphCompiler::InsertBufferHandlers(std::vectorsecond; - auto function_name = "malloc_buffer_instruction_" + std::to_string(step); - auto malloc_instr = std::make_unique( - common::DefaultHostTarget(), scope_.get(), malloc_var_names, std::vector({}), function_name); - VLOG(4) << "seting malloc function " << function_name << " for var " << cinn::utils::Join(malloc_var_names, ", "); - malloc_instr->SetLoweredFunc(reinterpret_cast(BufferMallocWithCallback), function_name); + auto function_name = "malloc_buffer_instruction_" + std::to_string(step); + auto malloc_instr = + std::make_unique(common::DefaultHostTarget(), + scope_.get(), + malloc_var_names, + std::vector({}), + function_name); + VLOG(4) << "seting malloc function " << function_name << " for var " + << cinn::utils::Join(malloc_var_names, ", "); + malloc_instr->SetLoweredFunc( + reinterpret_cast(BufferMallocWithCallback), function_name); malloc_instr->Finalize(); results.emplace_back(std::move(malloc_instr)); } @@ -1383,11 +1556,17 @@ void GraphCompiler::InsertBufferHandlers(std::vectorsecond; - auto function_name = "free_buffer_instruction_" + std::to_string(step); - auto free_instr = std::make_unique( - common::DefaultHostTarget(), scope_.get(), std::vector({}), free_var_names, function_name); - VLOG(4) << "setting free function " << function_name << " for var " << cinn::utils::Join(free_var_names, ", "); - free_instr->SetLoweredFunc(reinterpret_cast(BufferFreeWithCallback), function_name); + auto function_name = "free_buffer_instruction_" + std::to_string(step); + auto free_instr = + std::make_unique(common::DefaultHostTarget(), + scope_.get(), + std::vector({}), + free_var_names, + function_name); + VLOG(4) << "setting free function " << function_name << " for var " + << cinn::utils::Join(free_var_names, ", "); + free_instr->SetLoweredFunc( + reinterpret_cast(BufferFreeWithCallback), function_name); free_instr->Finalize(); results.emplace_back(std::move(free_instr)); } @@ -1397,11 +1576,14 @@ void GraphCompiler::InsertBufferHandlers(std::vectorswap(results); } -std::vector GraphCompiler::OpGetInputNames(const Node* node) const { +std::vector GraphCompiler::OpGetInputNames( + const Node* node) const { std::vector res; - if (node->op()->name == "cublas_gemm" || node->op()->name == "cublas_matmul" || node->op()->name == "conv2d" || - node->op()->name == "depthwise_conv2d" || node->op()->name == "pool2d" || node->op()->name == "softmax" || - node->op()->name == "mul" || node->op()->name == "matmul") { + if (node->op()->name == "cublas_gemm" || + node->op()->name == "cublas_matmul" || node->op()->name == "conv2d" || + node->op()->name == "depthwise_conv2d" || node->op()->name == "pool2d" || + node->op()->name == "softmax" || node->op()->name == "mul" || + node->op()->name == "matmul") { for (auto& i : node->inlinks_in_order()) { res.push_back(i->source()->as()->id()); } @@ -1419,7 +1601,8 @@ std::vector GraphCompiler::OpGetInputNames(const Node* node) const return res; } -std::vector GraphCompiler::OpGetOutputNames(const Node* node) const { +std::vector GraphCompiler::OpGetOutputNames( + const Node* node) const { std::vector res; for (auto& i : node->outlinks_in_order()) { res.push_back(i->sink()->as()->id()); @@ -1427,37 +1610,47 @@ std::vector GraphCompiler::OpGetOutputNames(const Node* node) const return res; } -std::shared_ptr BuildScope(Target target, const std::shared_ptr& graph, std::shared_ptr scope) { +std::shared_ptr BuildScope(Target target, + const std::shared_ptr& graph, + std::shared_ptr scope) { utils::RecordEvent("GraphCompiler BuildScope", utils::EventType::kOrdinary); - auto& shape_dict = graph->GetAttrs>("infershape"); - auto& dtype_dict = graph->GetAttrs>("inferdtype"); + auto& shape_dict = + graph->GetAttrs>("infershape"); + auto& dtype_dict = + graph->GetAttrs>("inferdtype"); if (!scope) scope = std::make_shared(); for (auto& iter : shape_dict) { - auto* var = scope->Var(iter.first); + auto* var = scope->Var(iter.first); auto& tensor = absl::get(*var); std::vector shape; for (auto& shape_dim : iter.second) { shape.push_back(Shape::dim_t(shape_dim)); } - VLOG(3) << "Tensor [" << iter.first << "] resize to " << utils::Join(shape, ","); + VLOG(3) << "Tensor [" << iter.first << "] resize to " + << utils::Join(shape, ","); tensor->Resize(Shape{shape}); CHECK(dtype_dict.count(iter.first)); CHECK(dtype_dict.at(iter.first).is_supported()) - << "The dtype of node " << iter.first << " is not float or bool or int! Its type " - << dtype_dict.at(iter.first).type() << ", " << dtype_dict.at(iter.first).bits() << " is not implemented yet."; + << "The dtype of node " << iter.first + << " is not float or bool or int! Its type " + << dtype_dict.at(iter.first).type() << ", " + << dtype_dict.at(iter.first).bits() << " is not implemented yet."; tensor->set_type(dtype_dict.at(iter.first)); } return scope; } -std::vector GetFuncFromImpl(const std::shared_ptr& impl, - const common::CINNValuePack& cinn_inputs, - std::vector& all_arg_tensors, - const std::vector& input_output_nodes, - const std::string& node_id, - const Target& target) { - utils::RecordEvent("GraphCompiler GetFuncFromImpl", utils::EventType::kOrdinary); - // 1.Call Op's Compute function, using the default stages and LowerVec to get IR tree. +std::vector GetFuncFromImpl( + const std::shared_ptr& impl, + const common::CINNValuePack& cinn_inputs, + std::vector& all_arg_tensors, + const std::vector& input_output_nodes, + const std::string& node_id, + const Target& target) { + utils::RecordEvent("GraphCompiler GetFuncFromImpl", + utils::EventType::kOrdinary); + // 1.Call Op's Compute function, using the default stages and LowerVec to get + // IR tree. common::CINNValuePack C = impl->fcompute(cinn_inputs); // 2. Collect tensors and arguments @@ -1465,14 +1658,22 @@ std::vector GetFuncFromImpl(const std::shared_ptr& impl for (int i = 0; i < C->size() - 1; i++) { ir::Expr temp = C[i]; // checkout whether the tensor is with buffer. - if (!temp.as_tensor_ref()->buffer.defined() || target != common::DefaultNVGPUTarget()) { + if (!temp.as_tensor_ref()->buffer.defined() || + target != common::DefaultNVGPUTarget()) { all_arg_tensors.push_back(temp.as_tensor_ref()); } } - poly::StageMap stages = C.back(); + poly::StageMap stages = C.back(); std::string func_name_prefix = "fn_"; - auto funcs = lang::LowerVec(func_name_prefix + node_id, stages, all_arg_tensors, {}, {}, nullptr, target, true); + auto funcs = lang::LowerVec(func_name_prefix + node_id, + stages, + all_arg_tensors, + {}, + {}, + nullptr, + target, + true); std::vector schedule_inputs; for (int i = 0; i < C.size() - 1; ++i) { @@ -1484,20 +1685,23 @@ std::vector GetFuncFromImpl(const std::shared_ptr& impl } // 3. Call Op's Schedule function, optimizing the IR tree by new IR schedule - common::CINNValuePack expr_pack = impl->fschedule(common::CINNValuePack{schedule_inputs}); + common::CINNValuePack expr_pack = + impl->fschedule(common::CINNValuePack{schedule_inputs}); // 4. Optimize the LoweredFunc - VLOG(3) << "expr_pack.size() is : " << expr_pack.size() << ", funcs.size() is " << funcs.size(); + VLOG(3) << "expr_pack.size() is : " << expr_pack.size() + << ", funcs.size() is " << funcs.size(); VLOG(3) << "input_output_nodes.size() is: " << input_output_nodes.size() << ", all_arg_tensors.size() is: " << all_arg_tensors.size(); std::vector funcs_after_schedule; CHECK_GE(funcs.size(), expr_pack.size()); - if (funcs.size() > expr_pack.size() || all_arg_tensors.size() > input_output_nodes.size()) { + if (funcs.size() > expr_pack.size() || + all_arg_tensors.size() > input_output_nodes.size()) { for (int i = 0; i < funcs.size(); i++) { for (int j = 0; j < expr_pack.size(); j++) { Expr temp = expr_pack[j]; if (temp == funcs[i]->body) { - auto new_args = lang::GetArgs(funcs[i]->body, input_output_nodes); + auto new_args = lang::GetArgs(funcs[i]->body, input_output_nodes); funcs[i]->args = new_args; funcs_after_schedule.push_back(funcs[i]); break; @@ -1515,13 +1719,17 @@ std::vector GetFuncFromImpl(const std::shared_ptr& impl #ifdef CINN_WITH_CUDA optim::OptimizeExprGPU(&(funcs_after_schedule[i]->body)); #endif - auto temp_buffers = lang::GetTempBuffers(all_arg_tensors, stages, funcs_after_schedule[i]->body); + auto temp_buffers = lang::GetTempBuffers( + all_arg_tensors, stages, funcs_after_schedule[i]->body); funcs_after_schedule[i]->temp_bufs = temp_buffers; - funcs_after_schedule[i] = ir::_LoweredFunc_::Make(funcs_after_schedule[i]->name, - funcs_after_schedule[i]->args, - funcs_after_schedule[i]->body, - funcs_after_schedule[i]->temp_bufs); - res.emplace_back(optim::Optimize(Expr(funcs_after_schedule[i]), target, false).as_lowered_func_ref()); + funcs_after_schedule[i] = + ir::_LoweredFunc_::Make(funcs_after_schedule[i]->name, + funcs_after_schedule[i]->args, + funcs_after_schedule[i]->body, + funcs_after_schedule[i]->temp_bufs); + res.emplace_back( + optim::Optimize(Expr(funcs_after_schedule[i]), target, false) + .as_lowered_func_ref()); } // 5. Return the result. return res; diff --git a/paddle/cinn/hlir/framework/graph_compiler.h b/paddle/cinn/hlir/framework/graph_compiler.h index 5cb6ed7e8267e..ae482e1165ff9 100644 --- a/paddle/cinn/hlir/framework/graph_compiler.h +++ b/paddle/cinn/hlir/framework/graph_compiler.h @@ -50,18 +50,22 @@ class Program { * @param scope The scope containing all the runtime variables. * @param instrs The instructions belonging to this program. */ - Program(const std::shared_ptr& scope, std::vector>&& instrs); + Program(const std::shared_ptr& scope, + std::vector>&& instrs); - void PreRun(const std::map* name2podargs = nullptr); + void PreRun( + const std::map* name2podargs = nullptr); - void Export(const std::vector& persistent_vars, const std::string& filename); + void Export(const std::vector& persistent_vars, + const std::string& filename); /** * Execute the program -- that is running all the instructions inside it. */ - void Execute(const std::map* name2podargs = nullptr, - void* stream = nullptr, - bool use_cache = true); + void Execute( + const std::map* name2podargs = nullptr, + void* stream = nullptr, + bool use_cache = true); void ExecuteTest(int repeat_); @@ -70,8 +74,12 @@ class Program { */ size_t size() const { return instrs_.size(); } - const std::vector>& GetPreRunInstructions() { return prerun_instrs_; } - const std::vector>& GetRunInstructions() { return instrs_; } + const std::vector>& GetPreRunInstructions() { + return prerun_instrs_; + } + const std::vector>& GetRunInstructions() { + return instrs_; + } private: // We need to hold scope to assure tensors alive used in instructions. @@ -87,18 +95,23 @@ class Program { */ class GraphCompiler final { public: - GraphCompiler(Target target, const std::shared_ptr& scope, const std::shared_ptr& graph) - : target_(std::move(target)), scope_(scope), graph_(graph), m_builder_(UniqName("module"), target) {} + GraphCompiler(Target target, + const std::shared_ptr& scope, + const std::shared_ptr& graph) + : target_(std::move(target)), + scope_(scope), + graph_(graph), + m_builder_(UniqName("module"), target) {} struct CompilationResult { std::unique_ptr runtime_program; }; struct CompileOptions { - std::string attached_code = ""; - bool with_instantiate_variables = false; + std::string attached_code = ""; + bool with_instantiate_variables = false; bool with_buffer_handle_instruction_inserted = false; - bool remove_unused_variables = true; + bool remove_unused_variables = true; // nodes group, it may come from the result of op fusion or graph tuning. // nodes in a group will be built into an Instruction std::vector> groups; @@ -113,7 +126,7 @@ class GraphCompiler final { // Compile with a packing option and result, to be extended easily. CompilationResult Build(const CompileOptions& options, std::unordered_set&& fetch_var_ids = {}, - void* stream = nullptr); + void* stream = nullptr); void ExportObject(const std::string& path) { compiler_->ExportObject(path); } std::unique_ptr Build(const std::string& code = ""); @@ -129,11 +142,14 @@ class GraphCompiler final { std::vector GetOpFunc(const Node* node); // Given a node, lower it to LoweredFunc using new ir schedule - std::vector GetOpFuncWithIRSchedule(const Node* node, - const absl::flat_hash_map& type_dict_, - const absl::flat_hash_map& shape_dict_); + std::vector GetOpFuncWithIRSchedule( + const Node* node, + const absl::flat_hash_map& type_dict_, + const absl::flat_hash_map& shape_dict_); - std::string GenOpFuncName(const Node* node) const { return "fn_" + node->id(); } + std::string GenOpFuncName(const Node* node) const { + return "fn_" + node->id(); + } // append a unique number at the end of the function name to distinguish // different functions from graphs whose structures are same @@ -145,25 +161,29 @@ class GraphCompiler final { std::vector OpGetOutputNames(const Node* node) const; std::vector> BuildInstructions( - const std::vector>& groups, const std::vector>& fusion_groups); + const std::vector>& groups, + const std::vector>& fusion_groups); void BuildCublasInstr(const Node& node, Instruction* instr) const; // some variables are eliminated by optimized passes(such as OpFusion), // we can filter out them according to arguments of the built instructions, // and erase them from the scope to avoid unnecessary buffer allocation - void RemoveInvalidVariables(const std::vector>& instructions); + void RemoveInvalidVariables( + const std::vector>& instructions); // find the first and last instruction where a variable used, and mark the // variable should allocate buffer before the first instruction runing and // can release the buffer after the last instruction finished. - void AnalyzeVariableLifeTime(const std::vector>& instructions, - std::unordered_map>* step2malloc, - std::unordered_map>* step2free); + void AnalyzeVariableLifeTime( + const std::vector>& instructions, + std::unordered_map>* step2malloc, + std::unordered_map>* step2free); // insert a buffer malloc instruction applying on variables before they are // firstly used in the next instruction, and insert a buffer free instruction // applying on variables after no instruction will use them anymore - void InsertBufferHandlers(std::vector>* instructions); + void InsertBufferHandlers( + std::vector>* instructions); private: // parallel compiler @@ -178,7 +198,8 @@ class GraphCompiler final { std::map> function2input_args_; // mapping a function's name to its output artuments' names std::map> function2output_args_; - // fetch var ids in cinn and the corresponding var nodes will not be fused so as to get the result + // fetch var ids in cinn and the corresponding var nodes will not be fused so + // as to get the result std::unordered_set fetch_var_ids_; absl::flat_hash_map prefix2full_namemap_; @@ -198,12 +219,13 @@ std::shared_ptr BuildScope(Target target, std::shared_ptr scope = nullptr); // Given params, lower the op to LoweredFunc using new IR Schedule -std::vector GetFuncFromImpl(const std::shared_ptr& impl, - const common::CINNValuePack& cinn_inputs, - std::vector& tensor_inputs, - const std::vector& input_output_nodes, - const std::string& node_id, - const Target& target); +std::vector GetFuncFromImpl( + const std::shared_ptr& impl, + const common::CINNValuePack& cinn_inputs, + std::vector& tensor_inputs, + const std::vector& input_output_nodes, + const std::string& node_id, + const Target& target); } // namespace framework } // namespace hlir diff --git a/paddle/cinn/hlir/framework/graph_compiler_test.cc b/paddle/cinn/hlir/framework/graph_compiler_test.cc index 78e57d81668c4..d04d29c64602d 100644 --- a/paddle/cinn/hlir/framework/graph_compiler_test.cc +++ b/paddle/cinn/hlir/framework/graph_compiler_test.cc @@ -40,9 +40,9 @@ TEST(GraphCompilerTest, TestRemoveInvaildVariables) { auto c = builder.Add(a, b, 1); auto d = builder.Relu(c); - auto target = common::DefaultHostTarget(); + auto target = common::DefaultHostTarget(); auto program = builder.Build(); - auto graph = Optimize(&program, {}, target); + auto graph = Optimize(&program, {}, target); auto scope = BuildScope(target, graph); ASSERT_EQ(scope->var_names().size(), 6); @@ -64,47 +64,56 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) { auto c = builder.Add(a, b, 1); auto d = builder.Relu(c); - auto target = common::DefaultHostTarget(); + auto target = common::DefaultHostTarget(); auto program = builder.Build(); - auto graph = Optimize(&program, {}, target); - auto scope = BuildScope(target, graph); + auto graph = Optimize(&program, {}, target); + auto scope = BuildScope(target, graph); GraphCompiler gc_disable(target, scope, graph); GraphCompiler::CompileOptions options; // disable with_buffer_handle_instruction_inserted: only 1 instruction auto runtime_program_disable = gc_disable.Build(options).runtime_program; ASSERT_EQ(runtime_program_disable->size(), 1); - const auto& computation_instr_disable = runtime_program_disable->GetRunInstructions().front(); + const auto& computation_instr_disable = + runtime_program_disable->GetRunInstructions().front(); // enable with_buffer_handle_instruction_inserted: 3 instructions, 1st -> // malloc instruction(a, b, d), 2nd -> the real computation // instruction(add + relu) and 3rd -> free instruction GraphCompiler gc_enable(target, scope, graph); options.with_buffer_handle_instruction_inserted = true; - auto runtime_program_enable = gc_enable.Build(options).runtime_program; - const auto& instructions = runtime_program_enable->GetRunInstructions(); + auto runtime_program_enable = gc_enable.Build(options).runtime_program; + const auto& instructions = runtime_program_enable->GetRunInstructions(); ASSERT_EQ(instructions.size(), 3); const auto& malloc_instr = instructions.front(); ASSERT_EQ(malloc_instr->size(), 1); auto malloc_variable_names = malloc_instr->GetInArgs().front(); - auto used_variable_names = std::unordered_set( - {static_cast(a)->id, static_cast(b)->id, d->id}); + auto used_variable_names = + std::unordered_set({static_cast(a)->id, + static_cast(b)->id, + d->id}); EXPECT_EQ(malloc_instr->GetFnNames().size(), 1); EXPECT_EQ(malloc_instr->GetFnNames().front(), "malloc_buffer_instruction_0"); EXPECT_EQ(malloc_instr->GetOutArgs().size(), 1); EXPECT_TRUE(malloc_instr->GetOutArgs().front().empty()); EXPECT_EQ(malloc_variable_names.size(), 3); - EXPECT_EQ(std::unordered_set(malloc_variable_names.begin(), malloc_variable_names.end()), + EXPECT_EQ(std::unordered_set(malloc_variable_names.begin(), + malloc_variable_names.end()), used_variable_names); const auto& computation_instr_enable = instructions.at(1); - ASSERT_EQ(computation_instr_disable->size(), computation_instr_enable->size()); - auto computation_instr_function_names = computation_instr_enable->GetFnNames(); - ASSERT_EQ(computation_instr_disable->GetFnNames().size(), computation_instr_enable->GetFnNames().size()); - - EXPECT_EQ(computation_instr_disable->GetInArgs(), computation_instr_enable->GetInArgs()); - EXPECT_EQ(computation_instr_disable->GetOutArgs(), computation_instr_enable->GetOutArgs()); + ASSERT_EQ(computation_instr_disable->size(), + computation_instr_enable->size()); + auto computation_instr_function_names = + computation_instr_enable->GetFnNames(); + ASSERT_EQ(computation_instr_disable->GetFnNames().size(), + computation_instr_enable->GetFnNames().size()); + + EXPECT_EQ(computation_instr_disable->GetInArgs(), + computation_instr_enable->GetInArgs()); + EXPECT_EQ(computation_instr_disable->GetOutArgs(), + computation_instr_enable->GetOutArgs()); const auto& free_instr = instructions.back(); ASSERT_EQ(free_instr->size(), 1); @@ -113,13 +122,19 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) { EXPECT_EQ(free_instr->GetInArgs().size(), 1); EXPECT_TRUE(free_instr->GetInArgs().front().empty()); auto free_variable_names = free_instr->GetOutArgs().front(); - EXPECT_EQ(std::unordered_set(free_variable_names.begin(), free_variable_names.end()), + EXPECT_EQ(std::unordered_set(free_variable_names.begin(), + free_variable_names.end()), used_variable_names); } #ifdef CINN_WITH_CUDA -std::vector test_mul( - const std::vector& A, const std::vector& B, int M, int K, int N, bool trans_a, bool trans_b) { +std::vector test_mul(const std::vector& A, + const std::vector& B, + int M, + int K, + int N, + bool trans_a, + bool trans_b) { std::vector C(M * N, 0); if (!trans_a && !trans_b) { for (int idx = 0; idx < M; ++idx) { @@ -157,15 +172,22 @@ std::vector test_mul( return C; } -void RunCublas(int M, int N, int K, bool trans_a = false, bool trans_b = false) { +void RunCublas( + int M, int N, int K, bool trans_a = false, bool trans_b = false) { frontend::NetBuilder net_builder("builder"); - auto A = net_builder.CreateInput(Float(32), trans_a ? std::vector({K, M}) : std::vector({M, K}), "A"); - auto B = net_builder.CreateInput(Float(32), trans_b ? std::vector({N, K}) : std::vector({K, N}), "B"); + auto A = net_builder.CreateInput( + Float(32), + trans_a ? std::vector({K, M}) : std::vector({M, K}), + "A"); + auto B = net_builder.CreateInput( + Float(32), + trans_b ? std::vector({N, K}) : std::vector({K, N}), + "B"); auto C = net_builder.Matmul(A, B, trans_a, trans_b); auto program = net_builder.Build(); - auto target = common::DefaultTarget(); - auto graph = std::make_shared(program, target); + auto target = common::DefaultTarget(); + auto graph = std::make_shared(program, target); hlir::framework::ApplyPass(graph.get(), "TransToCustomCallPass"); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); diff --git a/paddle/cinn/hlir/framework/graph_test.cc b/paddle/cinn/hlir/framework/graph_test.cc index c5928869fd0d5..4dd944ca64b7a 100644 --- a/paddle/cinn/hlir/framework/graph_test.cc +++ b/paddle/cinn/hlir/framework/graph_test.cc @@ -28,15 +28,15 @@ namespace framework { TEST(Graph, visualize) { frontend::NetBuilder builder("test"); - auto x = builder.CreateInput(Float(32), {32, 16}, "x"); - auto y = builder.CreateInput(Float(32), {32, 16}, "y"); - auto add_1 = builder.Add(x, y); - auto relu_1 = builder.Relu(add_1); + auto x = builder.CreateInput(Float(32), {32, 16}, "x"); + auto y = builder.CreateInput(Float(32), {32, 16}, "y"); + auto add_1 = builder.Add(x, y); + auto relu_1 = builder.Relu(add_1); auto reduce_sum_1 = builder.ReduceSum(relu_1, {1}); - auto program = builder.Build(); + auto program = builder.Build(); auto target = common::DefaultHostTarget(); - auto graph = std::make_shared(program, target); + auto graph = std::make_shared(program, target); ApplyPass(graph.get(), "OpFusion"); FLAGS_cinn_fusion_groups_graphviz_dir = "./visualize"; @@ -45,19 +45,19 @@ TEST(Graph, visualize) { TEST(Graph, visualize_recompute) { frontend::NetBuilder builder("test"); - auto x = builder.CreateInput(Float(32), {16, 32}, "x"); - auto y = builder.CreateInput(Float(32), {32, 16}, "y"); - auto z = builder.CreateInput(Float(32), {16}, "z"); - auto constant_1 = builder.FillConstant({16}, 1, "constant_1"); - auto add_1 = builder.Add(z, constant_1); + auto x = builder.CreateInput(Float(32), {16, 32}, "x"); + auto y = builder.CreateInput(Float(32), {32, 16}, "y"); + auto z = builder.CreateInput(Float(32), {16}, "z"); + auto constant_1 = builder.FillConstant({16}, 1, "constant_1"); + auto add_1 = builder.Add(z, constant_1); auto broadcast_to_1 = builder.BroadcastTo(add_1, {16, 32}); auto broadcast_to_2 = builder.BroadcastTo(add_1, {32, 16}); - auto add_2 = builder.Add(x, broadcast_to_1); - auto add_3 = builder.Add(y, broadcast_to_2); - auto program = builder.Build(); + auto add_2 = builder.Add(x, broadcast_to_1); + auto add_3 = builder.Add(y, broadcast_to_2); + auto program = builder.Build(); auto target = common::DefaultHostTarget(); - auto graph = std::make_shared(program, target); + auto graph = std::make_shared(program, target); ApplyPass(graph.get(), "OpFusionPass"); ApplyPass(graph.get(), "FusionMergePass"); diff --git a/paddle/cinn/hlir/framework/instruction.cc b/paddle/cinn/hlir/framework/instruction.cc index 7e822a69af42c..abd86b8a6d4de 100644 --- a/paddle/cinn/hlir/framework/instruction.cc +++ b/paddle/cinn/hlir/framework/instruction.cc @@ -49,16 +49,19 @@ class ResultsPrint { private: ResultsPrint() { - bool print_to_file = !FLAGS_cinn_self_check_accuracy.empty() && - !cinn::runtime::CheckStringFlagTrue(FLAGS_cinn_self_check_accuracy) && - !cinn::runtime::CheckStringFlagFalse(FLAGS_cinn_self_check_accuracy); + bool print_to_file = + !FLAGS_cinn_self_check_accuracy.empty() && + !cinn::runtime::CheckStringFlagTrue(FLAGS_cinn_self_check_accuracy) && + !cinn::runtime::CheckStringFlagFalse(FLAGS_cinn_self_check_accuracy); if (print_to_file) { of_.open(FLAGS_cinn_self_check_accuracy, std::ios_base::out); if (of_.is_open()) { - LOG(INFO) << "The CINN compute results will writing into file: \"" << FLAGS_cinn_self_check_accuracy << "\""; + LOG(INFO) << "The CINN compute results will writing into file: \"" + << FLAGS_cinn_self_check_accuracy << "\""; } else if (!FLAGS_cinn_self_check_accuracy.empty()) { - LOG(WARNING) << "Failed to open file: \"" << FLAGS_cinn_self_check_accuracy + LOG(WARNING) << "Failed to open file: \"" + << FLAGS_cinn_self_check_accuracy << "\", the CINN compute result will print."; } } @@ -74,19 +77,23 @@ class ResultsPrint { }; } // namespace details -void Instruction::UpdateArgsCache(const std::map* name2podargs) { +void Instruction::UpdateArgsCache( + const std::map* name2podargs) { int cache_size = size(); args_cached_.resize(cache_size); for (int i = 0; i < cache_size; ++i) { common::ArgsBuilder builder; std::vector all_args = in_args_[i]; - all_args.insert(std::end(all_args), out_args_[i].begin(), out_args_[i].end()); + all_args.insert( + std::end(all_args), out_args_[i].begin(), out_args_[i].end()); if (name2podargs != nullptr) { for (const auto& arg : all_args) { - CHECK_NE(name2podargs->count(arg), 0) << "Argument [" << arg << "] not found in the name2podargs"; - VLOG(5) << "Get a argument, name=" << arg << ",type_code=" << name2podargs->at(arg).type_code(); + CHECK_NE(name2podargs->count(arg), 0) + << "Argument [" << arg << "] not found in the name2podargs"; + VLOG(5) << "Get a argument, name=" << arg + << ",type_code=" << name2podargs->at(arg).type_code(); builder.Add(name2podargs->at(arg)); } } else { @@ -115,11 +122,13 @@ void Instruction::Finalize() { finalized_flag_ = true; } -void Instruction::Run(const std::map* name2podargs, - bool dryrun, - void* stream, - bool use_cache) { - utils::RecordEvent record_run(function_name_, cinn::utils::EventType::kInstruction); +void Instruction::Run( + const std::map* name2podargs, + bool dryrun, + void* stream, + bool use_cache) { + utils::RecordEvent record_run(function_name_, + cinn::utils::EventType::kInstruction); CHECK(finalized_flag_) << "Instruction must be finalized before run"; if (function_name_ == "no_run") { VLOG(2) << "skip instruction"; @@ -129,35 +138,49 @@ void Instruction::Run(const std::map* name2podarg VLOG(2) << "Run function " << function_name_; { - utils::RecordEvent record_args("UpdateArgsCache", cinn::utils::EventType::kInstruction); + utils::RecordEvent record_args("UpdateArgsCache", + cinn::utils::EventType::kInstruction); if (!use_cache || args_cached_.size() != size()) { UpdateArgsCache(name2podargs); } } - utils::RecordEvent record_args("Instruction::Run", cinn::utils::EventType::kInstruction); + utils::RecordEvent record_args("Instruction::Run", + cinn::utils::EventType::kInstruction); #if defined(CINN_WITH_CUDA) && !defined(CINN_WITH_CUDNN) if (function_name_ == "cublas_gemm" && target_.arch == Target::Arch::NVGPU) { auto& pod_args = args_cached_[0]; VLOG(3) << "The pod_args size of cublas_gemm: " << pod_args.size(); - runtime::cuda::cinn_gpu_cublas_gemm( - attrs, pod_args[0], pod_args[1], pod_args[2], pod_args[3], static_cast(stream)); - } else if (function_name_ == "cublas_matmul" && target_.arch == Target::Arch::NVGPU) { + runtime::cuda::cinn_gpu_cublas_gemm(attrs, + pod_args[0], + pod_args[1], + pod_args[2], + pod_args[3], + static_cast(stream)); + } else if (function_name_ == "cublas_matmul" && + target_.arch == Target::Arch::NVGPU) { auto& pod_args = args_cached_[0]; VLOG(3) << "The pod_args size of cublas_matmul: " << pod_args.size(); - runtime::cuda::cinn_gpu_cublas_gemm( - attrs, pod_args[0], pod_args[1], nullptr, pod_args[2], static_cast(stream)); + runtime::cuda::cinn_gpu_cublas_gemm(attrs, + pod_args[0], + pod_args[1], + nullptr, + pod_args[2], + static_cast(stream)); } else { VLOG(3) << "Runing extern function " << function_name_; for (int idx = 0; idx < fn_ptrs_.size(); ++idx) { VLOG(3) << "Runing func name: " << fn_names_[idx]; auto& pod_args = args_cached_[idx]; - CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first by calling SetLoweredFunc method"; + CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first by " + "calling SetLoweredFunc method"; if (!dryrun) { if (target_ == common::DefaultNVGPUTarget()) { - ((lower_func_ptr_g)fn_ptrs_[idx])(static_cast(pod_args.data()), pod_args.size(), stream); + ((lower_func_ptr_g)fn_ptrs_[idx])( + static_cast(pod_args.data()), pod_args.size(), stream); } else { - ((lower_func_ptr_t)fn_ptrs_[idx])(static_cast(pod_args.data()), pod_args.size()); + ((lower_func_ptr_t)fn_ptrs_[idx])(static_cast(pod_args.data()), + pod_args.size()); } } } @@ -165,84 +188,146 @@ void Instruction::Run(const std::map* name2podarg } #elif defined(CINN_WITH_CUDNN) auto& pod_args = args_cached_[0]; - // Here conv2d and depthwise_conv2d are implemented by one cudnn api cudnnConvolutionForward - if ((function_name_ == "conv2d" || function_name_ == "depthwise_conv2d") && target_.arch == Target::Arch::NVGPU) { + // Here conv2d and depthwise_conv2d are implemented by one cudnn api + // cudnnConvolutionForward + if ((function_name_ == "conv2d" || function_name_ == "depthwise_conv2d") && + target_.arch == Target::Arch::NVGPU) { if (str_attrs[0] == "forward") { if (str_attrs.size() > 1 && str_attrs[1] == "NHWC") { absl::flat_hash_map attrs_map = { - {"input_n", attrs[0]}, {"input_h", attrs[1]}, {"input_w", attrs[2]}, {"input_c", attrs[3]}, - {"weights_n", attrs[4]}, {"weights_c", attrs[5]}, {"weights_h", attrs[6]}, {"weights_w", attrs[7]}, - {"pad_h", attrs[8]}, {"pad_w", attrs[9]}, {"stride_h", attrs[10]}, {"stride_w", attrs[11]}, - {"dilation_h", attrs[12]}, {"dilation_w", attrs[13]}, {"groups", attrs[14]}, {"output_n", attrs[15]}, - {"output_h", attrs[16]}, {"output_w", attrs[17]}, {"output_c", attrs[18]}, + {"input_n", attrs[0]}, {"input_h", attrs[1]}, + {"input_w", attrs[2]}, {"input_c", attrs[3]}, + {"weights_n", attrs[4]}, {"weights_c", attrs[5]}, + {"weights_h", attrs[6]}, {"weights_w", attrs[7]}, + {"pad_h", attrs[8]}, {"pad_w", attrs[9]}, + {"stride_h", attrs[10]}, {"stride_w", attrs[11]}, + {"dilation_h", attrs[12]}, {"dilation_w", attrs[13]}, + {"groups", attrs[14]}, {"output_n", attrs[15]}, + {"output_h", attrs[16]}, {"output_w", attrs[17]}, + {"output_c", attrs[18]}, }; - runtime::cuda::cinn_gpu_cudnn_conv2d( - attrs_map, pod_args[0], pod_args[1], pod_args[2], static_cast(stream), common::Layout::kNHWC); + runtime::cuda::cinn_gpu_cudnn_conv2d(attrs_map, + pod_args[0], + pod_args[1], + pod_args[2], + static_cast(stream), + common::Layout::kNHWC); } else { absl::flat_hash_map attrs_map = { - {"input_n", attrs[0]}, {"input_c", attrs[1]}, {"input_h", attrs[2]}, {"input_w", attrs[3]}, - {"weights_n", attrs[4]}, {"weights_c", attrs[5]}, {"weights_h", attrs[6]}, {"weights_w", attrs[7]}, - {"pad_h", attrs[8]}, {"pad_w", attrs[9]}, {"stride_h", attrs[10]}, {"stride_w", attrs[11]}, - {"dilation_h", attrs[12]}, {"dilation_w", attrs[13]}, {"groups", attrs[14]}, {"output_n", attrs[15]}, - {"output_c", attrs[16]}, {"output_h", attrs[17]}, {"output_w", attrs[18]}, + {"input_n", attrs[0]}, {"input_c", attrs[1]}, + {"input_h", attrs[2]}, {"input_w", attrs[3]}, + {"weights_n", attrs[4]}, {"weights_c", attrs[5]}, + {"weights_h", attrs[6]}, {"weights_w", attrs[7]}, + {"pad_h", attrs[8]}, {"pad_w", attrs[9]}, + {"stride_h", attrs[10]}, {"stride_w", attrs[11]}, + {"dilation_h", attrs[12]}, {"dilation_w", attrs[13]}, + {"groups", attrs[14]}, {"output_n", attrs[15]}, + {"output_c", attrs[16]}, {"output_h", attrs[17]}, + {"output_w", attrs[18]}, }; - runtime::cuda::cinn_gpu_cudnn_conv2d( - attrs_map, pod_args[0], pod_args[1], pod_args[2], static_cast(stream), common::Layout::kNCHW); + runtime::cuda::cinn_gpu_cudnn_conv2d(attrs_map, + pod_args[0], + pod_args[1], + pod_args[2], + static_cast(stream), + common::Layout::kNCHW); } } else if (str_attrs[0] == "backward_data") { // w, dy, dx absl::flat_hash_map attrs_map = { - {"input_n", attrs[15]}, {"input_c", attrs[16]}, {"input_h", attrs[17]}, {"input_w", attrs[18]}, - {"weights_n", attrs[0]}, {"weights_c", attrs[1]}, {"weights_h", attrs[2]}, {"weights_w", attrs[3]}, - {"pad_h", attrs[8]}, {"pad_w", attrs[9]}, {"stride_h", attrs[10]}, {"stride_w", attrs[11]}, - {"dilation_h", attrs[12]}, {"dilation_w", attrs[13]}, {"groups", attrs[14]}, {"output_n", attrs[4]}, - {"output_c", attrs[5]}, {"output_h", attrs[6]}, {"output_w", attrs[7]}, + {"input_n", attrs[15]}, {"input_c", attrs[16]}, + {"input_h", attrs[17]}, {"input_w", attrs[18]}, + {"weights_n", attrs[0]}, {"weights_c", attrs[1]}, + {"weights_h", attrs[2]}, {"weights_w", attrs[3]}, + {"pad_h", attrs[8]}, {"pad_w", attrs[9]}, + {"stride_h", attrs[10]}, {"stride_w", attrs[11]}, + {"dilation_h", attrs[12]}, {"dilation_w", attrs[13]}, + {"groups", attrs[14]}, {"output_n", attrs[4]}, + {"output_c", attrs[5]}, {"output_h", attrs[6]}, + {"output_w", attrs[7]}, }; // w, dy, dx runtime::cuda::cinn_gpu_cudnn_conv2d_backward_data( - attrs_map, pod_args[0], pod_args[1], pod_args[2], static_cast(stream)); + attrs_map, + pod_args[0], + pod_args[1], + pod_args[2], + static_cast(stream)); } else { // x, dy, w absl::flat_hash_map attrs_map = { - {"input_n", attrs[0]}, {"input_c", attrs[1]}, {"input_h", attrs[2]}, {"input_w", attrs[3]}, - {"weights_n", attrs[15]}, {"weights_c", attrs[16]}, {"weights_h", attrs[17]}, {"weights_w", attrs[18]}, - {"pad_h", attrs[8]}, {"pad_w", attrs[9]}, {"stride_h", attrs[10]}, {"stride_w", attrs[11]}, - {"dilation_h", attrs[12]}, {"dilation_w", attrs[13]}, {"groups", attrs[14]}, {"output_n", attrs[4]}, - {"output_c", attrs[5]}, {"output_h", attrs[6]}, {"output_w", attrs[7]}, + {"input_n", attrs[0]}, {"input_c", attrs[1]}, + {"input_h", attrs[2]}, {"input_w", attrs[3]}, + {"weights_n", attrs[15]}, {"weights_c", attrs[16]}, + {"weights_h", attrs[17]}, {"weights_w", attrs[18]}, + {"pad_h", attrs[8]}, {"pad_w", attrs[9]}, + {"stride_h", attrs[10]}, {"stride_w", attrs[11]}, + {"dilation_h", attrs[12]}, {"dilation_w", attrs[13]}, + {"groups", attrs[14]}, {"output_n", attrs[4]}, + {"output_c", attrs[5]}, {"output_h", attrs[6]}, + {"output_w", attrs[7]}, }; // x, dy, w runtime::cuda::cinn_gpu_cudnn_conv2d_backward_filter( - attrs_map, pod_args[0], pod_args[1], pod_args[2], static_cast(stream)); + attrs_map, + pod_args[0], + pod_args[1], + pod_args[2], + static_cast(stream)); } - } else if (function_name_ == "pool2d" && target_.arch == Target::Arch::NVGPU) { - runtime::cuda::cinn_gpu_cudnn_pool2d(attrs, str_attrs, pod_args[0], pod_args[1], static_cast(stream)); - } else if (function_name_ == "softmax" && target_.arch == Target::Arch::NVGPU) { + } else if (function_name_ == "pool2d" && + target_.arch == Target::Arch::NVGPU) { + runtime::cuda::cinn_gpu_cudnn_pool2d(attrs, + str_attrs, + pod_args[0], + pod_args[1], + static_cast(stream)); + } else if (function_name_ == "softmax" && + target_.arch == Target::Arch::NVGPU) { CHECK_EQ(pod_args.size(), 3); - runtime::cuda::cinn_gpu_cudnn_softmax(attrs, pod_args[0], pod_args[1], static_cast(stream)); + runtime::cuda::cinn_gpu_cudnn_softmax( + attrs, pod_args[0], pod_args[1], static_cast(stream)); } else if (function_name_ == "mul" && target_.arch == Target::Arch::NVGPU) { CHECK_EQ(pod_args.size(), 4); - runtime::cuda::cinn_gpu_cublas_mul(attrs, pod_args[0], pod_args[1], pod_args[2], static_cast(stream)); - } else if (function_name_ == "cublas_gemm" && target_.arch == Target::Arch::NVGPU) { + runtime::cuda::cinn_gpu_cublas_mul(attrs, + pod_args[0], + pod_args[1], + pod_args[2], + static_cast(stream)); + } else if (function_name_ == "cublas_gemm" && + target_.arch == Target::Arch::NVGPU) { VLOG(3) << "The pod_args size of cublas_gemm: " << pod_args.size(); - runtime::cuda::cinn_gpu_cublas_gemm( - attrs, pod_args[0], pod_args[1], pod_args[2], pod_args[3], static_cast(stream)); - } else if (function_name_ == "cublas_matmul" && target_.arch == Target::Arch::NVGPU) { + runtime::cuda::cinn_gpu_cublas_gemm(attrs, + pod_args[0], + pod_args[1], + pod_args[2], + pod_args[3], + static_cast(stream)); + } else if (function_name_ == "cublas_matmul" && + target_.arch == Target::Arch::NVGPU) { auto& pod_args = args_cached_[0]; VLOG(3) << "The pod_args size of cublas_matmul: " << pod_args.size(); - runtime::cuda::cinn_gpu_cublas_gemm( - attrs, pod_args[0], pod_args[1], nullptr, pod_args[2], static_cast(stream)); + runtime::cuda::cinn_gpu_cublas_gemm(attrs, + pod_args[0], + pod_args[1], + nullptr, + pod_args[2], + static_cast(stream)); } else { VLOG(3) << "Runing extern function " << function_name_; for (int idx = 0; idx < fn_ptrs_.size(); ++idx) { VLOG(3) << "Runing func name: " << fn_names_[idx]; auto& pod_args = args_cached_[idx]; - CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first by calling SetLoweredFunc method"; + CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first by " + "calling SetLoweredFunc method"; if (!dryrun) { if (target_ == common::DefaultNVGPUTarget()) { - ((lower_func_ptr_g)fn_ptrs_[idx])(static_cast(pod_args.data()), pod_args.size(), stream); + ((lower_func_ptr_g)fn_ptrs_[idx])( + static_cast(pod_args.data()), pod_args.size(), stream); } else { - ((lower_func_ptr_t)fn_ptrs_[idx])(static_cast(pod_args.data()), pod_args.size()); + ((lower_func_ptr_t)fn_ptrs_[idx])(static_cast(pod_args.data()), + pod_args.size()); } } } @@ -253,12 +338,15 @@ void Instruction::Run(const std::map* name2podarg for (int idx = 0; idx < fn_ptrs_.size(); ++idx) { VLOG(3) << "Runing func name: " << fn_names_[idx]; auto& pod_args = args_cached_[idx]; - CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first by calling SetLoweredFunc method"; + CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first by " + "calling SetLoweredFunc method"; if (!dryrun) { if (target_ == common::DefaultNVGPUTarget()) { - ((lower_func_ptr_g)fn_ptrs_[idx])(static_cast(pod_args.data()), pod_args.size(), stream); + ((lower_func_ptr_g)fn_ptrs_[idx])( + static_cast(pod_args.data()), pod_args.size(), stream); } else { - ((lower_func_ptr_t)fn_ptrs_[idx])(static_cast(pod_args.data()), pod_args.size()); + ((lower_func_ptr_t)fn_ptrs_[idx])(static_cast(pod_args.data()), + pod_args.size()); } } } @@ -277,13 +365,15 @@ void Instruction::Run(const std::map* name2podarg // } } -void Instruction::CheckResults(const std::map* name2podargs, void* stream) { +void Instruction::CheckResults( + const std::map* name2podargs, void* stream) { #ifdef CINN_WITH_CUDA cudaStreamSynchronize(static_cast(stream)); #endif if (fn_names_.size() == 1) { - std::unordered_set skipped_instr_set = {"malloc_buffer_instruction", "free_buffer_instruction"}; + std::unordered_set skipped_instr_set = { + "malloc_buffer_instruction", "free_buffer_instruction"}; for (auto& name : skipped_instr_set) { if (fn_names_[0].find(name) != std::string::npos) { // Skip the malloc & free buffer instructions. diff --git a/paddle/cinn/hlir/framework/instruction.h b/paddle/cinn/hlir/framework/instruction.h index 61184b370ebd0..225cf6d08fd56 100644 --- a/paddle/cinn/hlir/framework/instruction.h +++ b/paddle/cinn/hlir/framework/instruction.h @@ -32,28 +32,36 @@ namespace hlir { namespace framework { /** - * Instruction is the basic executable element in runtime, it holds a pointer to the JIT-compiled LoweredFunc, and - * collect the cinn_buffer of the inputs and outputs from the scope, prepare the arguments and finally pass them into - * the LoweredFunc and execute it. + * Instruction is the basic executable element in runtime, it holds a pointer to + * the JIT-compiled LoweredFunc, and collect the cinn_buffer of the inputs and + * outputs from the scope, prepare the arguments and finally pass them into the + * LoweredFunc and execute it. */ class Instruction { public: - using infershape_t = std::function&)>; + using infershape_t = + std::function&)>; /** * Constructor. * @param target The \p target the instruction runs on. - * @param scope The scope containing all the runtime variables(Tensors and PODs). + * @param scope The scope containing all the runtime variables(Tensors and + * PODs). * @param in_args The names of the inputs. * @param out_args The names of the outputs. - * @param infershape The handler of this Instruction to perform shape inference. + * @param infershape The handler of this Instruction to perform shape + * inference. */ Instruction(const Target& target, Scope* scope, const std::vector& in_args, const std::vector& out_args, const std::string& function_name = "") - : target_(target), scope_(scope), in_args_({in_args}), out_args_({out_args}), function_name_(function_name) {} + : target_(target), + scope_(scope), + in_args_({in_args}), + out_args_({out_args}), + function_name_(function_name) {} /** * Set compiled function address. @@ -64,19 +72,23 @@ class Instruction { fn_names_.push_back(name); } - // explicitly finalize the instruction, and can't append function again after call it + // explicitly finalize the instruction, and can't append function again after + // call it void Finalize(); - void UpdateArgsCache(const std::map* name2podargs); + void UpdateArgsCache( + const std::map* name2podargs); /** * Run the Instruction. */ - void Run(const std::map* name2podargs = nullptr, - bool dryrun = false, - void* stream = nullptr, - bool use_cache = true); - - void PreRun(const std::map* name2podargs = nullptr) { + void Run( + const std::map* name2podargs = nullptr, + bool dryrun = false, + void* stream = nullptr, + bool use_cache = true); + + void PreRun( + const std::map* name2podargs = nullptr) { CHECK_EQ(fn_ptrs_.size(), 4); if (fn_ptrs_.size() > 1 && fn_ptrs_.size() != in_args_.size()) { out_args_.back()[0] = out_args_.front()[0]; @@ -88,18 +100,21 @@ class Instruction { CHECK_EQ(fn_ptrs_.size(), in_args_.size()); CHECK_EQ(fn_ptrs_.size(), out_args_.size()); - int flag = -1; + int flag = -1; void* stream = nullptr; for (int idx = 0; idx < 4; idx++) { if (utils::Startswith(out_args_[idx][0], "kernel_pack")) { VLOG(3) << "PreRun " << idx << "-th function of fn_:" << fn_names_[idx]; - flag = idx; + flag = idx; auto& pod_args = args_cached_[idx]; - CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first by calling SetLoweredFunc method"; + CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first " + "by calling SetLoweredFunc method"; if (target_ == common::DefaultNVGPUTarget()) { - ((lower_func_ptr_g)fn_ptrs_[idx])(static_cast(pod_args.data()), pod_args.size(), stream); + ((lower_func_ptr_g)fn_ptrs_[idx])( + static_cast(pod_args.data()), pod_args.size(), stream); } else { - ((lower_func_ptr_t)fn_ptrs_[idx])(static_cast(pod_args.data()), pod_args.size()); + ((lower_func_ptr_t)fn_ptrs_[idx])(static_cast(pod_args.data()), + pod_args.size()); } #ifdef CINN_WITH_CUDA CUDA_CALL(cudaDeviceSynchronize()); @@ -122,15 +137,21 @@ class Instruction { void ClearInArgs() { in_args_.clear(); } void ClearOutArgs() { out_args_.clear(); } std::vector GetFnNames() { return fn_names_; } - void AddInArgs(const std::vector& in_args) { in_args_.push_back(in_args); } - void AddOutArgs(const std::vector& out_args) { out_args_.push_back(out_args); } + void AddInArgs(const std::vector& in_args) { + in_args_.push_back(in_args); + } + void AddOutArgs(const std::vector& out_args) { + out_args_.push_back(out_args); + } std::vector attrs; std::vector str_attrs; bool pre_run = false; Target target_; protected: - void CheckResults(const std::map* name2podargs = nullptr, void* stream = nullptr); + void CheckResults( + const std::map* name2podargs = nullptr, + void* stream = nullptr); private: bool finalized_flag_ = false; diff --git a/paddle/cinn/hlir/framework/instruction_test.cc b/paddle/cinn/hlir/framework/instruction_test.cc index 5c5bf493b9120..85c99282ee747 100644 --- a/paddle/cinn/hlir/framework/instruction_test.cc +++ b/paddle/cinn/hlir/framework/instruction_test.cc @@ -44,7 +44,7 @@ std::unique_ptr GetLoweredFunc(int M, int N) { {m, n}, [=](Expr i, Expr j) { return x(i, j) + y(i, j); }, "z"); auto stages = CreateStages({z}); - auto fn = Lower("fn", stages, {x, y, z}); + auto fn = Lower("fn", stages, {x, y, z}); ir::Module::Builder builder("some_module", common::DefaultHostTarget()); builder.AddFunction(fn); @@ -56,7 +56,7 @@ std::unique_ptr GetLoweredFunc(int M, int N) { void InstantiateScope(int M, int N, Scope* scope) { for (auto& name : std::vector({"x", "y", "z"})) { - auto* var = scope->Var(name); + auto* var = scope->Var(name); auto& tensor = absl::get(*var); tensor->Resize(Shape{{M, N}}); auto* data = tensor->mutable_data(common::DefaultHostTarget()); @@ -74,7 +74,7 @@ TEST(Instruction, basic) { InstantiateScope(M, N, &scope); // create Instruction Instruction instr(common::DefaultHostTarget(), &scope, {"x", "y"}, {"z"}); - auto jit = GetLoweredFunc(M, N); + auto jit = GetLoweredFunc(M, N); auto fn_ptr = jit->Lookup("fn"); CHECK(fn_ptr); instr.SetLoweredFunc(reinterpret_cast(fn_ptr)); @@ -90,28 +90,34 @@ TEST(Instruction, basic) { auto* zd = scope.GetTensor("z")->data(); for (int i = 0; i < M * N; i++) { - LOG_FIRST_N(INFO, 3) << "data: " << xd[i] << " + " << yd[i] << " = " << zd[i]; + LOG_FIRST_N(INFO, 3) << "data: " << xd[i] << " + " << yd[i] << " = " + << zd[i]; ASSERT_NEAR(xd[i] + yd[i], zd[i], 1e-5); } } } TEST(Instruction, RunWithRawPodArgs) { - const int M = 10; - const int N = 20; + const int M = 10; + const int N = 20; const auto& shape = Shape({M, N}); std::map name2podargs; // case 1: create cinn_pod_value_t arguments dicrectly - std::vector args_buffer(3); // store {"x", "y", "z"} buffer objects - auto* default_memory_mng = MemoryManager::Global().RetrieveSafely(common::DefaultHostTarget().arch); + std::vector args_buffer( + 3); // store {"x", "y", "z"} buffer objects + auto* default_memory_mng = + MemoryManager::Global().RetrieveSafely(common::DefaultHostTarget().arch); int count = 0; for (const auto& name : std::vector({"x", "y", "z"})) { auto* buffer = &args_buffer.at(count++); - buffer->resize(reinterpret_cast(shape.data().data()), shape.size()); - buffer->memory = reinterpret_cast(default_memory_mng->malloc(shape.numel() * sizeof(float))); - auto* data = reinterpret_cast(buffer->memory); + buffer->resize( + reinterpret_cast(shape.data().data()), + shape.size()); + buffer->memory = reinterpret_cast( + default_memory_mng->malloc(shape.numel() * sizeof(float))); + auto* data = reinterpret_cast(buffer->memory); for (int i = 0; i < M * N; i++) { data[i] = (rand() * 1.f) / RAND_MAX; // NOLINT } @@ -119,19 +125,24 @@ TEST(Instruction, RunWithRawPodArgs) { } // create Instruction - auto jit = GetLoweredFunc(M, N); + auto jit = GetLoweredFunc(M, N); auto fn_ptr = jit->Lookup("fn"); CHECK(fn_ptr); - Instruction instr(common::DefaultHostTarget(), nullptr, {"x", "y"}, {"z"}); // empty scope + Instruction instr( + common::DefaultHostTarget(), nullptr, {"x", "y"}, {"z"}); // empty scope instr.SetLoweredFunc(reinterpret_cast(fn_ptr)); instr.Finalize(); auto check_equal_by_element = [&]() { - auto xd = reinterpret_cast(cinn_pod_value_to_buffer_p(&name2podargs.at("x"))->memory); - auto yd = reinterpret_cast(cinn_pod_value_to_buffer_p(&name2podargs.at("y"))->memory); - auto zd = reinterpret_cast(cinn_pod_value_to_buffer_p(&name2podargs.at("z"))->memory); + auto xd = reinterpret_cast( + cinn_pod_value_to_buffer_p(&name2podargs.at("x"))->memory); + auto yd = reinterpret_cast( + cinn_pod_value_to_buffer_p(&name2podargs.at("y"))->memory); + auto zd = reinterpret_cast( + cinn_pod_value_to_buffer_p(&name2podargs.at("z"))->memory); for (int i = 0; i < M * N; ++i) { - LOG_FIRST_N(INFO, 3) << "data: " << xd[i] << " + " << yd[i] << " = " << zd[i]; + LOG_FIRST_N(INFO, 3) << "data: " << xd[i] << " + " << yd[i] << " = " + << zd[i]; ASSERT_NEAR(xd[i] + yd[i], zd[i], 1e-5); } }; @@ -166,7 +177,9 @@ class TestInstruction : public Instruction { const std::string& func_name) : Instruction(target, scope, in_args, out_args, func_name) {} void SetArgs(const std::vector& args) { args_ = args; } - void SetPodArgs(const std::vector& pod_args) { pod_args_ = pod_args; } + void SetPodArgs(const std::vector& pod_args) { + pod_args_ = pod_args; + } void RunX(std::string conv_type) { if (conv_type == "forward") { @@ -271,27 +284,52 @@ TEST(Instruction, CONV_FORWARD) { int sh = 1, sw = 1; int dila_h = 1, dila_w = 1; - int group = 1; - std::vector args = {in, ic, ih, iw, fn, fc, fh, fw, ph, pw, sh, sw, dila_h, dila_w, group, on, oc, oh, ow}; + int group = 1; + std::vector args = {in, + ic, + ih, + iw, + fn, + fc, + fh, + fw, + ph, + pw, + sh, + sw, + dila_h, + dila_w, + group, + on, + oc, + oh, + ow}; // infer shape - auto conv2d = Operator::Get("conv2d"); - auto strategy = Operator::GetAttrs("CINNStrategy"); - auto infer_shape_func = Operator::GetAttrs("infershape")[conv2d]; + auto conv2d = Operator::Get("conv2d"); + auto strategy = Operator::GetAttrs("CINNStrategy"); + auto infer_shape_func = + Operator::GetAttrs("infershape")[conv2d]; CUDA_CALL(cudaSetDevice(0)); - auto buffer_x = common::BufferBuilder(Float(32), {in, ic, ih, iw}).set_random().Build(); - auto buffer_w = common::BufferBuilder(Float(32), {fn, fc, fh, fw}).set_random().Build(); - auto buffer_y = common::BufferBuilder(Float(32), {on, oc, oh, ow}).set_random().Build(); + auto buffer_x = + common::BufferBuilder(Float(32), {in, ic, ih, iw}).set_random().Build(); + auto buffer_w = + common::BufferBuilder(Float(32), {fn, fc, fh, fw}).set_random().Build(); + auto buffer_y = + common::BufferBuilder(Float(32), {on, oc, oh, ow}).set_random().Build(); void *dev_x = nullptr, *dev_w = nullptr, *dev_y = nullptr; CUDA_CALL(cudaMalloc(&dev_x, buffer_x->memory_size)); CUDA_CALL(cudaMalloc(&dev_w, buffer_w->memory_size)); CUDA_CALL(cudaMalloc(&dev_y, buffer_y->memory_size)); - CUDA_CALL(cudaMemcpy(dev_x, buffer_x->memory, buffer_x->memory_size, cudaMemcpyHostToDevice)); - CUDA_CALL(cudaMemcpy(dev_w, buffer_w->memory, buffer_w->memory_size, cudaMemcpyHostToDevice)); - CUDA_CALL(cudaMemcpy(dev_y, buffer_y->memory, buffer_y->memory_size, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy( + dev_x, buffer_x->memory, buffer_x->memory_size, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy( + dev_w, buffer_w->memory, buffer_w->memory_size, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy( + dev_y, buffer_y->memory, buffer_y->memory_size, cudaMemcpyHostToDevice)); cinn_buffer_t _x; cinn_buffer_t _w; @@ -337,27 +375,52 @@ TEST(Instruction, CONV_BACKWARD_DATA) { int sh = 1, sw = 1; int dila_h = 1, dila_w = 1; - int group = 1; - std::vector args = {in, ic, ih, iw, fn, fc, fh, fw, ph, pw, sh, sw, dila_h, dila_w, group, on, oc, oh, ow}; + int group = 1; + std::vector args = {in, + ic, + ih, + iw, + fn, + fc, + fh, + fw, + ph, + pw, + sh, + sw, + dila_h, + dila_w, + group, + on, + oc, + oh, + ow}; // infer shape - auto conv2d = Operator::Get("conv2d"); - auto strategy = Operator::GetAttrs("CINNStrategy"); - auto infer_shape_func = Operator::GetAttrs("infershape")[conv2d]; + auto conv2d = Operator::Get("conv2d"); + auto strategy = Operator::GetAttrs("CINNStrategy"); + auto infer_shape_func = + Operator::GetAttrs("infershape")[conv2d]; CUDA_CALL(cudaSetDevice(0)); - auto buffer_x = common::BufferBuilder(Float(32), {in, ic, ih, iw}).set_random().Build(); - auto buffer_w = common::BufferBuilder(Float(32), {fn, fc, fh, fw}).set_random().Build(); - auto buffer_y = common::BufferBuilder(Float(32), {on, oc, oh, ow}).set_random().Build(); + auto buffer_x = + common::BufferBuilder(Float(32), {in, ic, ih, iw}).set_random().Build(); + auto buffer_w = + common::BufferBuilder(Float(32), {fn, fc, fh, fw}).set_random().Build(); + auto buffer_y = + common::BufferBuilder(Float(32), {on, oc, oh, ow}).set_random().Build(); void *dev_x = nullptr, *dev_w = nullptr, *dev_y = nullptr; CUDA_CALL(cudaMalloc(&dev_x, buffer_x->memory_size)); CUDA_CALL(cudaMalloc(&dev_w, buffer_w->memory_size)); CUDA_CALL(cudaMalloc(&dev_y, buffer_y->memory_size)); - CUDA_CALL(cudaMemcpy(dev_x, buffer_x->memory, buffer_x->memory_size, cudaMemcpyHostToDevice)); - CUDA_CALL(cudaMemcpy(dev_w, buffer_w->memory, buffer_w->memory_size, cudaMemcpyHostToDevice)); - CUDA_CALL(cudaMemcpy(dev_y, buffer_y->memory, buffer_y->memory_size, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy( + dev_x, buffer_x->memory, buffer_x->memory_size, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy( + dev_w, buffer_w->memory, buffer_w->memory_size, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy( + dev_y, buffer_y->memory, buffer_y->memory_size, cudaMemcpyHostToDevice)); cinn_buffer_t _x; cinn_buffer_t _w; @@ -404,41 +467,67 @@ TEST(Instruction, CONV_BACKWARD_FILTER) { int sh = 1, sw = 1; int dila_h = 1, dila_w = 1; - int group = 1; - std::vector args = {in, ic, ih, iw, fn, fc, fh, fw, ph, pw, sh, sw, dila_h, dila_w, group, on, oc, oh, ow}; + int group = 1; + std::vector args = {in, + ic, + ih, + iw, + fn, + fc, + fh, + fw, + ph, + pw, + sh, + sw, + dila_h, + dila_w, + group, + on, + oc, + oh, + ow}; // infer shape - auto conv2d = Operator::Get("conv2d"); - auto strategy = Operator::GetAttrs("CINNStrategy"); - auto infer_shape_func = Operator::GetAttrs("infershape")[conv2d]; + auto conv2d = Operator::Get("conv2d"); + auto strategy = Operator::GetAttrs("CINNStrategy"); + auto infer_shape_func = + Operator::GetAttrs("infershape")[conv2d]; absl::flat_hash_map attrs_map; - attrs_map["padding"] = std::vector({ph, pw}); - attrs_map["stride"] = std::vector({sh, sw}); - attrs_map["dilation"] = std::vector({dila_h, dila_w}); - attrs_map["data_format"] = std::string("NCHW"); - attrs_map["conv_type"] = std::string("backward_filter"); + attrs_map["padding"] = std::vector({ph, pw}); + attrs_map["stride"] = std::vector({sh, sw}); + attrs_map["dilation"] = std::vector({dila_h, dila_w}); + attrs_map["data_format"] = std::string("NCHW"); + attrs_map["conv_type"] = std::string("backward_filter"); attrs_map["output_shape"] = std::vector({fn, fc, fh, fw}); - auto infer_shape = infer_shape_func({{in, ic, ih, iw}, {on, oc, oh, ow}}, attrs_map); + auto infer_shape = + infer_shape_func({{in, ic, ih, iw}, {on, oc, oh, ow}}, attrs_map); ASSERT_EQ(infer_shape[0][0], fn); ASSERT_EQ(infer_shape[0][1], fc); ASSERT_EQ(infer_shape[0][2], fh); ASSERT_EQ(infer_shape[0][3], fw); CUDA_CALL(cudaSetDevice(0)); - auto buffer_x = common::BufferBuilder(Float(32), {in, ic, ih, iw}).set_random().Build(); - auto buffer_w = common::BufferBuilder(Float(32), {fn, fc, fh, fw}).set_random().Build(); - auto buffer_y = common::BufferBuilder(Float(32), {on, oc, oh, ow}).set_random().Build(); + auto buffer_x = + common::BufferBuilder(Float(32), {in, ic, ih, iw}).set_random().Build(); + auto buffer_w = + common::BufferBuilder(Float(32), {fn, fc, fh, fw}).set_random().Build(); + auto buffer_y = + common::BufferBuilder(Float(32), {on, oc, oh, ow}).set_random().Build(); void *dev_x = nullptr, *dev_w = nullptr, *dev_y = nullptr; CUDA_CALL(cudaMalloc(&dev_x, buffer_x->memory_size)); CUDA_CALL(cudaMalloc(&dev_w, buffer_w->memory_size)); CUDA_CALL(cudaMalloc(&dev_y, buffer_y->memory_size)); - CUDA_CALL(cudaMemcpy(dev_x, buffer_x->memory, buffer_x->memory_size, cudaMemcpyHostToDevice)); - CUDA_CALL(cudaMemcpy(dev_w, buffer_w->memory, buffer_w->memory_size, cudaMemcpyHostToDevice)); - CUDA_CALL(cudaMemcpy(dev_y, buffer_y->memory, buffer_y->memory_size, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy( + dev_x, buffer_x->memory, buffer_x->memory_size, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy( + dev_w, buffer_w->memory, buffer_w->memory_size, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy( + dev_y, buffer_y->memory, buffer_y->memory_size, cudaMemcpyHostToDevice)); cinn_buffer_t _x; cinn_buffer_t _w; diff --git a/paddle/cinn/hlir/framework/memory.cc b/paddle/cinn/hlir/framework/memory.cc index d2739544b715f..6c567bb84f6b7 100755 --- a/paddle/cinn/hlir/framework/memory.cc +++ b/paddle/cinn/hlir/framework/memory.cc @@ -36,7 +36,9 @@ class X86MemoryMng : public MemoryInterface { if (!data) return; ::free(data); } - void* aligned_alloc(size_t alignment, size_t nbytes) override { return ::aligned_alloc(alignment, nbytes); } + void* aligned_alloc(size_t alignment, size_t nbytes) override { + return ::aligned_alloc(alignment, nbytes); + } }; #ifdef CINN_WITH_CUDA diff --git a/paddle/cinn/hlir/framework/memory.h b/paddle/cinn/hlir/framework/memory.h index a17f3608cf1d4..ee84433ed29e4 100755 --- a/paddle/cinn/hlir/framework/memory.h +++ b/paddle/cinn/hlir/framework/memory.h @@ -29,8 +29,10 @@ namespace framework { class MemoryInterface { public: virtual void* malloc(size_t nbytes) = 0; - virtual void free(void* data) = 0; - virtual void* aligned_alloc(size_t alignment, size_t nbytes) { return nullptr; } + virtual void free(void* data) = 0; + virtual void* aligned_alloc(size_t alignment, size_t nbytes) { + return nullptr; + } virtual ~MemoryInterface() {} }; @@ -67,7 +69,8 @@ class MemoryManager final { private: MemoryManager(); - absl::flat_hash_map> memory_mngs_; + absl::flat_hash_map> + memory_mngs_; CINN_DISALLOW_COPY_AND_ASSIGN(MemoryManager); }; diff --git a/paddle/cinn/hlir/framework/node.cc b/paddle/cinn/hlir/framework/node.cc index e2e8ae865f99c..20b2eb90921f0 100644 --- a/paddle/cinn/hlir/framework/node.cc +++ b/paddle/cinn/hlir/framework/node.cc @@ -22,17 +22,23 @@ namespace cinn { namespace hlir { namespace framework { -std::tuple Node::LinkTo(NodeData* other) { +std::tuple Node::LinkTo( + NodeData* other) { return this->common::GraphNode::LinkTo(other->as()); } -std::tuple NodeData::LinkTo(Node* other) { +std::tuple NodeData::LinkTo( + Node* other) { return this->common::GraphNode::LinkTo(other->as()); } -void Node::Controls(NodeData* other) { return this->common::GraphNode::Controls(other->as()); } +void Node::Controls(NodeData* other) { + return this->common::GraphNode::Controls(other->as()); +} -void NodeData::Controls(Node* other) { return this->common::GraphNode::Controls(other->as()); } +void NodeData::Controls(Node* other) { + return this->common::GraphNode::Controls(other->as()); +} namespace { @@ -76,7 +82,8 @@ std::ostream& operator<<(std::ostream& os, const NodeAttr& node_attr) { } //! Using index to sort the input/output tensors -bool edge_index_compare(const common::Shared& a, const common::Shared& b) { +bool edge_index_compare(const common::Shared& a, + const common::Shared& b) { CHECK_NOTNULL(a.get()); CHECK_NOTNULL(b.get()); return a->index() < b->index(); @@ -86,8 +93,9 @@ std::vector> Node::inlinks_in_order() const { std::vector> ordered_links; for (auto& in_edge : this->inlinks()) { ordered_links.push_back(in_edge); - CHECK_GE(in_edge->index(), 0) << "The index of a node's inlinks should be >= 0! Now index is: " << in_edge->index() - << ". Please check."; + CHECK_GE(in_edge->index(), 0) + << "The index of a node's inlinks should be >= 0! Now index is: " + << in_edge->index() << ". Please check."; } std::sort(ordered_links.begin(), ordered_links.end(), edge_index_compare); return ordered_links; @@ -97,21 +105,26 @@ std::vector> Node::outlinks_in_order() const { std::vector> ordered_links; for (auto& out_edge : this->outlinks()) { ordered_links.push_back(out_edge); - CHECK_GE(out_edge->index(), 0) << "The index of a node's outlinks should be >= 0! Now index is: " - << out_edge->index() << ". Please check."; + CHECK_GE(out_edge->index(), 0) + << "The index of a node's outlinks should be >= 0! Now index is: " + << out_edge->index() << ". Please check."; } std::sort(ordered_links.begin(), ordered_links.end(), edge_index_compare); return ordered_links; } -NodeData* InsertGraphOpNodeAfter( - common::Graph* graph, Node* insert_node, NodeData* input_nodedata, Node* out_node, int pos) { +NodeData* InsertGraphOpNodeAfter(common::Graph* graph, + Node* insert_node, + NodeData* input_nodedata, + Node* out_node, + int pos) { CHECK(graph); CHECK(insert_node); CHECK(input_nodedata); input_nodedata->Controls(insert_node); common::Shared node_ptr(insert_node); - auto* out_nodedata = new NodeData(node_ptr, 0, 0, common::UniqName(insert_node->id() + "_out")); + auto* out_nodedata = new NodeData( + node_ptr, 0, 0, common::UniqName(insert_node->id() + "_out")); insert_node->Controls(out_nodedata); std::vector old_sources; auto input_links = out_node->inlinks_in_order(); @@ -138,14 +151,18 @@ NodeData* InsertGraphOpNodeAfter( return out_nodedata; } -NodeData* InsertGraphOpNodeBefore( - common::Graph* graph, Node* insert_node, Node* input_node, NodeData* dst_data, int pos) { +NodeData* InsertGraphOpNodeBefore(common::Graph* graph, + Node* insert_node, + Node* input_node, + NodeData* dst_data, + int pos) { CHECK(graph); CHECK(insert_node); CHECK(input_node); CHECK(dst_data); - auto node_ptr = dst_data->source_node; - auto* input_node_out = new NodeData(node_ptr, 0, 0, common::UniqName(input_node->id() + "_out")); + auto node_ptr = dst_data->source_node; + auto* input_node_out = + new NodeData(node_ptr, 0, 0, common::UniqName(input_node->id() + "_out")); std::vector old_sinks; const auto& old_outlinks = input_node->outlinks_in_order(); for (auto& link : old_outlinks) { diff --git a/paddle/cinn/hlir/framework/node.h b/paddle/cinn/hlir/framework/node.h index 6c48efca99746..31d316bbbff8d 100644 --- a/paddle/cinn/hlir/framework/node.h +++ b/paddle/cinn/hlir/framework/node.h @@ -32,8 +32,8 @@ namespace framework { class Node; class NodeData; -using NodePtr = common::Shared; -using AttrType = utils::Attribute; +using NodePtr = common::Shared; +using AttrType = utils::Attribute; using AttrMapType = utils::AttributeMap; /** @@ -69,14 +69,15 @@ class Node : public common::GraphNode { public: Node() = default; Node(const Operator *op, const std::string &name, std::string id = {}) { - this->attrs.op = op; + this->attrs.op = op; this->attrs.node_name = name; - this->id_ = std::move(id); + this->id_ = std::move(id); } const char *type_info() const override { return __type_info__; } std::tuple LinkTo(NodeData *other); - // This node determines another node, which means the other node depeneds on this node. + // This node determines another node, which means the other node depeneds on + // this node. void Controls(NodeData *other); /** @@ -89,10 +90,12 @@ class Node : public common::GraphNode { */ NodeAttr attrs; - //! Get the input tensors in order to match tensors correctly. If do refresh, we will update the links. + //! Get the input tensors in order to match tensors correctly. If do refresh, + //! we will update the links. std::vector> inlinks_in_order() const; - //! Get the output tensors in order to match tensors correctly. If do refresh, we will update the links. + //! Get the output tensors in order to match tensors correctly. If do refresh, + //! we will update the links. std::vector> outlinks_in_order() const; inline const Operator *op() const { return this->attrs.op; } @@ -102,17 +105,21 @@ class Node : public common::GraphNode { inline uint32_t num_outputs() { if (is_variable()) return 1; if (this->op()->num_outputs == 0) { - using shape_func_t = std::function(const std::vector &, const AttrMapType &)>; - const auto &op_infershape = Operator::GetAttrs("infershape"); - auto key = Operator::Get(this->op()->name); - auto out_shapes = op_infershape[key]({}, this->attrs.attr_store); + using shape_func_t = std::function( + const std::vector &, const AttrMapType &)>; + const auto &op_infershape = + Operator::GetAttrs("infershape"); + auto key = Operator::Get(this->op()->name); + auto out_shapes = op_infershape[key]({}, this->attrs.attr_store); return out_shapes.size(); } else { return this->op()->num_outputs; } } - inline uint32_t num_inputs() { return is_variable() ? 1 : this->op()->num_inputs; } + inline uint32_t num_inputs() { + return is_variable() ? 1 : this->op()->num_inputs; + } template static NodePtr Create(Args &&...args) { @@ -135,29 +142,39 @@ class NodeData : public common::GraphNode { using attr_t = AttrType; public: - NodeData(NodePtr node, uint32_t index, uint32_t version, std::string id, bool is_const = false) - : source_node(std::move(node)), output_index(index), version(version), id_(std::move(id)), is_const_(is_const) {} + NodeData(NodePtr node, + uint32_t index, + uint32_t version, + std::string id, + bool is_const = false) + : source_node(std::move(node)), + output_index(index), + version(version), + id_(std::move(id)), + is_const_(is_const) {} NodeData() : source_node(), output_index(), version(), id_(), is_const_() {} std::tuple LinkTo(Node *other); - // This node determines another node, which means the other node depeneds on this node. + // This node determines another node, which means the other node depeneds on + // this node. void Controls(Node *other); static std::shared_ptr Create( const char *op_name, std::string node_name, std::vector inputs, - std::string id = nullptr, - absl::flat_hash_map attrs = absl::flat_hash_map(), - bool is_const = false) { - auto res = std::make_shared(); - res->id_ = std::move(id); - res->is_const_ = is_const; - res->source_node = Node::Create(); - res->source_node->attrs.op = Operator::Get(op_name); - res->source_node->attrs.node_name = std::move(node_name); + std::string id = nullptr, + absl::flat_hash_map attrs = + absl::flat_hash_map(), + bool is_const = false) { + auto res = std::make_shared(); + res->id_ = std::move(id); + res->is_const_ = is_const; + res->source_node = Node::Create(); + res->source_node->attrs.op = Operator::Get(op_name); + res->source_node->attrs.node_name = std::move(node_name); res->source_node->attrs.attr_store = attrs; return res; } @@ -186,7 +203,8 @@ class NodeData : public common::GraphNode { /** * \brief The version of input Variable. * This field can only be nonzero when this->node is a Variable node. - * version is increased by one each time a Variable get composed to a mutation Op. + * version is increased by one each time a Variable get composed to a + * mutation Op. */ uint32_t version; @@ -201,11 +219,17 @@ class NodeData : public common::GraphNode { }; // insert op_node after input_data -NodeData *InsertGraphOpNodeAfter( - common::Graph *graph, Node *insert_node, NodeData *input_nodedata, Node *dst_node, int pos); +NodeData *InsertGraphOpNodeAfter(common::Graph *graph, + Node *insert_node, + NodeData *input_nodedata, + Node *dst_node, + int pos); // insert op_node before out_data -NodeData *InsertGraphOpNodeBefore( - common::Graph *graph, Node *insert_node, Node *input_node, NodeData *dst_data, int pos); +NodeData *InsertGraphOpNodeBefore(common::Graph *graph, + Node *insert_node, + Node *input_node, + NodeData *dst_data, + int pos); } // namespace framework } // namespace hlir diff --git a/paddle/cinn/hlir/framework/op.h b/paddle/cinn/hlir/framework/op.h index cc21ed086d2c7..78e408c5e9980 100755 --- a/paddle/cinn/hlir/framework/op.h +++ b/paddle/cinn/hlir/framework/op.h @@ -40,22 +40,26 @@ namespace framework { class Operator; using shape_t = utils::ShapeType; -using dim_t = utils::DimType; +using dim_t = utils::DimType; /*! \brief operator pattern used in graph fusion */ enum OpPatternKind { - // The relation between input tensor index and output tensor index is one-to-one correspondence. + // The relation between input tensor index and output tensor index is + // one-to-one correspondence. // for example :code:`out[i, j] = input[i, j] + 1`. // Note that the axis need to be in order. kElementWise = 0, - // The relation between input tensor index and output tensor index is one-to-many correspondence. + // The relation between input tensor index and output tensor index is + // one-to-many correspondence. // for example :code:`out[i, j, k] = input[i, j]`. // Note that the axis need to be in order. kBroadcast = 1, - // Injective operator, we can always injectively map a output axis to a input axis. + // Injective operator, we can always injectively map a output axis to a input + // axis. // for example :code:`out[i, j] = input[j, i]`. kInjective = 2, - // The relation between input tensor index and output tensor index is many-to-one correspondence. + // The relation between input tensor index and output tensor index is + // many-to-one correspondence. // for example :code:`out[i, j] = sum(input[i, j, k]) along k`. kReduction = 3, // Complex operation, can still fuse one-to-one operations into its output. @@ -84,7 +88,8 @@ class OpValueType { public: inline const ValueType& operator[](const Operator* op) const; - inline const ValueType& Get(const Operator* op, const ValueType& def_value) const; + inline const ValueType& Get(const Operator* op, + const ValueType& def_value) const; inline bool Find(const Operator* op) const; @@ -137,14 +142,16 @@ class Operator { } template - inline Operator& set_attr(const std::string& attr_name, const ValueType& value) { + inline Operator& set_attr(const std::string& attr_name, + const ValueType& value) { UpdateAttrMap(attr_name, [this, attr_name, value](absl::any* pmap) { if (!pmap->has_value()) { OpValueType pm; pm.attr_name = attr_name; - *pmap = std::move(pm); + *pmap = std::move(pm); } - std::vector& vec = absl::any_cast&>(*pmap).data; + std::vector& vec = + absl::any_cast&>(*pmap).data; // resize the value type. if (vec.size() <= index) { vec.resize(index + 1, ValueType()); @@ -162,7 +169,7 @@ class Operator { if (!pmap->has_value()) { OpValueType pm; pm.attr_name = attr_name; - *pmap = std::move(pm); + *pmap = std::move(pm); } }); ref = GetAttrMap(attr_name); @@ -180,7 +187,7 @@ class Operator { Operator() { index = OpRegistry::Global()->op_counter++; } static const absl::any* GetAttrMap(const std::string& key) { auto& dict = OpRegistry::Global()->attrs; - auto it = dict.find(key); + auto it = dict.find(key); if (it != dict.end()) { return it->second.get(); } else { @@ -188,7 +195,8 @@ class Operator { } } //! update the attribute OpValueType - static void UpdateAttrMap(const std::string& key, std::function updater) { + static void UpdateAttrMap(const std::string& key, + std::function updater) { OpRegistry* reg = OpRegistry::Global(); std::lock_guard(reg->mutex); std::unique_ptr& value = reg->attrs[key]; @@ -201,12 +209,15 @@ template const ValueType& OpValueType::operator[](const Operator* op) const { CHECK(op) << "The input op is nullptr and it is invalid! Please check again."; const uint32_t idx = op->index; - CHECK_LT(idx, data.size()) << "Attribute " << attr_name << " has not been registered for Operator " << op->name; + CHECK_LT(idx, data.size()) + << "Attribute " << attr_name << " has not been registered for Operator " + << op->name; return data[idx]; } template -const ValueType& OpValueType::Get(const Operator* op, const ValueType& def_value) const { +const ValueType& OpValueType::Get(const Operator* op, + const ValueType& def_value) const { if (!op) return def_value; const uint32_t idx = op->index; if (idx < data.size()) { @@ -224,7 +235,8 @@ bool OpValueType::Find(const Operator* op) const { } // internal macros to make -#define CINN_REGISTER_VAR_DEF(OpName) static ::cinn::hlir::framework::Operator& __make_##HlirOp##_##OpName +#define CINN_REGISTER_VAR_DEF(OpName) \ + static ::cinn::hlir::framework::Operator& __make_##HlirOp##_##OpName /** * @def CINNR_REGISTER_OP @@ -239,9 +251,10 @@ bool OpValueType::Find(const Operator* op) const { * .set_attr("gpu_kernel", AddKernel); * \endcode */ -#define CINN_REGISTER_OP(OpName) \ - CINN_STR_CONCAT(CINN_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ - ::cinn::hlir::framework::OpRegistry::Global()->__REGISTER_OR_GET__(#OpName) +#define CINN_REGISTER_OP(OpName) \ + CINN_STR_CONCAT(CINN_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ + ::cinn::hlir::framework::OpRegistry::Global()->__REGISTER_OR_GET__( \ + #OpName) } // namespace framework } // namespace hlir diff --git a/paddle/cinn/hlir/framework/op_lowering.cc b/paddle/cinn/hlir/framework/op_lowering.cc index 4700d4a530d06..3f7dcbc9a1e97 100644 --- a/paddle/cinn/hlir/framework/op_lowering.cc +++ b/paddle/cinn/hlir/framework/op_lowering.cc @@ -40,13 +40,15 @@ using namespace lang; using cinn::hlir::op::ExternalApiRegistry; -OpLowerer::OpLowerer(const absl::flat_hash_map& type_dict, - const absl::flat_hash_map& shape_dict, - const Target& target) +OpLowerer::OpLowerer( + const absl::flat_hash_map& type_dict, + const absl::flat_hash_map& shape_dict, + const Target& target) : type_dict_(type_dict), shape_dict_(shape_dict), target_(target) {} std::vector OpLowerer::Lower(GroupPtr& group) { - VLOG(3) << "Lowering Group : " << group->group_id << " , Op Pattern : " << group->op_pattern_kind; + VLOG(3) << "Lowering Group : " << group->group_id + << " , Op Pattern : " << group->op_pattern_kind; group->input_names.clear(); group->output_names.clear(); if (FLAGS_cinn_ir_schedule) { @@ -70,13 +72,15 @@ std::vector OpLowerer::Lower(GroupPtr& group) { } std::vector OpLowerer::LowerWithoutSchedule(GroupPtr& group) { - VLOG(3) << "Lowering Group : " << group->group_id << " , Op Pattern : " << group->op_pattern_kind; + VLOG(3) << "Lowering Group : " << group->group_id + << " , Op Pattern : " << group->op_pattern_kind; if (FLAGS_cinn_ir_schedule) { switch (group->op_pattern_kind) { case framework::kElementWise: case framework::kBroadcast: case framework::kInjective: - return IRLowerOpWithoutSchedule(&OpLowerer::IRElementwiseCompute, group); + return IRLowerOpWithoutSchedule(&OpLowerer::IRElementwiseCompute, + group); case framework::kReduction: return IRLowerOpWithoutSchedule(&OpLowerer::IRReduceCompute, group); case framework::kOutFusible: @@ -91,18 +95,30 @@ std::vector OpLowerer::LowerWithoutSchedule(GroupPtr& group) { } } -std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, GroupPtr& group) { +std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, + GroupPtr& group) { poly::StageMap stages; std::vector arg_tensors; std::unordered_map tensor_map; // do compute. - VLOG(3) << "group->fused_sub_groups.size() is : " << group->fused_sub_groups.size(); + VLOG(3) << "group->fused_sub_groups.size() is : " + << group->fused_sub_groups.size(); std::vector ast_exprs; if (group->fused_sub_groups.size() == 0) { - ast_exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, group, /*apply_impl_schedule = */ true); + ast_exprs = (this->*compute)(stages, + arg_tensors, + tensor_map, + group, + group, + /*apply_impl_schedule = */ true); } else { for (auto& sub_group : group->fused_sub_groups) { - auto exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, sub_group, /*apply_impl_schedule = */ true); + auto exprs = (this->*compute)(stages, + arg_tensors, + tensor_map, + group, + sub_group, + /*apply_impl_schedule = */ true); ast_exprs.insert(ast_exprs.end(), exprs.begin(), exprs.end()); } } @@ -110,13 +126,15 @@ std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, Gro ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); - Node* first = nullptr; + Node* first = nullptr; Node* second = nullptr; - VLOG(3) << "Before IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "Before IRLowerOp schedule, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); // do schedule. IRSchedule(ir_sch, group, tensor_map); - VLOG(3) << "After IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After IRLowerOp schedule, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); // function args group->input_names.clear(); std::vector func_args; @@ -134,7 +152,7 @@ std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, Gro group->output_names.push_back(node_data->id()); } // collect all output tensor. - std::string post = ""; + std::string post = ""; std::string prefix = GetNodeData(node)->id(); for (int idx = 0; idx < 1; ++idx) { CHECK(tensor_map.count(prefix)) << "Can't find output tensor " << prefix; @@ -155,25 +173,38 @@ std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, Gro #endif auto temp_buffers = lang::GetTempBuffers(arg_tensors, stages, func_body); - auto func = - ir::_LoweredFunc_::Make(group->GetFuncName(), func_args, ir_sch.GetModule().GetExprs().at(0), temp_buffers); + auto func = ir::_LoweredFunc_::Make(group->GetFuncName(), + func_args, + ir_sch.GetModule().GetExprs().at(0), + temp_buffers); func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); return {func}; } -std::vector OpLowerer::IRLowerOpWithoutSchedule(IRComputeFunction compute, GroupPtr& group) { +std::vector OpLowerer::IRLowerOpWithoutSchedule( + IRComputeFunction compute, GroupPtr& group) { poly::StageMap stages; std::vector arg_tensors; std::unordered_map tensor_map; // do compute. - VLOG(3) << "group->fused_sub_groups.size() is : " << group->fused_sub_groups.size(); + VLOG(3) << "group->fused_sub_groups.size() is : " + << group->fused_sub_groups.size(); std::vector ast_exprs; if (group->fused_sub_groups.size() == 0) { - ast_exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, group, /*apply_impl_schedule = */ false); + ast_exprs = (this->*compute)(stages, + arg_tensors, + tensor_map, + group, + group, + /*apply_impl_schedule = */ false); } else { for (auto& sub_group : group->fused_sub_groups) { - auto exprs = - (this->*compute)(stages, arg_tensors, tensor_map, group, sub_group, /*apply_impl_schedule = */ false); + auto exprs = (this->*compute)(stages, + arg_tensors, + tensor_map, + group, + sub_group, + /*apply_impl_schedule = */ false); ast_exprs.insert(ast_exprs.end(), exprs.begin(), exprs.end()); } } @@ -181,7 +212,8 @@ std::vector OpLowerer::IRLowerOpWithoutSchedule(IRComputeFuncti ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); - VLOG(3) << "After IRLowerOp compute, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After IRLowerOp compute, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); // function args group->input_names.clear(); std::vector func_args; @@ -199,7 +231,7 @@ std::vector OpLowerer::IRLowerOpWithoutSchedule(IRComputeFuncti group->output_names.push_back(node_data->id()); } // collect all output tensor. - std::string post = ""; + std::string post = ""; std::string prefix = GetNodeData(node)->id(); for (int idx = 0; idx < 1; ++idx) { CHECK(tensor_map.count(prefix)) << "Can't find output tensor " << prefix; @@ -225,7 +257,8 @@ std::vector OpLowerer::IRLowerOpWithoutSchedule(IRComputeFuncti continue; } arg_tensors.push_back(tensor.second); - // use the underlying tensor name to be consistent with the argument name in the lowered function + // use the underlying tensor name to be consistent with the argument name in + // the lowered function group->output_names.push_back(tensor.second->name); func_args.emplace_back(tensor.second->buffer, ir::Argument::IO::kOutput); } @@ -236,20 +269,23 @@ std::vector OpLowerer::IRLowerOpWithoutSchedule(IRComputeFuncti #endif auto temp_buffers = lang::GetTempBuffers(arg_tensors, stages, func_body); - auto func = - ir::_LoweredFunc_::Make(group->GetFuncName(), func_args, ir_sch.GetModule().GetExprs().at(0), temp_buffers); + auto func = ir::_LoweredFunc_::Make(group->GetFuncName(), + func_args, + ir_sch.GetModule().GetExprs().at(0), + temp_buffers); func->PrepareBufferCastExprs(); func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); return {func}; } -std::vector OpLowerer::IRElementwiseCompute(poly::StageMap& stages, - std::vector& func_tensors, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group, - bool apply_impl_schedule) { +std::vector OpLowerer::IRElementwiseCompute( + poly::StageMap& stages, + std::vector& func_tensors, + std::unordered_map& tensor_map, + const GroupPtr& group, + const GroupPtr& sub_group, + bool apply_impl_schedule) { VLOG(2) << "ElementwiseCompute Group : " << sub_group->group_id; auto& strategy = Operator::GetAttrs("CINNStrategy"); @@ -259,8 +295,8 @@ std::vector OpLowerer::IRElementwiseCompute(poly::StageMap& stages, auto node_data = GetNodeData(node); CHECK_EQ(GetAllNodeData(node).size(), 1U); std::vector cinn_inputs; - std::vector tensor_inputs = - std::move(CollectInputTensor(node, func_tensors, tensor_map, this->type_dict_, this->shape_dict_)); + std::vector tensor_inputs = std::move(CollectInputTensor( + node, func_tensors, tensor_map, this->type_dict_, this->shape_dict_)); for (auto& tensor : tensor_inputs) { cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); } @@ -271,18 +307,26 @@ std::vector OpLowerer::IRElementwiseCompute(poly::StageMap& stages, std::vector> out_shapes; out_types.push_back(this->type_dict_.at(node_data->id())); out_shapes.push_back(this->shape_dict_.at(node_data->id())); - auto impl = - OpStrategy::SelectImpl(strategy[node->op()](node->attrs, tensor_inputs, out_types, out_shapes, this->target_)); + auto impl = OpStrategy::SelectImpl(strategy[node->op()]( + node->attrs, tensor_inputs, out_types, out_shapes, this->target_)); // do compute - common::CINNValuePack pack = impl->fcompute(common::CINNValuePack{cinn_inputs}); + common::CINNValuePack pack = + impl->fcompute(common::CINNValuePack{cinn_inputs}); CHECK_EQ(pack.size(), 2U); - Expr expr = pack[0]; + Expr expr = pack[0]; poly::StageMap node_stages = pack.back(); tensor_inputs.push_back(expr.as_tensor_ref()); tensor_map[node_data->id()] = expr.as_tensor_ref(); - auto func = lang::LowerVec("fn_" + node->id(), node_stages, tensor_inputs, {}, {}, nullptr, this->target_, true); + auto func = lang::LowerVec("fn_" + node->id(), + node_stages, + tensor_inputs, + {}, + {}, + nullptr, + this->target_, + true); CHECK_EQ(func.size(), 1); if (apply_impl_schedule) { @@ -296,7 +340,8 @@ std::vector OpLowerer::IRElementwiseCompute(poly::StageMap& stages, schedule_inputs.push_back(common::CINNValue(f->body)); } // do ast tree schedule - common::CINNValuePack expr_pack = impl->fschedule(common::CINNValuePack{schedule_inputs}); + common::CINNValuePack expr_pack = + impl->fschedule(common::CINNValuePack{schedule_inputs}); CHECK_EQ(expr_pack.size(), 1); Expr ast_expr = expr_pack[0]; @@ -309,24 +354,26 @@ std::vector OpLowerer::IRElementwiseCompute(poly::StageMap& stages, return ast_exprs; } -std::vector OpLowerer::IRReduceCompute(poly::StageMap& stages, - std::vector& func_args, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group, - bool apply_impl_schedule) { +std::vector OpLowerer::IRReduceCompute( + poly::StageMap& stages, + std::vector& func_args, + std::unordered_map& tensor_map, + const GroupPtr& group, + const GroupPtr& sub_group, + bool apply_impl_schedule) { VLOG(2) << "ReduceCompute Group : " << sub_group->group_id; - auto& cinn_strategy = Operator::GetAttrs("CINNStrategy"); + auto& cinn_strategy = Operator::GetAttrs("CINNStrategy"); auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); std::vector ast_exprs; for (auto& node : sub_group->nodes) { auto node_data = GetNodeData(node); - VLOG(3) << "In ReduceCompute, process node: " << node->id() << " with op type: " << node->op()->name; + VLOG(3) << "In ReduceCompute, process node: " << node->id() + << " with op type: " << node->op()->name; std::vector cinn_inputs; - std::vector tensor_inputs = - std::move(CollectInputTensor(node, func_args, tensor_map, this->type_dict_, this->shape_dict_)); + std::vector tensor_inputs = std::move(CollectInputTensor( + node, func_args, tensor_map, this->type_dict_, this->shape_dict_)); for (auto& tensor : tensor_inputs) { cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); } @@ -338,10 +385,11 @@ std::vector OpLowerer::IRReduceCompute(poly::StageMap& stages, out_types.push_back(this->type_dict_.at(node_data->id())); out_shapes.push_back(this->shape_dict_.at(node_data->id())); - auto impl = - OpStrategy::SelectImpl(cinn_strategy[node->op()](node->attrs, tensor_inputs, out_types, out_shapes, target_)); + auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()]( + node->attrs, tensor_inputs, out_types, out_shapes, target_)); // do compute - common::CINNValuePack pack = impl->fcompute(common::CINNValuePack{cinn_inputs}); + common::CINNValuePack pack = + impl->fcompute(common::CINNValuePack{cinn_inputs}); CHECK_GE(pack.size(), 2UL); CHECK_LE(pack.size(), 5UL); @@ -349,20 +397,29 @@ std::vector OpLowerer::IRReduceCompute(poly::StageMap& stages, std::string post = ""; for (int idx = 0; idx < pack.size() - 1; ++idx) { - Expr expr = pack[idx]; + Expr expr = pack[idx]; tensor_map[node_data->id() + post] = expr.as_tensor_ref(); // As op may has more than 1 output tensor, using id + "_0"/"_1" as key. post = "_" + std::to_string(idx); // Insert outout tensors - if (!expr.as_tensor_ref()->buffer.defined() || this->target_ != common::DefaultNVGPUTarget()) { + if (!expr.as_tensor_ref()->buffer.defined() || + this->target_ != common::DefaultNVGPUTarget()) { tensor_inputs.push_back(expr.as_tensor_ref()); } } - auto func = lang::LowerVec("fn_" + node->id(), tmp_stages, tensor_inputs, {}, {}, nullptr, this->target_, true); + auto func = lang::LowerVec("fn_" + node->id(), + tmp_stages, + tensor_inputs, + {}, + {}, + nullptr, + this->target_, + true); // node is kReduction - if (op_pattern_dict[node->op()] == framework::kReduction && apply_impl_schedule) { + if (op_pattern_dict[node->op()] == framework::kReduction && + apply_impl_schedule) { std::vector schedule_inputs; // collect tensor for (int idx = 0; idx < pack.size() - 1; ++idx) { @@ -373,12 +430,14 @@ std::vector OpLowerer::IRReduceCompute(poly::StageMap& stages, schedule_inputs.push_back(common::CINNValue(f->body)); } // do ast tree schedule - common::CINNValuePack expr_pack = impl->fschedule(common::CINNValuePack{schedule_inputs}); + common::CINNValuePack expr_pack = + impl->fschedule(common::CINNValuePack{schedule_inputs}); // ast tree after schedule. Expr ast_expr = expr_pack[0]; ast_exprs.push_back(ast_expr); } else if (group->master_nodes.count(node)) { - // as master node should copy transform from reducer, left it to reduce schedule. + // as master node should copy transform from reducer, left it to reduce + // schedule. ast_exprs.push_back(func[0]->body); } else { ast_exprs.push_back(func[0]->body); @@ -388,14 +447,17 @@ std::vector OpLowerer::IRReduceCompute(poly::StageMap& stages, return ast_exprs; } -std::vector OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, bool apply_impl_schedule) { +std::vector OpLowerer::IRLowerNonFusibleOp( + GroupPtr& group, bool apply_impl_schedule) { VLOG(3) << "LowerNonFusibleOp Group : " << group->group_id; // get input tensor and output tensor CHECK(group->nodes.size() || group->fused_sub_groups.size()); - auto& cinn_strategy = Operator::GetAttrs("CINNStrategy"); + auto& cinn_strategy = Operator::GetAttrs("CINNStrategy"); auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - auto node = group->fused_sub_groups.size() ? group->fused_sub_groups[0]->nodes.front() : group->nodes.front(); + auto node = group->fused_sub_groups.size() + ? group->fused_sub_groups[0]->nodes.front() + : group->nodes.front(); VLOG(3) << "GetOpFunc of op " << node->id(); std::vector inputs; std::vector cinn_inputs; @@ -431,18 +493,23 @@ std::vector OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, boo cinn_inputs.push_back(common::CINNValue(node_data->id())); } - auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()](node->attrs, inputs, out_types, out_shapes, target_)); + auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()]( + node->attrs, inputs, out_types, out_shapes, target_)); // if node op is custom_call, apply custom_call compute. if (node->op()->name == "custom_call") { std::string external_api; if (node->attrs.attr_store.count("custom_call")) { - external_api = absl::get(node->attrs.attr_store.at("custom_call")); + external_api = + absl::get(node->attrs.attr_store.at("custom_call")); } else { - external_api = ExternalApiRegistry::Global()->GetExternalApi(node, target_); + external_api = + ExternalApiRegistry::Global()->GetExternalApi(node, target_); } - std::vector compute_args = {common::CINNValue(group->GetFuncName()), - common::CINNValue(external_api)}; - common::CINNValuePack pack = impl->fcompute(common::CINNValuePack{compute_args}); + std::vector compute_args = { + common::CINNValue(group->GetFuncName()), + common::CINNValue(external_api)}; + common::CINNValuePack pack = + impl->fcompute(common::CINNValuePack{compute_args}); CHECK_EQ(pack.size(), 1UL); // reset input names as extern api input args can't be remove duplicate. group->input_names.clear(); @@ -452,19 +519,29 @@ std::vector OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, boo return {pack[0].operator ir::Expr().as_lowered_func_ref()}; } - common::CINNValuePack pack = impl->fcompute(common::CINNValuePack{cinn_inputs}); + common::CINNValuePack pack = + impl->fcompute(common::CINNValuePack{cinn_inputs}); for (int i = 0; i < pack->size() - 1; i++) { ir::Expr temp = pack[i]; // checkout whether the tensor is with buffer. - if (!temp.as_tensor_ref()->buffer.defined() || this->target_ != common::DefaultNVGPUTarget()) { + if (!temp.as_tensor_ref()->buffer.defined() || + this->target_ != common::DefaultNVGPUTarget()) { inputs.push_back(temp.as_tensor_ref()); temp.as_tensor_ref()->WithBuffer(); - args.emplace_back(temp.as_tensor_ref()->buffer, ir::Argument::IO::kOutput); + args.emplace_back(temp.as_tensor_ref()->buffer, + ir::Argument::IO::kOutput); } } poly::StageMap stages = pack.back(); - auto func = lang::LowerVec(group->GetFuncName(), stages, inputs, {}, {}, nullptr, this->target_, true); + auto func = lang::LowerVec(group->GetFuncName(), + stages, + inputs, + {}, + {}, + nullptr, + this->target_, + true); if (apply_impl_schedule) { std::vector schedule_inputs; @@ -477,13 +554,18 @@ std::vector OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, boo schedule_inputs.push_back(common::CINNValue(f->body)); } // do ast tree schedule - common::CINNValuePack expr_pack = impl->fschedule(common::CINNValuePack{schedule_inputs}); + common::CINNValuePack expr_pack = + impl->fschedule(common::CINNValuePack{schedule_inputs}); ir::Expr func_body = expr_pack[0]; std::vector input_output_nodes(group->input_names); - input_output_nodes.insert(input_output_nodes.end(), group->output_names.begin(), group->output_names.end()); - VLOG(6) << "func.size() = " << func.size() << ", expr_pack.size() = " << expr_pack.size(); - VLOG(6) << "args.size() = " << args.size() << ", input_output_nodes.size() = " << input_output_nodes.size(); + input_output_nodes.insert(input_output_nodes.end(), + group->output_names.begin(), + group->output_names.end()); + VLOG(6) << "func.size() = " << func.size() + << ", expr_pack.size() = " << expr_pack.size(); + VLOG(6) << "args.size() = " << args.size() + << ", input_output_nodes.size() = " << input_output_nodes.size(); if (args.size() > input_output_nodes.size()) { args = lang::GetArgs(func_body, input_output_nodes); } @@ -494,7 +576,8 @@ std::vector OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, boo optim::OptimizeExprGPU(&(func_body)); #endif auto temp_buffers = lang::GetTempBuffers(inputs, stages, func_body); - auto function = ir::_LoweredFunc_::Make(group->GetFuncName(), args, func_body, temp_buffers); + auto function = ir::_LoweredFunc_::Make( + group->GetFuncName(), args, func_body, temp_buffers); res.push_back(function); } for (auto& i : res) { @@ -513,26 +596,30 @@ std::vector OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, boo } // group schedule -void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, - const GroupPtr& group, - const std::unordered_map& tensor_map) { +void OpLowerer::IRSchedule( + ir::IRSchedule& ir_sch, + const GroupPtr& group, + const std::unordered_map& tensor_map) { // topological order. - auto nodes_set = group->NodeSet(); - auto v_consumers = BuildVirtualConsumer(group, this->shape_dict_); - auto nodes_in_order = BFSTopologicalOrderWithPriority(group, v_consumers, this->shape_dict_); + auto nodes_set = group->NodeSet(); + auto v_consumers = BuildVirtualConsumer(group, this->shape_dict_); + auto nodes_in_order = + BFSTopologicalOrderWithPriority(group, v_consumers, this->shape_dict_); // find reducer. std::unordered_set nodes_inline; - auto greducer = FindGlobalReducer(nodes_in_order); + auto greducer = FindGlobalReducer(nodes_in_order); auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); // do schedule for (auto node : nodes_in_order) { VLOG(4) << "Try FUSION " << node->op()->name; // consumers. - auto consumers = GetConsumersInSet(node, nodes_set); - const Node* reducer = greducer ? FindNearestReducer(node, nodes_set) : greducer; + auto consumers = GetConsumersInSet(node, nodes_set); + const Node* reducer = + greducer ? FindNearestReducer(node, nodes_set) : greducer; if (!reducer && greducer) { - reducer = v_consumers.count(node) ? v_consumers.find(node)->second : reducer; + reducer = + v_consumers.count(node) ? v_consumers.find(node)->second : reducer; if (reducer && op_pattern_dict[reducer->op()] != framework::kReduction) { reducer = nullptr; } @@ -540,8 +627,15 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, auto masters = GetMasters(node, nodes_inline, nodes_set); // node can be inline. - if (CanbeInline(node, consumers, reducer, masters, group, nodes_set, this->shape_dict_)) { - VLOG(3) << "Before compute inline, ir is:\n" << ir_sch.GetModule().GetExprs().at(0); + if (CanbeInline(node, + consumers, + reducer, + masters, + group, + nodes_set, + this->shape_dict_)) { + VLOG(3) << "Before compute inline, ir is:\n" + << ir_sch.GetModule().GetExprs().at(0); auto block = ir_sch.GetBlock(GetNodeData(node)->id()); ir::ComputeInlineChecker checker(ir_sch, block); if (!checker.Check()) { @@ -561,28 +655,51 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, ir_sch.ComputeInline(block); nodes_inline.insert(node); - VLOG(3) << "After compute inline, ir is:\n" << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After compute inline, ir is:\n" + << ir_sch.GetModule().GetExprs().at(0); continue; } // find master to computeat. - auto master = GetMasterToComputeAt(node, nodes_in_order, nodes_inline, nodes_set, v_consumers, this->shape_dict_); + auto master = GetMasterToComputeAt(node, + nodes_in_order, + nodes_inline, + nodes_set, + v_consumers, + this->shape_dict_); // assign to reducer/master loop. if (reducer) { - VLOG(3) << "Before assign node " << node->id() << " into vertical link reducer " << reducer->id() << ", ir is:\n" + VLOG(3) << "Before assign node " << node->id() + << " into vertical link reducer " << reducer->id() << ", ir is:\n" << ir_sch.GetModule().GetExprs().at(0); // if node is vertical with reduce, loop assign reducer. - LoopAssignReduce(ir_sch, node, reducer, this->target_, tensor_map, this->shape_dict_); + LoopAssignReduce( + ir_sch, node, reducer, this->target_, tensor_map, this->shape_dict_); } else if (greducer) { - auto greducer_out_shape = this->shape_dict_.at(greducer->outlinks_in_order()[0]->sink()->id()); - auto node_out_shape = this->shape_dict_.at(node->outlinks_in_order()[0]->sink()->id()); - if (std::accumulate(greducer_out_shape.begin(), greducer_out_shape.end(), 1, std::multiplies()) != - std::accumulate(node_out_shape.begin(), node_out_shape.end(), 1, std::multiplies())) { - LoopAssignReduce(ir_sch, node, greducer, this->target_, tensor_map, this->shape_dict_); + auto greducer_out_shape = + this->shape_dict_.at(greducer->outlinks_in_order()[0]->sink()->id()); + auto node_out_shape = + this->shape_dict_.at(node->outlinks_in_order()[0]->sink()->id()); + if (std::accumulate(greducer_out_shape.begin(), + greducer_out_shape.end(), + 1, + std::multiplies()) != + std::accumulate(node_out_shape.begin(), + node_out_shape.end(), + 1, + std::multiplies())) { + LoopAssignReduce(ir_sch, + node, + greducer, + this->target_, + tensor_map, + this->shape_dict_); } else { - VLOG(3) << "Before assign node " << node->id() << " into horizontal link reducer " << greducer->id() + VLOG(3) << "Before assign node " << node->id() + << " into horizontal link reducer " << greducer->id() << ", ir is:\n" << ir_sch.GetModule().GetExprs().at(0); - // if node is horizontal with reduce or node is reduce, loop assign master. + // if node is horizontal with reduce or node is reduce, loop assign + // master. auto loops = ir_sch.GetLoops(GetNodeData(node)->id()); if (op_pattern_dict[node->op()] == framework::kElementWise) { ir_sch.FlattenLoops(loops, true); @@ -601,10 +718,17 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, } } } - VLOG(3) << "Before loop fusion, ir is:\n" << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "Before loop fusion, ir is:\n" + << ir_sch.GetModule().GetExprs().at(0); // do loop fuse. - LoopComputeAt(ir_sch, node, master ? master : nodes_in_order.front(), group, this->shape_dict_, tensor_map); - VLOG(3) << "After loop fusion, ir is:\n" << ir_sch.GetModule().GetExprs().at(0); + LoopComputeAt(ir_sch, + node, + master ? master : nodes_in_order.front(), + group, + this->shape_dict_, + tensor_map); + VLOG(3) << "After loop fusion, ir is:\n" + << ir_sch.GetModule().GetExprs().at(0); } // do vectorize @@ -615,23 +739,28 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, // only support first block? auto block = all_blocks[0]; CHECK(block->as()); - CHECK(block->as()->schedule_block->as()); + CHECK(block->as() + ->schedule_block->as()); auto is_tensor_block = true; - auto tensor_name = block->as()->schedule_block->as()->name; + auto tensor_name = block->as() + ->schedule_block->as() + ->name; if (!tensor_map.count(tensor_name)) { is_tensor_block = false; } if (FLAGS_cinn_use_cuda_vectorize && is_tensor_block && - (group->op_pattern_kind == framework::kElementWise || group->op_pattern_kind == framework::kInjective || + (group->op_pattern_kind == framework::kElementWise || + group->op_pattern_kind == framework::kInjective || group->op_pattern_kind == framework::kBroadcast)) { // auto loops = ir_sch.GetLoops(GetNodeData(node)->id()); auto loops = ir_sch.GetLoops(block); VLOG(4) << "Op Pattern : " << loops.size(); if (loops.size() >= 1) { - VLOG(4) << "Before vectorize, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - auto loop_inner = loops.back(); + VLOG(4) << "Before vectorize, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); + auto loop_inner = loops.back(); int vector_width = 1; - auto psize = ir::GetLoopExtent(loop_inner); + auto psize = ir::GetLoopExtent(loop_inner); // get dtype of vectorized var auto dtype = this->type_dict_.at(tensor_name); VLOG(4) << tensor_name << " dtype " << dtype; @@ -644,17 +773,22 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, } if (vector_width > 1) { auto splited = ir_sch.Split(loop_inner, {-1, vector_width}); - splited[0].As()->set_bind_info(loop_inner.As()->bind_info()); + splited[0].As()->set_bind_info( + loop_inner.As()->bind_info()); splited[1].As()->set_serial(); ir_sch.Vectorize(splited[1], vector_width); } - VLOG(4) << "After vectorize, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + VLOG(4) << "After vectorize, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); } } - VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - SyncThreadWithShared(ir_sch, group, nodes_inline, nodes_set, this->shape_dict_, tensor_map); - VLOG(4) << "After IRSchedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); + SyncThreadWithShared( + ir_sch, group, nodes_inline, nodes_set, this->shape_dict_, tensor_map); + VLOG(4) << "After IRSchedule, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); } } // namespace framework diff --git a/paddle/cinn/hlir/framework/op_lowering.h b/paddle/cinn/hlir/framework/op_lowering.h index 520e5c165bb52..97bdaeb485883 100755 --- a/paddle/cinn/hlir/framework/op_lowering.h +++ b/paddle/cinn/hlir/framework/op_lowering.h @@ -39,12 +39,13 @@ using GroupPtr = std::shared_ptr; using common::Target; class OpLowerer; -typedef std::vector (OpLowerer::*IRComputeFunction)(poly::StageMap&, - std::vector&, - std::unordered_map&, - const GroupPtr&, - const GroupPtr&, - bool); +typedef std::vector (OpLowerer::*IRComputeFunction)( + poly::StageMap&, + std::vector&, + std::unordered_map&, + const GroupPtr&, + const GroupPtr&, + bool); class OpLowerer { public: @@ -57,23 +58,26 @@ class OpLowerer { private: std::vector IRLowerOp(IRComputeFunction, GroupPtr&); std::vector IRLowerNonFusibleOp(GroupPtr&, bool); - std::vector IRLowerOpWithoutSchedule(IRComputeFunction, GroupPtr&); -#define DEFINE_IR_COMPUTE(type) \ - std::vector IR##type##Compute(poly::StageMap& stages, \ - std::vector& func_args, \ - std::unordered_map& tensor_map, \ - const GroupPtr& group, \ - const GroupPtr& sub_group, \ - bool apply_impl_schedule = false); + std::vector IRLowerOpWithoutSchedule(IRComputeFunction, + GroupPtr&); +#define DEFINE_IR_COMPUTE(type) \ + std::vector IR##type##Compute( \ + poly::StageMap& stages, \ + std::vector& func_args, \ + std::unordered_map& tensor_map, \ + const GroupPtr& group, \ + const GroupPtr& sub_group, \ + bool apply_impl_schedule = false); // compute and schedule DEFINE_IR_COMPUTE(Elementwise); DEFINE_IR_COMPUTE(Reduce); DEFINE_IR_COMPUTE(OutEWiseFusable); - void IRSchedule(ir::IRSchedule& ir_sch, - const GroupPtr& group, - const std::unordered_map& tensor_map); + void IRSchedule( + ir::IRSchedule& ir_sch, + const GroupPtr& group, + const std::unordered_map& tensor_map); Target target_; const absl::flat_hash_map& type_dict_; diff --git a/paddle/cinn/hlir/framework/op_lowering_test.cc b/paddle/cinn/hlir/framework/op_lowering_test.cc index 34ecabee866a5..d63dd7d4d6283 100644 --- a/paddle/cinn/hlir/framework/op_lowering_test.cc +++ b/paddle/cinn/hlir/framework/op_lowering_test.cc @@ -37,7 +37,7 @@ void CodeGen(ir::LoweredFunc& func) { Module::Builder builder("module_builder", target); builder.AddFunction(func); - auto module = builder.Build(); + auto module = builder.Build(); auto compiler = backends::Compiler::Create(target); std::string code = ""; @@ -49,22 +49,27 @@ void CodeGen(ir::LoweredFunc& func) { CodeGenCX86 codegen(target, CodeGenCX86::Feature::AVX512); codegen.SetInlineBuiltinCodes(false); - auto source_code = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + auto source_code = + codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); LOG(INFO) << "compiled code of " << func->name << "is:\n\n\n" << source_code; #endif } void Compile(NetBuilder& net_builder) { auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); + auto& dtype_dict = + graph->GetMutableAttrs>( + "inferdtype"); + auto& shape_dict = + graph->GetMutableAttrs>( + "infershape"); OpLowerer op_lowerer(dtype_dict, shape_dict, target); for (auto& fusion_op : graph->fusion_groups) { @@ -129,25 +134,28 @@ TEST(OP_LOWERING, Reduce_With_Last_Axis_1) { TEST(OP_LOWERING, Reduce_Fuse_Broadcast_With_Output) { NetBuilder net_builder("Reduce_Fuse_Broadcast_With_Output"); - auto layer_norm_51__tmp_1 = net_builder.CreateInput(Float(32), {256}, "layer_norm_51__tmp_1"); - auto var_3216 = net_builder.CreateInput(Float(32), {256, 60}, "var_3216"); - auto var_3202 = net_builder.CreateInput(Float(32), {1, 60}, "var_3202"); - auto var_3212 = net_builder.CreateInput(Float(32), {256, 60}, "var_3212"); - - auto var_3206 = net_builder.Reshape(layer_norm_51__tmp_1, {256, 1}); - auto composite_tmp_8 = net_builder.FillConstant({256, 1}, 1e-5, "composite_tmp_8"); - auto var_3214 = net_builder.Add(var_3206, composite_tmp_8); - auto composite_tmp_10 = net_builder.FillConstant({256, 1}, 1.0, "composite_tmp_10"); - auto var_3220 = net_builder.Divide(composite_tmp_10, var_3214); - auto var_3226 = net_builder.Sqrt(var_3220); - auto var_3224 = net_builder.Scale(var_3220, -1.0, 0.0, true); - auto var_3366 = net_builder.BroadcastTo(var_3224, {256, 60}); - auto var_3228 = net_builder.Multiply(var_3366, var_3216); - auto var_3368 = net_builder.BroadcastTo(var_3202, {256, 60}); - auto var_3236 = net_builder.Multiply(var_3228, var_3212); - auto var_3244 = net_builder.Multiply(var_3236, var_3368); - auto var_3252 = net_builder.ReduceSum(var_3244, {1}, true); - auto var_3232 = net_builder.Scale(var_3226, 0.0166667, 0.0, true); + auto layer_norm_51__tmp_1 = + net_builder.CreateInput(Float(32), {256}, "layer_norm_51__tmp_1"); + auto var_3216 = net_builder.CreateInput(Float(32), {256, 60}, "var_3216"); + auto var_3202 = net_builder.CreateInput(Float(32), {1, 60}, "var_3202"); + auto var_3212 = net_builder.CreateInput(Float(32), {256, 60}, "var_3212"); + + auto var_3206 = net_builder.Reshape(layer_norm_51__tmp_1, {256, 1}); + auto composite_tmp_8 = + net_builder.FillConstant({256, 1}, 1e-5, "composite_tmp_8"); + auto var_3214 = net_builder.Add(var_3206, composite_tmp_8); + auto composite_tmp_10 = + net_builder.FillConstant({256, 1}, 1.0, "composite_tmp_10"); + auto var_3220 = net_builder.Divide(composite_tmp_10, var_3214); + auto var_3226 = net_builder.Sqrt(var_3220); + auto var_3224 = net_builder.Scale(var_3220, -1.0, 0.0, true); + auto var_3366 = net_builder.BroadcastTo(var_3224, {256, 60}); + auto var_3228 = net_builder.Multiply(var_3366, var_3216); + auto var_3368 = net_builder.BroadcastTo(var_3202, {256, 60}); + auto var_3236 = net_builder.Multiply(var_3228, var_3212); + auto var_3244 = net_builder.Multiply(var_3236, var_3368); + auto var_3252 = net_builder.ReduceSum(var_3244, {1}, true); + auto var_3232 = net_builder.Scale(var_3226, 0.0166667, 0.0, true); Compile(net_builder); } @@ -168,7 +176,7 @@ TEST(OP_LOWERING, Reduce_Fuse_Broadcast_Layernorm) { // constant w auto E = net_builder.FillConstant({h}, 1024.0f, "E"); // mean - auto F = net_builder.Divide(C, E); + auto F = net_builder.Divide(C, E); auto FF = net_builder.BroadcastTo(F, {h, w}, {0}); // mean x*x auto G = net_builder.Divide(D, E); @@ -181,7 +189,7 @@ TEST(OP_LOWERING, Reduce_Fuse_Broadcast_Layernorm) { // eps + delta auto K = net_builder.Add(I, J); // var - auto L = net_builder.Sqrt(K); + auto L = net_builder.Sqrt(K); auto LL = net_builder.BroadcastTo(L, {h, w}, {0}); // x - mean auto M = net_builder.Subtract(A, FF); @@ -515,16 +523,16 @@ TEST(OP_LOWERING, Elementwise_Test_Reshape_After_Reduce) { TEST(OP_LOWERING, Elementwise_Test_Reshape_Fuse_Concat) { NetBuilder net_builder("Elementwise_Test_Reshape_Fuse_Concat"); { - auto A = net_builder.CreateInput(Float(32), {8, 8, 8, 8}, "A"); - auto B = net_builder.Reshape(A, {16, 16, 16}); - auto C = net_builder.CreateInput(Float(32), {16, 16}, "C"); - auto D = net_builder.CreateInput(Float(32), {16, 16}, "D"); + auto A = net_builder.CreateInput(Float(32), {8, 8, 8, 8}, "A"); + auto B = net_builder.Reshape(A, {16, 16, 16}); + auto C = net_builder.CreateInput(Float(32), {16, 16}, "C"); + auto D = net_builder.CreateInput(Float(32), {16, 16}, "D"); auto DT = net_builder.Transpose(D, {1, 0}); - auto E = net_builder.Add(C, DT); - auto F = net_builder.BroadcastTo(E, {16, 16, 16}, {1, 2}); - auto G = net_builder.Add(B, F); - auto H = net_builder.CreateInput(Float(32), {16, 16, 16}, "H"); - auto I = net_builder.Concat({G, H}, 2); + auto E = net_builder.Add(C, DT); + auto F = net_builder.BroadcastTo(E, {16, 16, 16}, {1, 2}); + auto G = net_builder.Add(B, F); + auto H = net_builder.CreateInput(Float(32), {16, 16, 16}, "H"); + auto I = net_builder.Concat({G, H}, 2); } Compile(net_builder); @@ -563,7 +571,7 @@ TEST(OP_LOWERING, Elementwise_TEST_Split_2) { TEST(OP_LOWERING, Elementwise_TEST_0) { NetBuilder net_builder("Elementwise_TEST_0"); { - auto x = net_builder.FillConstant({1}, 128.0, "x"); + auto x = net_builder.FillConstant({1}, 128.0, "x"); auto o1 = net_builder.Scale(x, -1.0, 0.0); auto o2 = net_builder.Scale(x, -1.0, 0.0); } @@ -1195,11 +1203,12 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_21) { */ TEST(OpFusionPass, Block_Reduce_Fuse_Broadcast) { - int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); - int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = + common::DefaultNVGPUTarget().get_max_threads_per_sm(); int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; - int h = warp_reduce_threshold - 10; - int w = 256; + int h = warp_reduce_threshold - 10; + int w = 256; NetBuilder net_builder("Block_Reduce_Fuse_Broadcast"); // create model { @@ -1212,11 +1221,12 @@ TEST(OpFusionPass, Block_Reduce_Fuse_Broadcast) { } TEST(OpFusionPass, Block_Reduce_Fuse_Elementwise) { - int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); - int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = + common::DefaultNVGPUTarget().get_max_threads_per_sm(); int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; - int h = warp_reduce_threshold - 10; - int w = 256; + int h = warp_reduce_threshold - 10; + int w = 256; NetBuilder net_builder("Block_Reduce_Fuse_Elementwise"); // create model { @@ -1229,11 +1239,12 @@ TEST(OpFusionPass, Block_Reduce_Fuse_Elementwise) { Compile(net_builder); } TEST(OpFusionPass, Warp_Reduce_Fuse_Broadcast) { - int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); - int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = + common::DefaultNVGPUTarget().get_max_threads_per_sm(); int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; - int h = warp_reduce_threshold + 10; - int w = 256; + int h = warp_reduce_threshold + 10; + int w = 256; NetBuilder net_builder("Warp_Reduce_Fuse_Broadcast"); // create model { @@ -1246,11 +1257,12 @@ TEST(OpFusionPass, Warp_Reduce_Fuse_Broadcast) { } TEST(OpFusionPass, Warp_Reduce_Fuse_Elementwise) { - int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); - int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = + common::DefaultNVGPUTarget().get_max_threads_per_sm(); int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; - int h = warp_reduce_threshold + 10; - int w = 256; + int h = warp_reduce_threshold + 10; + int w = 256; NetBuilder net_builder("Warp_Reduce_Fuse_Elementwise"); // create model { diff --git a/paddle/cinn/hlir/framework/op_lowering_util.cc b/paddle/cinn/hlir/framework/op_lowering_util.cc index 4eb45b1d2884b..807d70eb864d0 100644 --- a/paddle/cinn/hlir/framework/op_lowering_util.cc +++ b/paddle/cinn/hlir/framework/op_lowering_util.cc @@ -41,46 +41,61 @@ std::vector GetInputNodeData(const Node* node) { return producers; } -ir::Tensor GetTensor(const NodeData* node_data, - const absl::flat_hash_map& type_dict, - const absl::flat_hash_map& shape_dict) { +ir::Tensor GetTensor( + const NodeData* node_data, + const absl::flat_hash_map& type_dict, + const absl::flat_hash_map& shape_dict) { auto dtype = type_dict.at(node_data->id()); if (dtype.is_float(32)) { - return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + return lang::Placeholder(node_data->id(), + shape_dict.at(node_data->id())); } else if (dtype.is_float(64)) { - return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + return lang::Placeholder(node_data->id(), + shape_dict.at(node_data->id())); } else if (dtype.is_bfloat16()) { - return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + return lang::Placeholder(node_data->id(), + shape_dict.at(node_data->id())); } else if (dtype.is_float16()) { - return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + return lang::Placeholder(node_data->id(), + shape_dict.at(node_data->id())); } else if (dtype.is_bool()) { - return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + return lang::Placeholder(node_data->id(), + shape_dict.at(node_data->id())); } else if (dtype.is_int(8)) { - return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + return lang::Placeholder(node_data->id(), + shape_dict.at(node_data->id())); } else if (dtype.is_int(16)) { - return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + return lang::Placeholder(node_data->id(), + shape_dict.at(node_data->id())); } else if (dtype.is_int(32)) { - return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + return lang::Placeholder(node_data->id(), + shape_dict.at(node_data->id())); } else if (dtype.is_int(64)) { - return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + return lang::Placeholder(node_data->id(), + shape_dict.at(node_data->id())); } else if (dtype.is_uint(8)) { - return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + return lang::Placeholder(node_data->id(), + shape_dict.at(node_data->id())); } else if (dtype.is_uint(16)) { - return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + return lang::Placeholder(node_data->id(), + shape_dict.at(node_data->id())); } else if (dtype.is_uint(32)) { - return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + return lang::Placeholder(node_data->id(), + shape_dict.at(node_data->id())); } else if (dtype.is_uint(64)) { - return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); + return lang::Placeholder(node_data->id(), + shape_dict.at(node_data->id())); } else { LOG(FATAL) << "Unsupport dtype: " << dtype; } } -std::vector CollectInputTensor(const Node* node, - std::vector& func_args, - std::unordered_map& tensor_map, - const absl::flat_hash_map& type_dict, - const absl::flat_hash_map& shape_dict) { +std::vector CollectInputTensor( + const Node* node, + std::vector& func_args, + std::unordered_map& tensor_map, + const absl::flat_hash_map& type_dict, + const absl::flat_hash_map& shape_dict) { std::vector tensors; // get all input nodes for (auto& node_data : GetInputNodeData(node)) { @@ -124,7 +139,8 @@ std::vector GetConsumers(const Node* node) { return consumers; } -std::vector GetConsumersInSet(const Node* node, const std::unordered_set& node_set) { +std::vector GetConsumersInSet( + const Node* node, const std::unordered_set& node_set) { std::vector consumers; auto node_data = GetNodeData(node); for (auto& link : node_data->outlinks()) { @@ -149,7 +165,8 @@ std::vector GetProducers(const Node* node) { return producers; } -std::vector GetProducersInSet(const Node* node, const std::unordered_set& node_set) { +std::vector GetProducersInSet( + const Node* node, const std::unordered_set& node_set) { std::vector producers; for (auto& link : node->inlinks_in_order()) { auto data = link->source()->safe_as(); @@ -162,7 +179,8 @@ std::vector GetProducersInSet(const Node* node, const std::unordered_set< } bool IsConstOp(const framework::Node* node) { - static std::unordered_set const_op_type = {"const_scalar", "fill_constant", "arange"}; + static std::unordered_set const_op_type = { + "const_scalar", "fill_constant", "arange"}; if (const_op_type.count(node->op()->name)) { return true; } else { @@ -170,23 +188,29 @@ bool IsConstOp(const framework::Node* node) { } } -std::vector GetInputShape(const Node* node, const absl::flat_hash_map& shape_dict) { +std::vector GetInputShape( + const Node* node, + const absl::flat_hash_map& shape_dict) { const auto& in_links = node->inlinks_in_order(); - CHECK(!in_links.empty()) << "Cannot get input shape from a no-input op \"" << node->id() << "\""; + CHECK(!in_links.empty()) << "Cannot get input shape from a no-input op \"" + << node->id() << "\""; auto* producer_data = in_links.front()->source()->safe_as(); CHECK_NOTNULL(producer_data); return shape_dict.at(producer_data->id()); } -std::vector GetOutputShape(const Node* node, const absl::flat_hash_map& shape_dict) { +std::vector GetOutputShape( + const Node* node, + const absl::flat_hash_map& shape_dict) { auto node_data = GetNodeData(node); return shape_dict.at(node_data->id()); } Node* FindGlobalReducer(const std::vector& nodes_in_order) { auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - for (auto iter = nodes_in_order.rbegin(); iter != nodes_in_order.rend(); ++iter) { + for (auto iter = nodes_in_order.rbegin(); iter != nodes_in_order.rend(); + ++iter) { if (op_pattern_dict[(*iter)->op()] == framework::kReduction) { return *iter; } @@ -195,8 +219,11 @@ Node* FindGlobalReducer(const std::vector& nodes_in_order) { return nullptr; } -using Visitor = std::function(const Node*, const std::unordered_set&)>; -Node* FindReducerInRoute(const Node* node, const std::unordered_set& nodes_set, Visitor visitor) { +using Visitor = std::function( + const Node*, const std::unordered_set&)>; +Node* FindReducerInRoute(const Node* node, + const std::unordered_set& nodes_set, + Visitor visitor) { auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); std::queue candidates; candidates.push(node); @@ -215,7 +242,8 @@ Node* FindReducerInRoute(const Node* node, const std::unordered_set& node return nullptr; } -Node* FindNearestReducer(const Node* node, const std::unordered_set& nodes_set) { +Node* FindNearestReducer(const Node* node, + const std::unordered_set& nodes_set) { auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); // from consumers find reducer. auto reducer = FindReducerInRoute(node, nodes_set, GetConsumersInSet); @@ -225,8 +253,9 @@ Node* FindNearestReducer(const Node* node, const std::unordered_set& node return FindReducerInRoute(node, nodes_set, GetProducersInSet); } -std::unordered_map BuildVirtualConsumer(const GroupPtr& group, - const absl::flat_hash_map& shape_dict) { +std::unordered_map BuildVirtualConsumer( + const GroupPtr& group, + const absl::flat_hash_map& shape_dict) { std::unordered_map virtual_consumers; std::unordered_set nodes_set = group->NodeSet(); if (group->op_pattern_kind != framework::kReduction) { @@ -276,10 +305,11 @@ std::unordered_map BuildVirtualConsumer(const GroupPtr& group, continue; } - auto reducer = FindReducerInRoute(producer, nodes_set, GetConsumersInSet); + auto reducer = + FindReducerInRoute(producer, nodes_set, GetConsumersInSet); if (reducer) { virtual_consumers[t_node] = reducer; - found = true; + found = true; break; } candidates.push(producer); @@ -290,23 +320,36 @@ std::unordered_map BuildVirtualConsumer(const GroupPtr& group, auto output_shape = GetOutputShape(t_node, shape_dict); if (!found && t_node != e_node && e_node) { auto enode_output_shape = GetOutputShape(e_node, shape_dict); - if (std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()) == - std::accumulate(enode_output_shape.begin(), enode_output_shape.end(), 1, std::multiplies())) { + if (std::accumulate(output_shape.begin(), + output_shape.end(), + 1, + std::multiplies()) == + std::accumulate(enode_output_shape.begin(), + enode_output_shape.end(), + 1, + std::multiplies())) { virtual_consumers[t_node] = e_node; - found = true; + found = true; } } if (!found && r_node) { auto rnode_input_shape = GetInputShape(r_node, shape_dict); - if (std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()) == - std::accumulate(rnode_input_shape.begin(), rnode_input_shape.end(), 1, std::multiplies())) { + if (std::accumulate(output_shape.begin(), + output_shape.end(), + 1, + std::multiplies()) == + std::accumulate(rnode_input_shape.begin(), + rnode_input_shape.end(), + 1, + std::multiplies())) { virtual_consumers[t_node] = r_node; - found = true; + found = true; } } } - // Establish virtual consumer relationships between output nodes with the same shape. - // This allows the calculation of output nodes without affiliation to be placed under the same loop. + // Establish virtual consumer relationships between output nodes with the same + // shape. This allows the calculation of output nodes without affiliation to + // be placed under the same loop. std::unordered_map numel_consumers; for (auto out_node : group->output_nodes) { if (virtual_consumers.find(out_node) != virtual_consumers.end() || @@ -314,7 +357,8 @@ std::unordered_map BuildVirtualConsumer(const GroupPtr& group, continue; } auto shape = GetOutputShape(out_node, shape_dict); - int numel = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + int numel = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); if (numel_consumers.find(numel) == numel_consumers.end()) { numel_consumers.insert(std::make_pair(numel, out_node)); } else { @@ -325,9 +369,10 @@ std::unordered_map BuildVirtualConsumer(const GroupPtr& group, return virtual_consumers; } -std::vector FindConsumers(Node* node, - const std::unordered_set& nodes_set, - const std::unordered_map& virtual_consumers) { +std::vector FindConsumers( + Node* node, + const std::unordered_set& nodes_set, + const std::unordered_map& virtual_consumers) { auto consumers = GetConsumersInSet(node, nodes_set); if (virtual_consumers.count(node)) { consumers.push_back(virtual_consumers.find(node)->second); @@ -335,9 +380,10 @@ std::vector FindConsumers(Node* node, return consumers; } -std::vector FindProducers(Node* node, - const std::unordered_set& nodes_set, - const std::unordered_map& virtual_consumers) { +std::vector FindProducers( + Node* node, + const std::unordered_set& nodes_set, + const std::unordered_map& virtual_consumers) { auto producers = GetProducersInSet(node, nodes_set); for (const auto& iter : virtual_consumers) { if (iter.second == node) { @@ -348,14 +394,17 @@ std::vector FindProducers(Node* node, return producers; } -std::vector TopologicalOrder(const GroupPtr& group, const std::unordered_map& virtual_consumers) { +std::vector TopologicalOrder( + const GroupPtr& group, + const std::unordered_map& virtual_consumers) { std::vector nodes_in_order; std::unordered_set nodes_set = group->NodeSet(); while (!nodes_set.empty()) { - std::set tmp_node_set(nodes_set.begin(), nodes_set.end()); + std::set tmp_node_set(nodes_set.begin(), + nodes_set.end()); for (auto node : tmp_node_set) { - auto consumers = FindConsumers(node, nodes_set, virtual_consumers); + auto consumers = FindConsumers(node, nodes_set, virtual_consumers); bool cant_be_erase = false; for (auto consumer : consumers) { if (nodes_set.count(consumer)) { @@ -373,23 +422,29 @@ std::vector TopologicalOrder(const GroupPtr& group, const std::unordered_ return nodes_in_order; } -std::vector BFSTopologicalOrderWithPriority(const GroupPtr& group, - const std::unordered_map& virtual_consumers, - const absl::flat_hash_map& shape_dict) { +std::vector BFSTopologicalOrderWithPriority( + const GroupPtr& group, + const std::unordered_map& virtual_consumers, + const absl::flat_hash_map& shape_dict) { struct NodeWithPriority { Node* node; int priority; }; struct Comparator { - bool operator()(const NodeWithPriority& lhs, const NodeWithPriority& rhs) { return lhs.priority > rhs.priority; } + bool operator()(const NodeWithPriority& lhs, const NodeWithPriority& rhs) { + return lhs.priority > rhs.priority; + } }; std::vector nodes_in_order; std::unordered_set visited; std::unordered_set nodes_set = group->NodeSet(); std::unordered_map degree_map; - std::priority_queue, Comparator> priority_candidates; + std::priority_queue, + Comparator> + priority_candidates; std::vector visited_numel; // Calculate the priority of a node. @@ -397,8 +452,9 @@ std::vector BFSTopologicalOrderWithPriority(const GroupPtr& group, // Prioritize the same shape before considering OpPattern auto PriorityFunc = [&visited_numel, &shape_dict](const Node* node) -> int { auto node_shape = GetOutputShape(node, shape_dict); - int numel = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); - int index = -1; + int numel = std::accumulate( + node_shape.begin(), node_shape.end(), 1, std::multiplies()); + int index = -1; for (int i = 0; i < visited_numel.size(); ++i) { if (numel == visited_numel[i]) { index = i; @@ -415,16 +471,18 @@ std::vector BFSTopologicalOrderWithPriority(const GroupPtr& group, for (Node* node : nodes_set) { auto consumers = FindConsumers(node, nodes_set, virtual_consumers); - // Some nodes may have multiple edges between them, resulting in duplicates in the consumer. - // We only need to calculate once. - std::unordered_set consumers_without_duplicate(consumers.begin(), consumers.end()); + // Some nodes may have multiple edges between them, resulting in duplicates + // in the consumer. We only need to calculate once. + std::unordered_set consumers_without_duplicate(consumers.begin(), + consumers.end()); degree_map[node] = consumers_without_duplicate.size(); if (degree_map.at(node) == 0) { priority_candidates.push(NodeWithPriority{node, PriorityFunc(node)}); } } - // Nested BFS, outer layer traverses priority, inner layer performs BFS on current priority. + // Nested BFS, outer layer traverses priority, inner layer performs BFS on + // current priority. while (!priority_candidates.empty()) { Node* cur_priority_node = priority_candidates.top().node; priority_candidates.pop(); @@ -438,10 +496,12 @@ std::vector BFSTopologicalOrderWithPriority(const GroupPtr& group, nodes_in_order.push_back(cur); auto producers = FindProducers(cur, nodes_set, virtual_consumers); - std::unordered_set producers_without_duplicate(producers.begin(), producers.end()); + std::unordered_set producers_without_duplicate(producers.begin(), + producers.end()); for (Node* node : producers_without_duplicate) { --degree_map[node]; - // Ensure that each node is accessed only once and maintain topological order. + // Ensure that each node is accessed only once and maintain topological + // order. if (visited.count(node) != 0 || degree_map[node] != 0) { continue; } @@ -460,7 +520,8 @@ std::vector BFSTopologicalOrderWithPriority(const GroupPtr& group, return nodes_in_order; } -bool WithoutLastDimInReduce(const std::vector& shape, const std::vector& axes) { +bool WithoutLastDimInReduce(const std::vector& shape, + const std::vector& axes) { if (axes.empty()) { return false; } @@ -506,7 +567,7 @@ void LoopOrderAssignReduce(ir::IRSchedule& ir_sch, } // fuse others none-reduce axis. int last_dimension_num = n_out_dims - axes.back() - 1; - int index = n_out_dims - last_dimension_num - axes.size(); + int index = n_out_dims - last_dimension_num - axes.size(); // fuse last_dimension_num - 1 times for (auto idx = index; idx < index + last_dimension_num - 1; ++idx) { @@ -537,16 +598,19 @@ void LoopAssignReduceWithoutLast(ir::IRSchedule& ir_sch, const std::vector& inshape, const std::vector& axes, const common::Target& target) { - int tail = 0; + int tail = 0; bool bound = true; auto shape = pe::GetFirstStepReduceShape(inshape, axes, bound, tail); - CHECK(bound) << std::accumulate( - inshape.begin(), inshape.end(), std::string(""), [](const std::string& left, const int right) { - return left + std::to_string(right) + " "; - }); - - VLOG(4) << "LoopAssignReduceWithoutLast: THe input shape=[" << cinn::utils::Join(inshape, ", ") - << "], first step reduce shape=[" << cinn::utils::Join(shape, ", ") << "]" + CHECK(bound) << std::accumulate(inshape.begin(), + inshape.end(), + std::string(""), + [](const std::string& left, const int right) { + return left + std::to_string(right) + " "; + }); + + VLOG(4) << "LoopAssignReduceWithoutLast: THe input shape=[" + << cinn::utils::Join(inshape, ", ") << "], first step reduce shape=[" + << cinn::utils::Join(shape, ", ") << "]" << ", axes=[" << cinn::utils::Join(axes, ", ") << "], tail=" << tail; // remove loop size = 1 and remove axis in axes. @@ -651,20 +715,21 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, // If the number of current device SM is smaller than the number of SM // required by Warp Reduce, the performance of Warp Reduce is better. // Otherwise, use Block Reduce. - auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); int need_reduce_last_count = 1; for (int i = 0; i < inshape.size(); i++) { if (find(axes.begin(), axes.end(), i) == axes.end()) { need_reduce_last_count *= inshape[i]; } } - int warp_reduce_need_sm_count = ceil((need_reduce_last_count * 32) / float(target.get_max_threads_per_sm())); + int warp_reduce_need_sm_count = ceil((need_reduce_last_count * 32) / + float(target.get_max_threads_per_sm())); // Set Num_max_threads to 32 is Warp Reduce if (target.get_multi_processor_count() < warp_reduce_need_sm_count) { max_num_threads = 32; } // find first reduce and second reduce axis. - int lane = 1; + int lane = 1; int index = static_cast(axes.size()) - 1; for (; index >= 0; --index) { @@ -673,7 +738,8 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, } lane *= inshape[axes[index]]; if (index == 0 && lane <= max_num_threads) { - LOG(FATAL) << "Error! lane is less equal than max_num_threads, Please check!"; + LOG(FATAL) + << "Error! lane is less equal than max_num_threads, Please check!"; } if (lane >= max_num_threads / 2) { if (lane <= max_num_threads) { @@ -686,7 +752,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, if (lane > max_num_threads) { // last reduce axis size > 1024 if (index == static_cast(axes.size()) - 1) { - int tail = max_num_threads; + int tail = max_num_threads; bool check_bound = true; for (; tail >= max_num_threads / 2; --tail) { if (lane % tail == 0) { @@ -695,7 +761,8 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, } } if (check_bound) { - lane = ((lane + max_num_threads - 1) / max_num_threads) * max_num_threads; + lane = + ((lane + max_num_threads - 1) / max_num_threads) * max_num_threads; ir_sch.Split(block_name, axes[index], {lane}); } int idx = max_num_threads; @@ -709,23 +776,25 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, // if can't be divide by(1024, 512), it's shouldn't be fused. CHECK_GE(idx, max_num_threads / 2) << "Check bounds exist, can't fuse!"; } else { - int axis = axes[index]; + int axis = axes[index]; int prefix = inshape[axis]; - int tail = lane / prefix; - for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; --idx) { + int tail = lane / prefix; + for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; + --idx) { if (prefix % idx == 0) { ir_sch.Split(block_name, axis, {-1, idx}); break; } - CHECK_GT(idx, (max_num_threads / 2) / tail) << "Error, it's shouldn't fuse!"; + CHECK_GT(idx, (max_num_threads / 2) / tail) + << "Error, it's shouldn't fuse!"; } } LoopOrderAssignReduce(ir_sch, block_name, first_axes, target); // The current one-dimensional reduce does not make full use of SM. // This case is optimized into a two-dimensional. - auto loops = ir_sch.GetLoops(block_name); + auto loops = ir_sch.GetLoops(block_name); auto block_dim_x = loops[1].As()->extent.as_int32(); - int block_dim_y = block_dim_x <= 32 ? 2 : 1; + int block_dim_y = block_dim_x <= 32 ? 2 : 1; if (block_dim_y != 1) { ir_sch.Split(loops[0], {-1, block_dim_y}); } @@ -776,11 +845,17 @@ bool CanbeInline(Node* node, // node is before reducer and node is not after reduce. if (FindReducerInRoute(node, nodes_set, GetConsumersInSet) && !FindReducerInRoute(node, nodes_set, GetProducersInSet)) { - auto node_shape = GetOutputShape(node, shape_dict); + auto node_shape = GetOutputShape(node, shape_dict); auto input_shape = GetInputShape(reducer, shape_dict); // check with same shape with reducer input. - if (std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()) != - std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies())) { + if (std::accumulate(node_shape.begin(), + node_shape.end(), + 1, + std::multiplies()) != + std::accumulate(input_shape.begin(), + input_shape.end(), + 1, + std::multiplies())) { return true; } } @@ -788,11 +863,13 @@ bool CanbeInline(Node* node, return false; } else { auto node_shape = GetOutputShape(node, shape_dict); - auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); + auto node_size = std::accumulate( + node_shape.begin(), node_shape.end(), 1, std::multiplies()); for (auto master : masters) { auto master_shape = GetOutputShape(master, shape_dict); - auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); + auto master_size = std::accumulate( + master_shape.begin(), master_shape.end(), 1, std::multiplies()); if (node_size != master_size) { return true; } @@ -802,12 +879,13 @@ bool CanbeInline(Node* node, } } -Node* GetMasterToComputeAt(Node* node, - const std::vector& nodes_in_order, - const std::unordered_set& nodes_inline, - const std::unordered_set& nodes_set, - const std::unordered_map& virtual_consumers, - const absl::flat_hash_map& shape_dict) { +Node* GetMasterToComputeAt( + Node* node, + const std::vector& nodes_in_order, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set, + const std::unordered_map& virtual_consumers, + const absl::flat_hash_map& shape_dict) { auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); // if node is reduction, try find horizontal to compute at. if (op_pattern_dict[node->op()] == framework::kReduction) { @@ -846,7 +924,8 @@ Node* GetMasterToComputeAt(Node* node, if (done_schedule.size()) { auto shape = shape_dict.at(node->inlinks_in_order()[0]->source()->id()); for (auto rnode : done_schedule) { - auto rshape = shape_dict.at(rnode->inlinks_in_order()[0]->source()->id()); + auto rshape = + shape_dict.at(rnode->inlinks_in_order()[0]->source()->id()); if (shape == rshape) { return rnode; } @@ -892,18 +971,19 @@ Node* GetMasterToComputeAt(Node* node, return nullptr; } -void LoopAssignReduce(ir::IRSchedule& ir_sch, - const Node* node, - const Node* reducer, - const Target& target, - const std::unordered_map& tensor_map, - const absl::flat_hash_map& shape_dict) { +void LoopAssignReduce( + ir::IRSchedule& ir_sch, + const Node* node, + const Node* reducer, + const Target& target, + const std::unordered_map& tensor_map, + const absl::flat_hash_map& shape_dict) { auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); // if node is reducer, return. if (op_pattern_dict[node->op()] == framework::kReduction) { return; } - auto node_data = GetNodeData(node); + auto node_data = GetNodeData(node); auto reducer_data = GetNodeData(reducer); // flatten loops. @@ -918,14 +998,15 @@ void LoopAssignReduce(ir::IRSchedule& ir_sch, // shape and axis. CHECK(shape_dict.count(reducer->inlinks_in_order()[0]->source()->id())); auto shape = shape_dict.at(reducer->inlinks_in_order()[0]->source()->id()); - auto axes = absl::get>(reducer->attrs.attr_store.at("dim")); + auto axes = absl::get>(reducer->attrs.attr_store.at("dim")); if (axes.empty()) { for (int idx = 0; idx < shape.size(); idx++) { axes.push_back(idx); } } - auto copy_loop_info = [](std::vector& loops, std::vector& rloops) { + auto copy_loop_info = [](std::vector& loops, + std::vector& rloops) { for (int idx = 0; idx < std::min(rloops.size(), loops.size()); ++idx) { auto l0 = rloops[idx].As(); auto l1 = loops[idx].As(); @@ -937,46 +1018,54 @@ void LoopAssignReduce(ir::IRSchedule& ir_sch, auto node_shape = shape_dict.at(node_data->id()); // The output shape of node is different from that of reduce node if (std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) != - std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies())) { + std::accumulate( + node_shape.begin(), node_shape.end(), 1, std::multiplies())) { // get loop factors of reduce node int extend = 1; std::vector factors; - loops = ir_sch.GetLoops(node_data->id()); + loops = ir_sch.GetLoops(node_data->id()); auto rloops = ir_sch.GetLoops(reducer_data->id()); for (auto& loop : rloops) { - if (extend >= loops.back().As()->extent.as_int32() && factors.size() && - loop.As()->extent.as_int32() > 1) { + if (extend >= loops.back().As()->extent.as_int32() && + factors.size() && loop.As()->extent.as_int32() > 1) { break; } extend *= loop.As()->extent.as_int32(); factors.push_back(loop.As()->extent.as_int32()); } - // If there are IfThenElse stmt in loop, we need to find out the indices in condition, - // and special treatment should be applied to loops with these indices. - // We apply two step split on loop of src node to align the loop of reduce node. + // If there are IfThenElse stmt in loop, we need to find out the indices in + // condition, and special treatment should be applied to loops with these + // indices. We apply two step split on loop of src node to align the loop of + // reduce node. std::unordered_set loop_index_in_if; auto first_reduce_loop = rloops.front(); // collect if - auto if_checker = [](const Expr* x) { return x->As(); }; - auto if_set = ir::CollectIRNodesWithoutTensor(first_reduce_loop.As()->body, if_checker); + auto if_checker = [](const Expr* x) { return x->As(); }; + auto if_set = ir::CollectIRNodesWithoutTensor( + first_reduce_loop.As()->body, if_checker); std::string reduce_block_name = reducer_data->id(); for (auto if_expr : if_set) { auto checker = [reduce_block_name](const Expr* x) { return x->As() && - x->As()->schedule_block.As()->name == reduce_block_name; + x->As() + ->schedule_block.As() + ->name == reduce_block_name; }; auto blocks_in_if = ir::CollectIRNodesWithoutTensor(if_expr, checker); if (!blocks_in_if.empty()) { ir::Expr condition = if_expr.As()->condition; - auto indices_in_if = - ir::CollectIRNodesWithoutTensor(condition, [](const Expr* x) { return x->As(); }); + auto indices_in_if = ir::CollectIRNodesWithoutTensor( + condition, [](const Expr* x) { return x->As(); }); for (int i = 0; i < rloops.size(); ++i) { std::string var_name = rloops[i].As()->loop_var->name; - auto find_var_iter = std::find_if(indices_in_if.begin(), indices_in_if.end(), [&var_name](const ir::Expr& x) { - return x.As()->name == var_name; - }); + auto find_var_iter = + std::find_if(indices_in_if.begin(), + indices_in_if.end(), + [&var_name](const ir::Expr& x) { + return x.As()->name == var_name; + }); if (find_var_iter != indices_in_if.end()) { loop_index_in_if.insert(i); } @@ -992,11 +1081,13 @@ void LoopAssignReduce(ir::IRSchedule& ir_sch, for (int i = 0; i < factors.size(); ++i) { if (loop_index_in_if.count(i) == 0) { first_step_factors.push_back(factors[i]); - } else if (loop_index_in_if.count(i) != 0 && second_step_factors.empty()) { + } else if (loop_index_in_if.count(i) != 0 && + second_step_factors.empty()) { first_step_factors.push_back(-1); second_step_factors.push_back(factors[i]); second_start_loop_index = i; - } else if (loop_index_in_if.count(i) != 0 && !second_step_factors.empty()) { + } else if (loop_index_in_if.count(i) != 0 && + !second_step_factors.empty()) { second_step_factors.push_back(factors[i]); } } @@ -1019,13 +1110,16 @@ void LoopAssignReduce(ir::IRSchedule& ir_sch, if (WithoutLastDimInReduce(shape, axes)) { // if using two strep reduce. if (tensor_map.count(reducer_data->id() + "_1")) { - VLOG(4) << "Try assign loop of " << node_data->id() << " into two strep reduce loop of " << reducer_data->id(); + VLOG(4) << "Try assign loop of " << node_data->id() + << " into two strep reduce loop of " << reducer_data->id(); LoopAssignReduceWithoutLast(ir_sch, node_data->id(), shape, axes, target); auto nloops = ir_sch.GetLoops(node_data->id()); - auto rloops = ir_sch.GetLoops(tensor_map.find(reducer_data->id() + "_0")->second->name); + auto rloops = ir_sch.GetLoops( + tensor_map.find(reducer_data->id() + "_0")->second->name); - VLOG(4) << node_data->id() << "'s loop level is " << nloops.size() << ", and " << reducer_data->id() - << "'s loop level is " << rloops.size(); + VLOG(4) << node_data->id() << "'s loop level is " << nloops.size() + << ", and " << reducer_data->id() << "'s loop level is " + << rloops.size(); if (nloops.size() < rloops.size()) { ir_sch.Split(nloops[0], {1, -1}); } @@ -1034,13 +1128,15 @@ void LoopAssignReduce(ir::IRSchedule& ir_sch, // copy loop info form rloops. copy_loop_info(nloops, rloops); } else { - VLOG(4) << "Try assign loop of " << node_data->id() << " into reduce loop of " << reducer_data->id(); + VLOG(4) << "Try assign loop of " << node_data->id() + << " into reduce loop of " << reducer_data->id(); auto nloops = ir_sch.GetLoops(node_data->id()); ir_sch.Split(nloops.back(), shape); LoopOrderAssignReduce(ir_sch, node_data->id(), axes, target); - nloops = ir_sch.GetLoops(node_data->id()); - auto rloops = ir_sch.GetLoops(tensor_map.find(reducer_data->id())->second->name); + nloops = ir_sch.GetLoops(node_data->id()); + auto rloops = + ir_sch.GetLoops(tensor_map.find(reducer_data->id())->second->name); if (nloops.size() < rloops.size()) { ir_sch.Split(nloops[0], {1, -1}); } @@ -1058,7 +1154,8 @@ void LoopAssignReduce(ir::IRSchedule& ir_sch, LoopAssignReduceWithLast(ir_sch, node_data->id(), shape, axes, target); auto nloops = ir_sch.GetLoops(node_data->id()); - auto rloops = ir_sch.GetLoops(tensor_map.find(reducer_data->id() + "_1")->second->name); + auto rloops = ir_sch.GetLoops( + tensor_map.find(reducer_data->id() + "_1")->second->name); if (nloops.size() < rloops.size()) { ir_sch.Split(nloops[0], {1, -1}); } @@ -1093,13 +1190,17 @@ class RemoveExpr : public ir::IRMutator<> { void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } private: - void Visit(const ir::ScheduleBlockRealize* expr, Expr* op) override { IRMutator::Visit(expr, op); } + void Visit(const ir::ScheduleBlockRealize* expr, Expr* op) override { + IRMutator::Visit(expr, op); + } - void Visit(const ir::For* expr, Expr* op) override { IRMutator::Visit(expr, op); } + void Visit(const ir::For* expr, Expr* op) override { + IRMutator::Visit(expr, op); + } void Visit(const ir::Block* expr, Expr* op) override { auto* node = op->As(); - auto iter = std::find(node->stmts.begin(), node->stmts.end(), target_); + auto iter = std::find(node->stmts.begin(), node->stmts.end(), target_); if (iter != node->stmts.end()) { node->stmts.erase(iter); } else { @@ -1113,7 +1214,10 @@ class RemoveExpr : public ir::IRMutator<> { const Expr& target_; }; -void MergeLoops(ir::Expr root, std::vector& src, std::vector& dst, int index) { +void MergeLoops(ir::Expr root, + std::vector& src, + std::vector& dst, + int index) { if (index < 0) { return; } @@ -1133,19 +1237,21 @@ void MergeLoops(ir::Expr root, std::vector& src, std::vector auto src_body = src[index].As()->body; ReplaceExpr(&src_body, src_vars, dst_vars); - dst[index].As()->body = ir::Block::Make({src_body, dst[index].As()->body}); + dst[index].As()->body = + ir::Block::Make({src_body, dst[index].As()->body}); RemoveExpr remove_expr(src[0]); remove_expr(&root); } -void InsertSyncThread(ir::IRSchedule& ir_sch, - const Node* node, - const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map) { +void InsertSyncThread( + ir::IRSchedule& ir_sch, + const Node* node, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map) { CHECK(shape_dict.count(node->inlinks_in_order()[0]->source()->id())); auto shape = shape_dict.at(node->inlinks_in_order()[0]->source()->id()); - auto axes = absl::get>(node->attrs.attr_store.at("dim")); + auto axes = absl::get>(node->attrs.attr_store.at("dim")); if (axes.empty()) { for (int idx = 0; idx < shape.size(); idx++) { axes.push_back(idx); @@ -1155,7 +1261,7 @@ void InsertSyncThread(ir::IRSchedule& ir_sch, return; } - auto node_data = GetNodeData(node); + auto node_data = GetNodeData(node); std::string post = ""; for (int idx = 0;; ++idx) { if (!tensor_map.count(node_data->id() + post)) { @@ -1184,13 +1290,17 @@ class InsertExpr : public ir::IRMutator<> { void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } private: - void Visit(const ir::ScheduleBlockRealize* expr, Expr* op) override { IRMutator::Visit(expr, op); } + void Visit(const ir::ScheduleBlockRealize* expr, Expr* op) override { + IRMutator::Visit(expr, op); + } - void Visit(const ir::For* expr, Expr* op) override { IRMutator::Visit(expr, op); } + void Visit(const ir::For* expr, Expr* op) override { + IRMutator::Visit(expr, op); + } void Visit(const ir::Block* expr, Expr* op) override { auto* node = op->As(); - auto iter = std::find(node->stmts.begin(), node->stmts.end(), anchor_); + auto iter = std::find(node->stmts.begin(), node->stmts.end(), anchor_); if (iter != node->stmts.end()) { node->stmts.insert(iter, target_); } else { @@ -1205,17 +1315,18 @@ class InsertExpr : public ir::IRMutator<> { Expr anchor_; }; -void MergeReduceToReduce(ir::IRSchedule& ir_sch, - const Node* node, - const Node* master, - const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map) { - auto node_data = GetNodeData(node); +void MergeReduceToReduce( + ir::IRSchedule& ir_sch, + const Node* node, + const Node* master, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map) { + auto node_data = GetNodeData(node); auto master_data = GetNodeData(master); CHECK(shape_dict.count(node->inlinks_in_order()[0]->source()->id())); auto shape = shape_dict.at(node->inlinks_in_order()[0]->source()->id()); - auto axes = absl::get>(node->attrs.attr_store.at("dim")); + auto axes = absl::get>(node->attrs.attr_store.at("dim")); if (axes.empty()) { for (int idx = 0; idx < shape.size(); idx++) { axes.push_back(idx); @@ -1277,7 +1388,10 @@ void MergeReduceToReduce(ir::IRSchedule& ir_sch, auto m_loops = ir_sch.GetLoops(m_tensor->name + "__reduce_init"); CHECK_EQ(n_loops.size(), m_loops.size()); - MergeLoops(ir_sch.GetModule().GetExprs().at(0), n_loops, m_loops, n_loops.size() - 1); + MergeLoops(ir_sch.GetModule().GetExprs().at(0), + n_loops, + m_loops, + n_loops.size() - 1); } } else { LOG(FATAL) << "not support this type fusion!"; @@ -1300,7 +1414,7 @@ void MergeReduceToReduce(ir::IRSchedule& ir_sch, } else { // reduce loop { - auto block = ir_sch.GetBlock(node_data->id()); + auto block = ir_sch.GetBlock(node_data->id()); auto nloops = ir_sch.GetLoops(node_data->id()); auto mloops = ir_sch.GetLoops(master_data->id()); for (int idx = 0; idx < mloops.size(); ++idx) { @@ -1389,18 +1503,19 @@ void MergeReduceToReduce(ir::IRSchedule& ir_sch, } } -void MergeReduceLoop(ir::IRSchedule& ir_sch, - Node* node, - const Node* master, - const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map) { +void MergeReduceLoop( + ir::IRSchedule& ir_sch, + Node* node, + const Node* master, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map) { auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); if (op_pattern_dict[master->op()] == kReduction && node != master) { MergeReduceToReduce(ir_sch, node, master, shape_dict, tensor_map); return; } - auto node_data = GetNodeData(node); + auto node_data = GetNodeData(node); auto master_data = GetNodeData(master); int min_index_loop = INT_MAX; @@ -1409,7 +1524,7 @@ void MergeReduceLoop(ir::IRSchedule& ir_sch, if (!tensor_map.count(node_data->id() + post__)) { break; } - auto tensor_ = tensor_map.find(node_data->id() + post_)->second; + auto tensor_ = tensor_map.find(node_data->id() + post_)->second; auto tensor__ = tensor_map.find(node_data->id() + post__)->second; if (!ir_sch.HasBlock(tensor__->name)) { break; @@ -1417,7 +1532,7 @@ void MergeReduceLoop(ir::IRSchedule& ir_sch, auto dst_loops = ir_sch.GetLoops(tensor_->name); auto src_loops = ir_sch.GetLoops(tensor__->name); - int index = -1; + int index = -1; while (src_loops[index + 1].As()->extent.as_int32() == dst_loops[index + 1].As()->extent.as_int32()) { ++index; @@ -1426,25 +1541,30 @@ void MergeReduceLoop(ir::IRSchedule& ir_sch, } } min_index_loop = std::min(min_index_loop, index); - MergeLoops(ir_sch.GetModule().GetExprs().at(0), src_loops, dst_loops, index); + MergeLoops( + ir_sch.GetModule().GetExprs().at(0), src_loops, dst_loops, index); - post_ = "_" + std::to_string(idx); + post_ = "_" + std::to_string(idx); post__ = "_" + std::to_string(idx + 1); } InsertSyncThread(ir_sch, node, shape_dict, tensor_map); if (node == master) return; - auto node_loops = ir_sch.GetLoops(node_data->id()); + auto node_loops = ir_sch.GetLoops(node_data->id()); auto master_loops = ir_sch.GetLoops(master_data->id()); int index = std::min(node_loops.size(), master_loops.size()) - 1; do { // if loop range is not equal. - if (node_loops[index].As()->extent.as_int32() != master_loops[index].As()->extent.as_int32()) { + if (node_loops[index].As()->extent.as_int32() != + master_loops[index].As()->extent.as_int32()) { continue; } - MergeLoops(ir_sch.GetModule().GetExprs().at(0), node_loops, master_loops, std::min(index, min_index_loop)); + MergeLoops(ir_sch.GetModule().GetExprs().at(0), + node_loops, + master_loops, + std::min(index, min_index_loop)); if (index > min_index_loop) { auto block = ir_sch.GetBlock(node_data->id()); auto loops = ir_sch.GetLoops(master_data->id()); @@ -1472,7 +1592,9 @@ class FindExprInBlock : public ir::IRMutator<> { } private: - void Visit(const ir::ScheduleBlockRealize* expr, Expr* op) override { exprs_.push_back(*op); } + void Visit(const ir::ScheduleBlockRealize* expr, Expr* op) override { + exprs_.push_back(*op); + } void Visit(const ir::For* expr, Expr* op) override { exprs_.push_back(*op); } @@ -1486,12 +1608,13 @@ class FindExprInBlock : public ir::IRMutator<> { std::vector exprs_; }; -void LoopComputeAt(ir::IRSchedule& ir_sch, - Node* node, - const Node* master, - const GroupPtr& group, - const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map) { +void LoopComputeAt( + ir::IRSchedule& ir_sch, + Node* node, + const Node* master, + const GroupPtr& group, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map) { auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); if (!group->output_nodes.count(node)) { auto block = ir_sch.GetBlock(GetNodeData(node)->id()); @@ -1505,10 +1628,10 @@ void LoopComputeAt(ir::IRSchedule& ir_sch, if (node == master) return; - auto node_data = GetNodeData(node); + auto node_data = GetNodeData(node); auto master_data = GetNodeData(master); - auto node_loops = ir_sch.GetLoops(node_data->id()); + auto node_loops = ir_sch.GetLoops(node_data->id()); auto master_loops = ir_sch.GetLoops(master_data->id()); if (op_pattern_dict[master->op()] == framework::kReduction) { @@ -1524,37 +1647,41 @@ void LoopComputeAt(ir::IRSchedule& ir_sch, } prefix = post; - post = "_" + std::to_string(idx); + post = "_" + std::to_string(idx); } - auto tensor = tensor_map.find(master_data->id() + prefix)->second; + auto tensor = tensor_map.find(master_data->id() + prefix)->second; master_loops = ir_sch.GetLoops(tensor->name); } int index = std::min(node_loops.size(), master_loops.size()) - 1; do { // if loop range is not equal. - if (node_loops[index].As()->extent.as_int32() != master_loops[index].As()->extent.as_int32()) { + if (node_loops[index].As()->extent.as_int32() != + master_loops[index].As()->extent.as_int32()) { continue; } - MergeLoops(ir_sch.GetModule().GetExprs().at(0), node_loops, master_loops, index); + MergeLoops( + ir_sch.GetModule().GetExprs().at(0), node_loops, master_loops, index); break; } while (--index >= 0); } -std::unordered_map GetNodeDataSet(const std::unordered_set& nodes_set) { +std::unordered_map GetNodeDataSet( + const std::unordered_set& nodes_set) { std::unordered_map node_data_set; for (auto node : nodes_set) { - auto node_data = GetNodeData(node); + auto node_data = GetNodeData(node); node_data_set[node_data->id()] = node_data; } return node_data_set; } -std::unordered_set GetMasters(Node* node, - const std::unordered_set& nodes_inline, - const std::unordered_set& nodes_set) { +std::unordered_set GetMasters( + Node* node, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set) { // find consumer std::unordered_set visited; std::queue candidates; @@ -1582,14 +1709,15 @@ std::unordered_set GetMasters(Node* node, return masters; } -void SyncThreadWithShared(ir::IRSchedule& ir_sch, - const GroupPtr& group, - const std::unordered_set& nodes_inline, - const std::unordered_set& nodes_set, - const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map) { - auto exprs_inorder = ir_sch.GetAllBlocks(); - auto node_data_set = GetNodeDataSet(nodes_set); +void SyncThreadWithShared( + ir::IRSchedule& ir_sch, + const GroupPtr& group, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map) { + auto exprs_inorder = ir_sch.GetAllBlocks(); + auto node_data_set = GetNodeDataSet(nodes_set); auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); std::unordered_set sync_mark; @@ -1597,8 +1725,10 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch, for (int idx = start + 1; exprs_inorder.size(); ++idx) { auto expr = exprs_inorder[idx]; CHECK(expr.As()); - CHECK(expr.As()->schedule_block.As()); - auto block = expr.As()->schedule_block.As(); + CHECK(expr.As() + ->schedule_block.As()); + auto block = expr.As() + ->schedule_block.As(); if (sync_mark.count(block->name)) { return false; @@ -1614,14 +1744,16 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch, for (int idx = 0; idx < exprs_inorder.size() - 1; ++idx) { auto expr = exprs_inorder[idx]; CHECK(expr.As()); - CHECK(expr.As()->schedule_block.As()); - auto block = expr.As()->schedule_block.As(); + CHECK(expr.As() + ->schedule_block.As()); + auto block = expr.As() + ->schedule_block.As(); if (!node_data_set.count(block->name)) { continue; } - auto node_data = node_data_set.find(block->name)->second; - auto node = node_data->source_node.get(); + auto node_data = node_data_set.find(block->name)->second; + auto node = node_data->source_node.get(); auto node_shape = shape_dict.at(node_data->id()); auto masters = GetMasters(node, nodes_inline, nodes_set); @@ -1631,14 +1763,17 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch, bool do_set_buffer_to_shared = false; for (auto master : masters) { - auto master_data = GetNodeData(master); + auto master_data = GetNodeData(master); auto master_shape = shape_dict.at(master_data->id()); if (op_pattern_dict[master->op()] == framework::kReduction) { - master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id()); + master_shape = + shape_dict.at(master->inlinks_in_order()[0]->source()->id()); } - auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); - auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); + auto node_size = std::accumulate( + node_shape.begin(), node_shape.end(), 1, std::multiplies()); + auto master_size = std::accumulate( + master_shape.begin(), master_shape.end(), 1, std::multiplies()); if (node_size != master_size) { if (check_sync_mark(idx, master_data->id())) { @@ -1649,7 +1784,8 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch, do_set_buffer_to_shared = true; } } - if (do_set_buffer_to_shared && group->output_nodes.find(node) == group->output_nodes.end()) { + if (do_set_buffer_to_shared && + group->output_nodes.find(node) == group->output_nodes.end()) { auto block = ir_sch.GetBlock(node_data->id()); ir_sch.SetBuffer(block, "shared"); } diff --git a/paddle/cinn/hlir/framework/op_lowering_util.h b/paddle/cinn/hlir/framework/op_lowering_util.h index fd1c6addb0aae..02741820db85c 100644 --- a/paddle/cinn/hlir/framework/op_lowering_util.h +++ b/paddle/cinn/hlir/framework/op_lowering_util.h @@ -24,18 +24,21 @@ namespace framework { std::vector GetInputNodeData(const Node* node); -ir::Tensor GetTensor(const NodeData* node_data, - const absl::flat_hash_map& type_dict, - const absl::flat_hash_map& shape_dict); - -std::vector CollectInputTensor(const Node* node, - std::vector& func_args, - std::unordered_map& tensor_map, - const absl::flat_hash_map& type_dict, - const absl::flat_hash_map& shape_dict); - -std::unordered_map BuildVirtualConsumer(const GroupPtr& group, - const absl::flat_hash_map& shape_dict); +ir::Tensor GetTensor( + const NodeData* node_data, + const absl::flat_hash_map& type_dict, + const absl::flat_hash_map& shape_dict); + +std::vector CollectInputTensor( + const Node* node, + std::vector& func_args, + std::unordered_map& tensor_map, + const absl::flat_hash_map& type_dict, + const absl::flat_hash_map& shape_dict); + +std::unordered_map BuildVirtualConsumer( + const GroupPtr& group, + const absl::flat_hash_map& shape_dict); NodeData* GetNodeData(const Node* node); @@ -45,17 +48,22 @@ std::vector GetConsumers(const Node* node); bool IsConstOp(const framework::Node* node); -std::vector GetConsumersInSet(const Node* node, const std::unordered_set& node_set); +std::vector GetConsumersInSet(const Node* node, + const std::unordered_set& node_set); -std::vector TopologicalOrder(const GroupPtr& group, const std::unordered_map& virtual_consumers); +std::vector TopologicalOrder( + const GroupPtr& group, + const std::unordered_map& virtual_consumers); -std::vector BFSTopologicalOrderWithPriority(const GroupPtr& group, - const std::unordered_map& virtual_consumers, - const absl::flat_hash_map& shape_dict); +std::vector BFSTopologicalOrderWithPriority( + const GroupPtr& group, + const std::unordered_map& virtual_consumers, + const absl::flat_hash_map& shape_dict); Node* FindGlobalReducer(const std::vector& nodes_in_order); -Node* FindNearestReducer(const Node* node, const std::unordered_set& nodes_set); +Node* FindNearestReducer(const Node* node, + const std::unordered_set& nodes_set); bool CanbeInline(Node* node, const std::vector consumers, @@ -65,37 +73,42 @@ bool CanbeInline(Node* node, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict); -Node* GetMasterToComputeAt(Node* node, - const std::vector& nodes_in_order, - const std::unordered_set& nodes_inline, - const std::unordered_set& nodes_set, - const std::unordered_map& virtual_consumers, - const absl::flat_hash_map& shape_dict); - -std::unordered_set GetMasters(Node* node, - const std::unordered_set& nodes_inline, - const std::unordered_set& nodes_set); - -void LoopAssignReduce(ir::IRSchedule& ir_sch, - const Node* node, - const Node* reducer, - const Target& target, - const std::unordered_map& tensor_map, - const absl::flat_hash_map& shape_dict); - -void LoopComputeAt(ir::IRSchedule& ir_sch, - Node* node, - const Node* master, - const GroupPtr& group, - const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map); - -void SyncThreadWithShared(ir::IRSchedule& ir_sch, - const GroupPtr& group, - const std::unordered_set& nodes_inline, - const std::unordered_set& nodes_set, - const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map); +Node* GetMasterToComputeAt( + Node* node, + const std::vector& nodes_in_order, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set, + const std::unordered_map& virtual_consumers, + const absl::flat_hash_map& shape_dict); + +std::unordered_set GetMasters( + Node* node, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set); + +void LoopAssignReduce( + ir::IRSchedule& ir_sch, + const Node* node, + const Node* reducer, + const Target& target, + const std::unordered_map& tensor_map, + const absl::flat_hash_map& shape_dict); + +void LoopComputeAt( + ir::IRSchedule& ir_sch, + Node* node, + const Node* master, + const GroupPtr& group, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map); + +void SyncThreadWithShared( + ir::IRSchedule& ir_sch, + const GroupPtr& group, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map); } // namespace framework } // namespace hlir diff --git a/paddle/cinn/hlir/framework/op_strategy.cc b/paddle/cinn/hlir/framework/op_strategy.cc index 825a519987e9f..b832875036c04 100644 --- a/paddle/cinn/hlir/framework/op_strategy.cc +++ b/paddle/cinn/hlir/framework/op_strategy.cc @@ -17,9 +17,10 @@ namespace cinn { namespace hlir { namespace framework { -std::shared_ptr OpStrategy::SelectImpl(const std::shared_ptr& strategy) { +std::shared_ptr OpStrategy::SelectImpl( + const std::shared_ptr& strategy) { //! should get the host info from global environment. - std::string curr_condition = "default"; + std::string curr_condition = "default"; std::shared_ptr res = nullptr; for (auto& spec : strategy->specializations) { if (spec->condition == "default") { @@ -30,11 +31,15 @@ std::shared_ptr OpStrategy::SelectImpl(const std::shared_ptr } } } - CHECK(res) << "There is no available strategy implementation! SelectImpl failed!"; + CHECK(res) + << "There is no available strategy implementation! SelectImpl failed!"; return res; } -void OpStrategy::AddImpl(CINNCompute fcompute, CINNSchedule fschedule, std::string name, int plevel) { +void OpStrategy::AddImpl(CINNCompute fcompute, + CINNSchedule fschedule, + std::string name, + int plevel) { //! TODO(haozech) : here curr_cond should get the condition from outside. //! Expected : auto curr_cond = SpecializedCondition::Current(); std::string curr_condition = "default"; @@ -45,7 +50,7 @@ void OpStrategy::AddImpl(CINNCompute fcompute, CINNSchedule fschedule, std::stri } } std::shared_ptr n = std::make_shared(); - n->condition = curr_condition; + n->condition = curr_condition; n->AddImpl(fcompute, fschedule, std::move(name), plevel); this->specializations.push_back(n); } diff --git a/paddle/cinn/hlir/framework/op_strategy.h b/paddle/cinn/hlir/framework/op_strategy.h index 96d053ca5c980..b782e943b2c21 100644 --- a/paddle/cinn/hlir/framework/op_strategy.h +++ b/paddle/cinn/hlir/framework/op_strategy.h @@ -26,18 +26,19 @@ namespace cinn { namespace hlir { namespace framework { -using CINNCompute = lang::PackedFunc; +using CINNCompute = lang::PackedFunc; using CINNSchedule = lang::PackedFunc; class OpStrategy; -using StrategyFunction = std::function(const NodeAttr&, - const std::vector&, - const std::vector&, - const std::vector>&, - const common::Target&)>; -using InferShapeFunction = - std::function>(const std::vector>&, const AttrMapType&)>; +using StrategyFunction = std::function( + const NodeAttr&, + const std::vector&, + const std::vector&, + const std::vector>&, + const common::Target&)>; +using InferShapeFunction = std::function>( + const std::vector>&, const AttrMapType&)>; //! Operator implementation that includes compute and schedule function. class OpImpl : public common::Object { @@ -57,7 +58,8 @@ class OpImpl : public common::Object { * @param out_type The output type information. * @return The output compute description of the operator. */ - ir::Tensor Compute(const std::vector& inputs, const Type& out_type) { + ir::Tensor Compute(const std::vector& inputs, + const Type& out_type) { // TODO(haozech) : add support for packedfunc to return Tensor // Expected : return this->fcompute(inputs, out_type); ir::Tensor temp; @@ -70,9 +72,10 @@ class OpImpl : public common::Object { * @param target The build target. * @return The computation schedule. */ - common::Shared GetSchedule(const std::vector& outs, - const std::vector& temp_tensors, - const Target& target) { + common::Shared GetSchedule( + const std::vector& outs, + const std::vector& temp_tensors, + const Target& target) { // TODO(haozech) : add support for packedfunc to return Schedule // Expected : return this->fschedule(outs, target); return nullptr; @@ -92,19 +95,22 @@ class OpSpec : public common::Object { /** \brief Condition to enable the specialization. * Could be undefined to represent generic case. - * TODO(haozech) : build a specified class SpecializedCondition to represent the condition. - * Expected : SpecializedCondition condition; + * TODO(haozech) : build a specified class SpecializedCondition to represent + * the condition. Expected : SpecializedCondition condition; */ std::string condition; const char* type_info() const override { return __type_info__; } - void AddImpl(CINNCompute fcompute, CINNSchedule fschedule, std::string name, int plevel) { - auto n = std::make_shared(); - n->fcompute = fcompute; + void AddImpl(CINNCompute fcompute, + CINNSchedule fschedule, + std::string name, + int plevel) { + auto n = std::make_shared(); + n->fcompute = fcompute; n->fschedule = fschedule; - n->name = std::move(name); - n->plevel = plevel; + n->name = std::move(name); + n->plevel = plevel; this->implementations.push_back(n); } @@ -126,8 +132,12 @@ class OpStrategy : public common::Object { * @param name Name of the implementation * @param plevel Priority level of the implementation */ - void AddImpl(CINNCompute fcompute, CINNSchedule fschedule, std::string name, int plevel); - static std::shared_ptr SelectImpl(const std::shared_ptr& strategy); + void AddImpl(CINNCompute fcompute, + CINNSchedule fschedule, + std::string name, + int plevel); + static std::shared_ptr SelectImpl( + const std::shared_ptr& strategy); private: static constexpr char* __type_info__ = "OpStrategy"; diff --git a/paddle/cinn/hlir/framework/op_test.cc b/paddle/cinn/hlir/framework/op_test.cc index f883c41550b6c..b3d967c031fcc 100644 --- a/paddle/cinn/hlir/framework/op_test.cc +++ b/paddle/cinn/hlir/framework/op_test.cc @@ -33,10 +33,11 @@ namespace cinn { namespace hlir { namespace framework { -using CCompute = std::function(const std::vector)>; +using CCompute = + std::function(const std::vector)>; TEST(Operator, GetAttrs) { - auto add = Operator::Get("elementwise_add"); + auto add = Operator::Get("elementwise_add"); Operator temp = *add; auto strategy = Operator::GetAttrs("CINNStrategy"); @@ -48,7 +49,8 @@ TEST(Operator, GetAttrs) { std::vector inputs{A, B}; std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); - auto impl = OpStrategy::SelectImpl(strategy[add](attrs, inputs, type, {{100, 32}}, target)); + auto impl = OpStrategy::SelectImpl( + strategy[add](attrs, inputs, type, {{100, 32}}, target)); ASSERT_EQ(impl->name, "strategy.elementwise_add.x86"); ASSERT_EQ(add->description, "elementwise_add function"); @@ -58,17 +60,22 @@ TEST(Operator, GetAttrs) { if (FLAGS_cinn_ir_schedule) { std::string out_name = "C"; common::CINNValuePack cinn_input = - common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B), common::CINNValue(out_name)}}; + common::CINNValuePack{{common::CINNValue(A), + common::CINNValue(B), + common::CINNValue(out_name)}}; std::vector input_output_names{"A", "B", out_name}; - auto funcs = framework::GetFuncFromImpl(impl, cinn_input, inputs, input_output_names, func_name, target); + auto funcs = framework::GetFuncFromImpl( + impl, cinn_input, inputs, input_output_names, func_name, target); for (auto func : funcs) { - LOG(INFO) << "Test Operator_ElementWise_Add_Test0's Strategy, func is :\n" << func; + LOG(INFO) << "Test Operator_ElementWise_Add_Test0's Strategy, func is :\n" + << func; } } else { - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); + common::CINNValuePack cinn_input = + common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); ASSERT_EQ(rets.size(), 2UL); rets = impl->fschedule(rets); ASSERT_EQ(rets.size(), 2UL); diff --git a/paddle/cinn/hlir/framework/parallel_compiler.cc b/paddle/cinn/hlir/framework/parallel_compiler.cc index a5c233e0d14a5..5004b6de8b5e3 100644 --- a/paddle/cinn/hlir/framework/parallel_compiler.cc +++ b/paddle/cinn/hlir/framework/parallel_compiler.cc @@ -51,12 +51,15 @@ std::vector> ParallelCompiler::operator()() { } OpPatternKind GetOpKind(const framework::Node* node) { - auto& op_pattern_dict = framework::Operator::GetAttrs("OpPattern"); - CHECK(op_pattern_dict.Find(node->op())) << "Don't find the pattern of op : " << node->id(); + auto& op_pattern_dict = + framework::Operator::GetAttrs("OpPattern"); + CHECK(op_pattern_dict.Find(node->op())) + << "Don't find the pattern of op : " << node->id(); auto kind = op_pattern_dict[node->op()]; if (kind == framework::kBroadcast) { - // As binary op was defined as broadcast, actually it should be element-wise. + // As binary op was defined as broadcast, actually it should be + // element-wise. if (node->op()->name != "broadcast_to") { return framework::kElementWise; } @@ -67,16 +70,19 @@ OpPatternKind GetOpKind(const framework::Node* node) { void ParallelCompiler::SplitTask() { CHECK(graph_->fusion_groups.size()); - CHECK(graph_->fusion_groups.size() == option_.lowered_funcs.size() || option_.lowered_funcs.size() == 0); + CHECK(graph_->fusion_groups.size() == option_.lowered_funcs.size() || + option_.lowered_funcs.size() == 0); // split task - int max_task_num = - FLAGS_cinn_parallel_compile_thread > 0 ? FLAGS_cinn_parallel_compile_thread : graph_->fusion_groups.size(); + int max_task_num = FLAGS_cinn_parallel_compile_thread > 0 + ? FLAGS_cinn_parallel_compile_thread + : graph_->fusion_groups.size(); int group_per_task = graph_->fusion_groups.size(); if (max_task_num > 1) { group_per_task = FLAGS_cinn_parallel_compile_size > 0 ? FLAGS_cinn_parallel_compile_size - : ((graph_->fusion_groups.size() + max_task_num - 1) / max_task_num); + : ((graph_->fusion_groups.size() + max_task_num - 1) / + max_task_num); } for (int idx = 0; idx < graph_->fusion_groups.size(); idx += group_per_task) { @@ -124,8 +130,12 @@ void ParallelCompiler::Task::Lowering() { if (options.lowered_funcs.size()) { CHECK_EQ(options.lowered_funcs.size(), graph->fusion_groups.size()); } - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); + auto& dtype_dict = + graph->GetMutableAttrs>( + "inferdtype"); + auto& shape_dict = + graph->GetMutableAttrs>( + "infershape"); OpLowerer op_lowerer(dtype_dict, shape_dict, target); while (true) { @@ -140,16 +150,19 @@ void ParallelCompiler::Task::Lowering() { continue; } auto& group = graph->fusion_groups[idx]; - VLOG(1) << "Start Lowering Group " << idx << " at " << std::this_thread::get_id() << " :\n" + VLOG(1) << "Start Lowering Group " << idx << " at " + << std::this_thread::get_id() << " :\n" << "Group " << idx << " {\n" << graph->DebugGroupedGraph(group->CollectNodes()) << "}\n"; lowered_funcs.emplace_back(std::move(op_lowerer.Lower(group))); - CHECK_EQ(lowered_funcs.back().size(), 1) << "Lowerd Function Is Not Equal 1!"; + CHECK_EQ(lowered_funcs.back().size(), 1) + << "Lowerd Function Is Not Equal 1!"; } } void ParallelCompiler::Task::CodegenAndJit() { - VLOG(2) << "Start Codegen and JIT with Group [" << cinn::utils::Join(this->gidx, ", ") << "] at " + VLOG(2) << "Start Codegen and JIT with Group [" + << cinn::utils::Join(this->gidx, ", ") << "] at " << std::this_thread::get_id(); // build module ir::Module::Builder builder(common::UniqName("module"), target); @@ -163,14 +176,15 @@ void ParallelCompiler::Task::CodegenAndJit() { if (target == common::DefaultNVGPUTarget()) { #ifdef CINN_WITH_CUDA auto splited_module = backends::SplitCudaAndHostModule(ir_module); - auto hmodule = std::get<0>(splited_module); - auto dmodule = std::get<1>(splited_module); + auto hmodule = std::get<0>(splited_module); + auto dmodule = std::get<1>(splited_module); VLOG(3) << "Host Code:\n" << hmodule; VLOG(3) << "Device Code:\n" << dmodule; backends::CodeGenCUDA_Dev codegen(target); auto cuda_c = codegen.Compile(dmodule); - CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n" << dmodule; + CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n" + << dmodule; cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c); graph->SaveSourceCode(cuda_c); @@ -180,7 +194,10 @@ void ParallelCompiler::Task::CodegenAndJit() { auto ptx = compiler(cuda_c); CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c; // load cumodule - cumodule.reset(new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX)); + cumodule.reset(new CUDAModule(ptx, + compiler.compile_to_cubin() + ? CUDAModule::Kind::CUBIN + : CUDAModule::Kind::PTX)); // register kernel backends::RuntimeSymbols symbols; @@ -189,7 +206,8 @@ void ParallelCompiler::Task::CodegenAndJit() { CHECK(cufunc); symbols.RegisterVar(fn->name + "_ptr_", reinterpret_cast(cufunc)); } - engine = backends::ExecutionEngine::Create(backends::ExecutionOptions(), std::move(symbols)); + engine = backends::ExecutionEngine::Create(backends::ExecutionOptions(), + std::move(symbols)); engine->Link(hmodule); #endif } else { @@ -201,15 +219,21 @@ void ParallelCompiler::Task::CodegenAndJit() { void ParallelCompiler::Task::BuildInstruction() { // create instruction. for (int idx : gidx) { - VLOG(2) << "Start BuildInstruction of Group " << idx << " at " << std::this_thread::get_id(); + VLOG(2) << "Start BuildInstruction of Group " << idx << " at " + << std::this_thread::get_id(); auto& group = graph->fusion_groups[idx]; CHECK(group->input_names.size() > 0 || group->output_names.size() > 0); - auto instr = std::unique_ptr( - new Instruction(target, scope.get(), group->input_names, group->output_names, group->GetFuncName())); + auto instr = + std::unique_ptr(new Instruction(target, + scope.get(), + group->input_names, + group->output_names, + group->GetFuncName())); auto fn_ptr = engine->Lookup(group->GetFuncName()); CHECK(fn_ptr) << "Can't find jit function : " << group->GetFuncName(); - instr->SetLoweredFunc(reinterpret_cast(fn_ptr), group->GetFuncName()); + instr->SetLoweredFunc(reinterpret_cast(fn_ptr), + group->GetFuncName()); instr->Finalize(); instructions.push_back(std::move(instr)); diff --git a/paddle/cinn/hlir/framework/parallel_compiler_test.cc b/paddle/cinn/hlir/framework/parallel_compiler_test.cc index e4f295a8fee84..ee70f016f33bf 100644 --- a/paddle/cinn/hlir/framework/parallel_compiler_test.cc +++ b/paddle/cinn/hlir/framework/parallel_compiler_test.cc @@ -29,13 +29,13 @@ using namespace frontend; TEST(ParallelCompilerTest, Add_TEST_0) { frontend::NetBuilder builder("Add_TEST_0"); - auto A = builder.CreateInput(Float(32), {128, 128}, "A"); - auto B = builder.CreateInput(Float(32), {128, 128}, "B"); - auto C = builder.Add(A, B); - auto target = common::DefaultNVGPUTarget(); + auto A = builder.CreateInput(Float(32), {128, 128}, "A"); + auto B = builder.CreateInput(Float(32), {128, 128}, "B"); + auto C = builder.Add(A, B); + auto target = common::DefaultNVGPUTarget(); auto program = builder.Build(); - auto graph = std::make_shared(program, target); - auto scope = BuildScope(target, graph); + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); ParallelCompiler::CompileOptions option; ParallelCompiler pc(scope, graph, option, target); @@ -50,10 +50,10 @@ TEST(ParallelCompilerTest, Conv2d_Test_0) { auto D = builder.Conv2d(A, B, {2, 2}, {1, 1}); auto E = builder.Add(C, D); - auto target = common::DefaultNVGPUTarget(); + auto target = common::DefaultNVGPUTarget(); auto program = builder.Build(); - auto graph = Optimize(&program, {}, target); - auto scope = BuildScope(target, graph); + auto graph = Optimize(&program, {}, target); + auto scope = BuildScope(target, graph); ParallelCompiler::CompileOptions option; ParallelCompiler pc(scope, graph, option, target); @@ -68,10 +68,10 @@ TEST(ParallelCompilerTest, Matmul_Test_0) { auto D = builder.Matmul(A, B); auto E = builder.Add(C, D); - auto target = common::DefaultNVGPUTarget(); + auto target = common::DefaultNVGPUTarget(); auto program = builder.Build(); - auto graph = Optimize(&program, {}, target); - auto scope = BuildScope(target, graph); + auto graph = Optimize(&program, {}, target); + auto scope = BuildScope(target, graph); ParallelCompiler::CompileOptions option; ParallelCompiler pc(scope, graph, option, target); diff --git a/paddle/cinn/hlir/framework/pass.cc b/paddle/cinn/hlir/framework/pass.cc index 978c52d4609bb..d10490e607ed1 100644 --- a/paddle/cinn/hlir/framework/pass.cc +++ b/paddle/cinn/hlir/framework/pass.cc @@ -32,11 +32,13 @@ void ApplyPasses(Graph* g, const std::vector& passes) { for (auto* r : fpass) { cinn::hlir::framework::PassPrinter::GetInstance()->PassBegin(r->name, g); for (auto& dep : r->graph_attr_dependency) { - CHECK_NE(g->attrs.count(dep), 0) << "To apply pass [" << r->name << "], Graph's attribute [" << dep - << "] is required, but it is not available."; + CHECK_NE(g->attrs.count(dep), 0) + << "To apply pass [" << r->name << "], Graph's attribute [" << dep + << "] is required, but it is not available."; if (g->attrs.count(dep) == 0) { auto* pass_dep = FindPassDep(dep); - CHECK(!pass_dep) << "And the attribute is provided by pass [" << pass_dep->name << "]."; + CHECK(!pass_dep) << "And the attribute is provided by pass [" + << pass_dep->name << "]."; } } r->body(g); diff --git a/paddle/cinn/hlir/framework/pass.h b/paddle/cinn/hlir/framework/pass.h index bfd64a0e27ef3..1a4e32e0d0fc3 100644 --- a/paddle/cinn/hlir/framework/pass.h +++ b/paddle/cinn/hlir/framework/pass.h @@ -20,8 +20,10 @@ #include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/utils/registry.h" -#define CINN_REGISTER_PASS(name) \ - CINN_REGISTRY_REGISTER(::cinn::hlir::framework::PassFunctionRegister, PassFunctionRegister, name) +#define CINN_REGISTER_PASS(name) \ + CINN_REGISTRY_REGISTER(::cinn::hlir::framework::PassFunctionRegister, \ + PassFunctionRegister, \ + name) namespace cinn { namespace hlir { @@ -31,13 +33,15 @@ class PassFunctionRegister; typedef std::function PassFunction; /** - * \brief Given an attribute of graph, find the pass that generates this attribute. + * \brief Given an attribute of graph, find the pass that generates this + * attribute. * @param attr_name Name of the graph attribute. * @return The pass that generates it. */ const PassFunctionRegister* FindPassDep(const std::string& attr_name); -class PassFunctionRegister : public FunctionRegEntryBase { +class PassFunctionRegister + : public FunctionRegEntryBase { public: bool change_structure{false}; //! dependencies on operator attributes @@ -49,7 +53,8 @@ class PassFunctionRegister : public FunctionRegEntryBase& passes); // Apply a single pass on a graph. -inline void ApplyPass(Graph* g, const std::string& pass) { return ApplyPasses(g, {pass}); } +inline void ApplyPass(Graph* g, const std::string& pass) { + return ApplyPasses(g, {pass}); +} } // namespace framework } // namespace hlir diff --git a/paddle/cinn/hlir/framework/print_graph_pass_test.cc b/paddle/cinn/hlir/framework/print_graph_pass_test.cc index d74afa6ed86f9..cc3d51c4f79c0 100644 --- a/paddle/cinn/hlir/framework/print_graph_pass_test.cc +++ b/paddle/cinn/hlir/framework/print_graph_pass_test.cc @@ -32,7 +32,7 @@ namespace framework { void PrintGraphPass(Graph* src) { std::string res; auto store_node = std::get<0>(src->topological_order()); - int index = 0; + int index = 0; for (auto& i : store_node) { if (i->is_type()) { res += std::to_string(index) + ":"; @@ -45,7 +45,9 @@ void PrintGraphPass(Graph* src) { } CINN_REGISTER_PASS(PrintGraph) - .describe("This pass just save the visulization Graph to g.attrs[\"print_graph\"].") + .describe( + "This pass just save the visulization Graph to " + "g.attrs[\"print_graph\"].") .set_change_structure(false) .provide_graph_attr("print_graph") .set_body(PrintGraphPass); @@ -54,12 +56,12 @@ TEST(Operator, GetAttrs) { frontend::Program prog; frontend::Variable a("A"); frontend::Variable b("B"); - Type t = Float(32); + Type t = Float(32); a->type = t; b->type = t; - auto c = prog.add(a, b); - auto d = prog.add(c, b); - auto e = prog.add(c, d); + auto c = prog.add(a, b); + auto d = prog.add(c, b); + auto e = prog.add(c, d); ASSERT_EQ(prog.size(), 3); Graph* g = new Graph(prog, common::DefaultHostTarget()); ApplyPass(g, "PrintGraph"); diff --git a/paddle/cinn/hlir/framework/schedule.h b/paddle/cinn/hlir/framework/schedule.h index 36dc5186a9c53..3fe12f5afae7c 100644 --- a/paddle/cinn/hlir/framework/schedule.h +++ b/paddle/cinn/hlir/framework/schedule.h @@ -40,7 +40,8 @@ class Schedule : public common::Object { */ ir::Tensor operator[](const ir::Operation& op) { auto it = stage_map.find(op.name); - CHECK(it != stage_map.end()) << "Cannot find Stage for operator " << op.name << " in the schedule"; + CHECK(it != stage_map.end()) + << "Cannot find Stage for operator " << op.name << " in the schedule"; return it->second; } diff --git a/paddle/cinn/hlir/framework/scope_test.cc b/paddle/cinn/hlir/framework/scope_test.cc index 99cc296c887b2..23ac65469af9a 100644 --- a/paddle/cinn/hlir/framework/scope_test.cc +++ b/paddle/cinn/hlir/framework/scope_test.cc @@ -22,13 +22,13 @@ namespace framework { TEST(Scope, basic) { Scope scope; - auto* var = scope.Var("key"); + auto* var = scope.Var("key"); auto& tensor = absl::get(*var); tensor->Resize(Shape{{3, 1}}); auto* data = tensor->mutable_data(common::DefaultHostTarget()); - data[0] = 0.f; - data[1] = 1.f; - data[2] = 2.f; + data[0] = 0.f; + data[1] = 1.f; + data[2] = 2.f; } TEST(ScopeTest, TestEraseVar) { diff --git a/paddle/cinn/hlir/framework/tensor.h b/paddle/cinn/hlir/framework/tensor.h index 6f35148c407fa..7b5d201d0f0ae 100644 --- a/paddle/cinn/hlir/framework/tensor.h +++ b/paddle/cinn/hlir/framework/tensor.h @@ -39,11 +39,14 @@ struct Shape { void SetData(const std::vector& data) { data_ = data; } - const std::vector& data() const CINN_RESULT_SHOULD_USE { return data_; } + const std::vector& data() const CINN_RESULT_SHOULD_USE { + return data_; + } std::vector& data() CINN_RESULT_SHOULD_USE { return data_; } size_t size() const CINN_RESULT_SHOULD_USE { return data_.size(); } uint32_t numel() const CINN_RESULT_SHOULD_USE { - return std::accumulate(data_.begin(), data_.end(), 1, [](dim_t a, dim_t b) { return a * b; }); + return std::accumulate( + data_.begin(), data_.end(), 1, [](dim_t a, dim_t b) { return a * b; }); } private: @@ -58,7 +61,9 @@ class _Tensor_ : public Object { void Resize(const Shape& shape) { shape_ = shape; - buffer_->data()->resize(reinterpret_cast(shape.data().data()), shape.size()); + buffer_->data()->resize( + reinterpret_cast(shape.data().data()), + shape.size()); } inline void* mutable_data(const Target& target, const Type& type) { diff --git a/paddle/cinn/hlir/framework/visualize_helper.cc b/paddle/cinn/hlir/framework/visualize_helper.cc index 214637655fde7..a310ac2a0fb8a 100644 --- a/paddle/cinn/hlir/framework/visualize_helper.cc +++ b/paddle/cinn/hlir/framework/visualize_helper.cc @@ -38,44 +38,55 @@ namespace framework { bool PassPrinter::Begin(const std::unordered_set& fetch_ids) { if (FLAGS_cinn_pass_visualize_dir.empty()) { - VLOG(3) << "No set \"FLAGS_cinn_pass_visualize_dir\", the pass visualize information will print directly."; + VLOG(3) << "No set \"FLAGS_cinn_pass_visualize_dir\", the pass visualize " + "information will print directly."; save_path_.clear(); return false; } - pass_id_ = 0; + pass_id_ = 0; fetch_ids_ = fetch_ids; - save_path_ = utils::StringFormat("%s/fusion_groups_%d/", FLAGS_cinn_pass_visualize_dir.c_str(), graph_id_); - if (!MakeDirectory(save_path_, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) { - LOG_IF(WARNING, graph_id_ == 0) << "Failed to make directory: \"" << save_path_ - << "\", the CINN subgraph's pass visualize information will not print."; + save_path_ = utils::StringFormat( + "%s/fusion_groups_%d/", FLAGS_cinn_pass_visualize_dir.c_str(), graph_id_); + if (!MakeDirectory(save_path_, + S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) { + LOG_IF(WARNING, graph_id_ == 0) + << "Failed to make directory: \"" << save_path_ + << "\", the CINN subgraph's pass visualize information will not print."; return false; } - LOG_IF(INFO, graph_id_ == 0) << "The CINN subgraph's pass visualize information will writing into path: \"" + LOG_IF(INFO, graph_id_ == 0) << "The CINN subgraph's pass visualize " + "information will writing into path: \"" << FLAGS_cinn_pass_visualize_dir << "\""; return true; } -bool PassPrinter::PassBegin(const std::string& pass_name, const frontend::Program& program) { +bool PassPrinter::PassBegin(const std::string& pass_name, + const frontend::Program& program) { if (save_path_.empty()) { return false; } const auto& program_info = utils::GetStreamCnt(program); VLOG(3) << "Before " << pass_name << " Pass:\n" << program_info; - const std::string& file_path = - utils::StringFormat("%s/pass_%d_%s_before.txt", save_path_.c_str(), pass_id_, pass_name.c_str()); + const std::string& file_path = utils::StringFormat("%s/pass_%d_%s_before.txt", + save_path_.c_str(), + pass_id_, + pass_name.c_str()); WriteToFile(file_path, program_info); return true; } -bool PassPrinter::PassEnd(const std::string& pass_name, const frontend::Program& program) { +bool PassPrinter::PassEnd(const std::string& pass_name, + const frontend::Program& program) { if (save_path_.empty()) { return false; } const auto& program_info = utils::GetStreamCnt(program); VLOG(3) << "After " << pass_name << " Pass:\n" << program_info; - const std::string& file_path = - utils::StringFormat("%s/pass_%d_%s_after.txt", save_path_.c_str(), pass_id_, pass_name.c_str()); + const std::string& file_path = utils::StringFormat("%s/pass_%d_%s_after.txt", + save_path_.c_str(), + pass_id_, + pass_name.c_str()); WriteToFile(file_path, program_info); ++pass_id_; @@ -88,13 +99,17 @@ bool PassPrinter::PassBegin(const std::string& pass_name, Graph* g) { } const auto& graph_info = g->DebugGroupedGraph(fetch_ids_); VLOG(3) << "Before " << pass_name << " Pass:\n" << graph_info; - const std::string& file_path = - utils::StringFormat("%s/pass_%d_%s_before.txt", save_path_.c_str(), pass_id_, pass_name.c_str()); + const std::string& file_path = utils::StringFormat("%s/pass_%d_%s_before.txt", + save_path_.c_str(), + pass_id_, + pass_name.c_str()); WriteToFile(file_path, graph_info); const auto& dot_info = g->VisualizeGraph(fetch_ids_); - const std::string& dot_path = - utils::StringFormat("%s/pass_%d_%s_before.dot", save_path_.c_str(), pass_id_, pass_name.c_str()); + const std::string& dot_path = utils::StringFormat("%s/pass_%d_%s_before.dot", + save_path_.c_str(), + pass_id_, + pass_name.c_str()); WriteToFile(dot_path, dot_info); return true; } @@ -106,13 +121,17 @@ bool PassPrinter::PassEnd(const std::string& pass_name, Graph* g) { const auto& graph_info = g->DebugGroupedGraph(fetch_ids_); VLOG(3) << "After " << pass_name << " Pass:\n" << graph_info; - const std::string& file_path = - utils::StringFormat("%s/pass_%d_%s_after.txt", save_path_.c_str(), pass_id_, pass_name.c_str()); + const std::string& file_path = utils::StringFormat("%s/pass_%d_%s_after.txt", + save_path_.c_str(), + pass_id_, + pass_name.c_str()); WriteToFile(file_path, graph_info); const auto& dot_info = g->VisualizeGraph(fetch_ids_); - const std::string& dot_path = - utils::StringFormat("%s/pass_%d_%s_after.dot", save_path_.c_str(), pass_id_, pass_name.c_str()); + const std::string& dot_path = utils::StringFormat("%s/pass_%d_%s_after.dot", + save_path_.c_str(), + pass_id_, + pass_name.c_str()); WriteToFile(dot_path, dot_info); ++pass_id_; @@ -154,21 +173,22 @@ std::string GetFilePathForGroup(const std::vector>& groups, filename += "_" + node->id(); } - int max_len = 50; + int max_len = 50; std::string simplified_filename = filename; if (filename.size() > max_len) { - static std::unordered_map funcname_map = {{"const_scalar", "scalar"}, - {"fill_constant", "fill"}, - {"identity", "copy"}, - {"broadcast_to", "broadcast"}, - {"elementwise_add", "add"}, - {"subtract", "sub"}, - {"elementwise_mul", "mul"}, - {"divide", "div"}, - {"reduce_sum", "reduce"}, - {"reduce_prod", "reduce"}, - {"reduce_max", "reduce"}, - {"reduce_min", "reduce"}}; + static std::unordered_map funcname_map = { + {"const_scalar", "scalar"}, + {"fill_constant", "fill"}, + {"identity", "copy"}, + {"broadcast_to", "broadcast"}, + {"elementwise_add", "add"}, + {"subtract", "sub"}, + {"elementwise_mul", "mul"}, + {"divide", "div"}, + {"reduce_sum", "reduce"}, + {"reduce_prod", "reduce"}, + {"reduce_max", "reduce"}, + {"reduce_min", "reduce"}}; for (auto& item : funcname_map) { size_t index = 0; while (true) { @@ -190,10 +210,11 @@ std::string GetFilePathForGroup(const std::vector>& groups, return ss.str(); } -std::string GenNodeDataLabel(const NodeData* node, - const absl::flat_hash_map& shape_dict, - const absl::flat_hash_map& dtype_dict, - const std::string dot_nodedata_id) { +std::string GenNodeDataLabel( + const NodeData* node, + const absl::flat_hash_map& shape_dict, + const absl::flat_hash_map& dtype_dict, + const std::string dot_nodedata_id) { std::stringstream ss; ss << dot_nodedata_id; if (shape_dict.count(node->id())) { @@ -215,7 +236,8 @@ std::string GenNodeDataLabel(const NodeData* node, return ss.str(); } -void Summary(const std::vector>& groups, const std::string& viz_path) { +void Summary(const std::vector>& groups, + const std::string& viz_path) { std::map group_summary; std::map single_group_detail; std::map fusion_group_detail; @@ -226,7 +248,7 @@ void Summary(const std::vector>& groups, const std::string& v if (group_size == 1) { // Like "fill_constant_1", remove the "_1" at the end of the string. std::string node_id = group[0]->id(); - int index = node_id.size() - 1; + int index = node_id.size() - 1; while (index != -1) { if (node_id[index] >= '0' && node_id[index] <= '9') { index--; @@ -313,8 +335,8 @@ std::string DebugString(const Node* node) { } std::stringstream ss; - ss << cinn::utils::Join(out_names, ", ") << " = builder." << node->op()->name << "(" - << cinn::utils::Join(in_names, ", "); + ss << cinn::utils::Join(out_names, ", ") << " = builder." << node->op()->name + << "(" << cinn::utils::Join(in_names, ", "); bool first = true; std::map attr_str_map; @@ -353,17 +375,18 @@ void FindRecomputeNodes(const std::vector>& groups, } } -void AddGroupNode(const Node* node, - const std::string& dot_cluster_id, - const std::unordered_set& fetch_var_ids, - const absl::flat_hash_map& shape_dict, - const absl::flat_hash_map& dtype_dict, - std::unordered_map* recompute_nodes, - std::unordered_map* outnode2dot_id, - std::unordered_set* nodedatas_set, - utils::DotLang* dot) { +void AddGroupNode( + const Node* node, + const std::string& dot_cluster_id, + const std::unordered_set& fetch_var_ids, + const absl::flat_hash_map& shape_dict, + const absl::flat_hash_map& dtype_dict, + std::unordered_map* recompute_nodes, + std::unordered_map* outnode2dot_id, + std::unordered_set* nodedatas_set, + utils::DotLang* dot) { bool is_recomputed = recompute_nodes->count(node->id()); - int recompute_id = is_recomputed ? (*recompute_nodes)[node->id()]++ : -1; + int recompute_id = is_recomputed ? (*recompute_nodes)[node->id()]++ : -1; std::string dot_node_id = GenNodeId(node, is_recomputed, recompute_id); dot->AddNode(dot_node_id, GetGroupOpAttrs(is_recomputed), "", dot_cluster_id); @@ -376,8 +399,13 @@ void AddGroupNode(const Node* node, } std::string dot_innode_id = outnode2dot_id->at(innode->id()); if (!nodedatas_set || !nodedatas_set->count(dot_innode_id)) { - std::string label = GenNodeDataLabel(innode, shape_dict, dtype_dict, dot_innode_id); - dot->AddNode(dot_innode_id, GetGroupVarAttrs(false), label, dot_cluster_id, true); + std::string label = + GenNodeDataLabel(innode, shape_dict, dtype_dict, dot_innode_id); + dot->AddNode(dot_innode_id, + GetGroupVarAttrs(false), + label, + dot_cluster_id, + true); if (nodedatas_set) { nodedatas_set->insert(dot_innode_id); } @@ -389,12 +417,18 @@ void AddGroupNode(const Node* node, for (auto& outlink : node->outlinks()) { auto* outnode = outlink->sink()->safe_as(); if (outnode) { - std::string dot_outnode_id = GenNodeDataId(outnode, is_recomputed, recompute_id); + std::string dot_outnode_id = + GenNodeDataId(outnode, is_recomputed, recompute_id); (*outnode2dot_id)[outnode->id()] = dot_outnode_id; if (!nodedatas_set || !nodedatas_set->count(dot_outnode_id)) { - bool is_fetched = fetch_var_ids.count(outnode->id()); - std::string label = GenNodeDataLabel(outnode, shape_dict, dtype_dict, dot_outnode_id); - dot->AddNode(dot_outnode_id, GetGroupVarAttrs(is_fetched), label, dot_cluster_id, true); + bool is_fetched = fetch_var_ids.count(outnode->id()); + std::string label = + GenNodeDataLabel(outnode, shape_dict, dtype_dict, dot_outnode_id); + dot->AddNode(dot_outnode_id, + GetGroupVarAttrs(is_fetched), + label, + dot_cluster_id, + true); if (nodedatas_set) { nodedatas_set->insert(dot_outnode_id); } @@ -404,8 +438,12 @@ void AddGroupNode(const Node* node, } } -bool IsAccCheckOp(const Node* op) { return op->attrs.node_name.find("_acc_check") != std::string::npos; } -bool IsAccCheckVar(const NodeData* var) { return var->id().find("_acc_check") != std::string::npos; } +bool IsAccCheckOp(const Node* op) { + return op->attrs.node_name.find("_acc_check") != std::string::npos; +} +bool IsAccCheckVar(const NodeData* var) { + return var->id().find("_acc_check") != std::string::npos; +} std::string GenerateAccCheckNodeId(const std::string& node_id) { return node_id + cinn::common::UniqName("_acc_check"); @@ -420,8 +458,10 @@ bool IsAccCheckGroup(const std::vector& group) { return false; } -std::vector> RemoveAccCheckGroups(const std::vector>& groups) { - if (cinn::runtime::CheckStringFlagFalse(FLAGS_cinn_check_fusion_accuracy_pass)) { +std::vector> RemoveAccCheckGroups( + const std::vector>& groups) { + if (cinn::runtime::CheckStringFlagFalse( + FLAGS_cinn_check_fusion_accuracy_pass)) { // no set acc check flag return groups; } diff --git a/paddle/cinn/hlir/framework/visualize_helper.h b/paddle/cinn/hlir/framework/visualize_helper.h index c961a0137340c..032f5f2ec4e0c 100644 --- a/paddle/cinn/hlir/framework/visualize_helper.h +++ b/paddle/cinn/hlir/framework/visualize_helper.h @@ -41,7 +41,8 @@ class PassPrinter { } bool Begin(const std::unordered_set& fetch_ids = {}); - bool PassBegin(const std::string& pass_name, const frontend::Program& program); + bool PassBegin(const std::string& pass_name, + const frontend::Program& program); bool PassEnd(const std::string& pass_name, const frontend::Program& program); bool PassBegin(const std::string& pass_name, Graph* g); bool PassEnd(const std::string& pass_name, Graph* g); @@ -54,7 +55,8 @@ class PassPrinter { int64_t pass_id_{0}; }; -inline void WriteToFile(const std::string& filepath, const std::string& content) { +inline void WriteToFile(const std::string& filepath, + const std::string& content) { VLOG(4) << "Write to " << filepath; std::ofstream of(filepath); CHECK(of.is_open()) << "Failed to open " << filepath; @@ -63,10 +65,13 @@ inline void WriteToFile(const std::string& filepath, const std::string& content) } inline std::string GenClusterId(const std::vector& group, int group_id) { - return "group_" + std::to_string(group_id) + "(size=" + std::to_string(group.size()) + ")"; + return "group_" + std::to_string(group_id) + + "(size=" + std::to_string(group.size()) + ")"; } -inline std::string GenNodeId(const Node* node, bool is_recomputed, int recompute_id) { +inline std::string GenNodeId(const Node* node, + bool is_recomputed, + int recompute_id) { if (is_recomputed) { return node->id() + "/" + std::to_string(recompute_id); } else { @@ -74,7 +79,9 @@ inline std::string GenNodeId(const Node* node, bool is_recomputed, int recompute } } -inline std::string GenNodeDataId(const NodeData* data, bool is_recomputed, int recompute_id) { +inline std::string GenNodeDataId(const NodeData* data, + bool is_recomputed, + int recompute_id) { if (is_recomputed) { return data->id() + "/" + std::to_string(recompute_id); } else { @@ -84,21 +91,25 @@ inline std::string GenNodeDataId(const NodeData* data, bool is_recomputed, int r inline std::vector GetGroupOpAttrs(bool is_recomputed = false) { std::string color = is_recomputed ? "#836FFF" : "#8EABFF"; - return std::vector{ - utils::DotAttr("shape", "Mrecord"), utils::DotAttr("color", color), utils::DotAttr("style", "filled")}; + return std::vector{utils::DotAttr("shape", "Mrecord"), + utils::DotAttr("color", color), + utils::DotAttr("style", "filled")}; } inline std::vector GetOutlinkOpAttrs() { - return std::vector{ - utils::DotAttr("shape", "Mrecord"), utils::DotAttr("color", "#ff7f00"), utils::DotAttr("style", "filled")}; + return std::vector{utils::DotAttr("shape", "Mrecord"), + utils::DotAttr("color", "#ff7f00"), + utils::DotAttr("style", "filled")}; } inline std::vector GetGroupVarAttrs(bool is_fetched = false) { if (is_fetched) { - return std::vector{ - utils::DotAttr("peripheries", "2"), utils::DotAttr("color", "#43CD80"), utils::DotAttr("style", "filled")}; + return std::vector{utils::DotAttr("peripheries", "2"), + utils::DotAttr("color", "#43CD80"), + utils::DotAttr("style", "filled")}; } else { - return std::vector{utils::DotAttr("color", "#FFDC85"), utils::DotAttr("style", "filled")}; + return std::vector{utils::DotAttr("color", "#FFDC85"), + utils::DotAttr("style", "filled")}; } } @@ -114,8 +125,9 @@ inline std::vector GetGroupAttrs(size_t group_size) { // group_size > 10 fillcolor = "#EEE5DE"; } - std::vector attrs = { - utils::DotAttr("color", "grey"), utils::DotAttr("style", "filled"), utils::DotAttr("fillcolor", fillcolor)}; + std::vector attrs = {utils::DotAttr("color", "grey"), + utils::DotAttr("style", "filled"), + utils::DotAttr("fillcolor", fillcolor)}; return attrs; } @@ -125,27 +137,30 @@ std::string GetFilePathForGroup(const std::vector>& groups, const int group_id, const std::string& viz_path); -std::string GenNodeDataLabel(const NodeData* node, - const absl::flat_hash_map& shape_dict, - const absl::flat_hash_map& dtype_dict, - const std::string dot_nodedata_id); +std::string GenNodeDataLabel( + const NodeData* node, + const absl::flat_hash_map& shape_dict, + const absl::flat_hash_map& dtype_dict, + const std::string dot_nodedata_id); -void Summary(const std::vector>& groups, const std::string& viz_path); +void Summary(const std::vector>& groups, + const std::string& viz_path); std::string DebugString(const Node* node); void FindRecomputeNodes(const std::vector>& groups, std::unordered_map* recompute_nodes); -void AddGroupNode(const Node* node, - const std::string& dot_cluster_id, - const std::unordered_set& fetch_var_ids, - const absl::flat_hash_map& shape_dict, - const absl::flat_hash_map& dtype_dict, - std::unordered_map* recompute_nodes, - std::unordered_map* outnode2dot_id, - std::unordered_set* nodedatas_set, - utils::DotLang* dot); +void AddGroupNode( + const Node* node, + const std::string& dot_cluster_id, + const std::unordered_set& fetch_var_ids, + const absl::flat_hash_map& shape_dict, + const absl::flat_hash_map& dtype_dict, + std::unordered_map* recompute_nodes, + std::unordered_map* outnode2dot_id, + std::unordered_set* nodedatas_set, + utils::DotLang* dot); // used for CheckFusionAccuracyPass std::string GenerateAccCheckNodeId(const std::string& node_id); @@ -154,7 +169,8 @@ bool IsAccCheckOp(const Node* op); bool IsAccCheckVar(const NodeData* var); bool IsAccCheckGroup(const std::vector& group); -std::vector> RemoveAccCheckGroups(const std::vector>& groups); +std::vector> RemoveAccCheckGroups( + const std::vector>& groups); } // namespace framework } // namespace hlir diff --git a/paddle/cinn/hlir/op/broadcast.cc b/paddle/cinn/hlir/op/broadcast.cc index 5d876bb615a84..727dcbb0c4578 100644 --- a/paddle/cinn/hlir/op/broadcast.cc +++ b/paddle/cinn/hlir/op/broadcast.cc @@ -38,13 +38,15 @@ using framework::OpStrategy; using framework::shape_t; using framework::StrategyFunction; -#define StrategyForBinary(op_name__, pe__) \ - std::shared_ptr StrategyFor##pe__(const framework::NodeAttr &attrs, \ - const std::vector &inputs, \ - const std::vector &out_type, \ - const std::vector> &output_shapes, \ - const Target &target) { \ - return StrategyForBroadcast(attrs, inputs, out_type, output_shapes, target, #op_name__, pe::pe__); \ +#define StrategyForBinary(op_name__, pe__) \ + std::shared_ptr StrategyFor##pe__( \ + const framework::NodeAttr &attrs, \ + const std::vector &inputs, \ + const std::vector &out_type, \ + const std::vector> &output_shapes, \ + const Target &target) { \ + return StrategyForBroadcast( \ + attrs, inputs, out_type, output_shapes, target, #op_name__, pe::pe__); \ } std::shared_ptr StrategyForBroadcast( @@ -54,11 +56,17 @@ std::shared_ptr StrategyForBroadcast( const std::vector> &output_shapes, const Target &target, const std::string &op_name, - ir::Tensor (*pe_func)(const ir::Tensor &A, const ir::Tensor &B, const std::string &output_name, const Expr &axis)) { - framework::CINNCompute binary_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check."; + ir::Tensor (*pe_func)(const ir::Tensor &A, + const ir::Tensor &B, + const std::string &output_name, + const Expr &axis)) { + framework::CINNCompute binary_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of " << op_name + << " compute is empty! Please check."; CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 2U) << "at least 2 input tensors for " << op_name << " compute"; + CHECK_GE(pack_args.size(), 2U) + << "at least 2 input tensors for " << op_name << " compute"; std::string tensor_name = UniqName(op_name + "_Out"); if (FLAGS_cinn_ir_schedule) { CHECK_GE(pack_args.size(), 3U) << op_name << " 's input is not enough!"; @@ -79,18 +87,22 @@ std::shared_ptr StrategyForBroadcast( break; } } - auto out = pe_func(A, B, tensor_name, axis); + auto out = pe_func(A, B, tensor_name, axis); auto stages = CreateStages({A, B, out}); - *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; + *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; }); auto strategy = std::make_shared(); - strategy->AddImpl(binary_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy." + op_name + ".x86", 1); + strategy->AddImpl(binary_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy." + op_name + ".x86", + 1); return strategy; } -std::vector InferShapeForBroadcast(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForBroadcast( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 2UL); std::vector out_shape; @@ -101,30 +113,37 @@ std::vector InferShapeForBroadcast(const std::vector &inputs_s break; } } - VLOG(3) << "broadcast input shapes are : " << utils::Join(inputs_shape[0], ", ") << "; " + VLOG(3) << "broadcast input shapes are : " + << utils::Join(inputs_shape[0], ", ") << "; " << utils::Join(inputs_shape[1], ", "); pe::GetBroadcastOutShape(inputs_shape[0], inputs_shape[1], &out_shape, axis); VLOG(3) << "broadcast out shape: " << utils::Join(out_shape, ", "); return {out_shape}; } -std::vector InferDtypeForBroadcast(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForBroadcast(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::vector InferDtypeForBroadcastCmp(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForBroadcastCmp( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; return {Bool()}; } -std::vector> InferLayoutForBroadcast(const std::vector> &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector> InferLayoutForBroadcast( + const std::vector> &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { int input_size = input_layouts.size(); - CHECK(input_size == 2U || input_size == 3U) << "The input's layouts size is not 2 or 3! Please check again."; + CHECK(input_size == 2U || input_size == 3U) + << "The input's layouts size is not 2 or 3! Please check again."; int axis = -1; if (attrs.attr_store.find("axis") != attrs.attr_store.end()) { axis = absl::get(attrs.attr_store.at("axis")); @@ -134,7 +153,7 @@ std::vector> InferLayoutForBroadcast(const std::vector< return {{input_layouts[0]}, input_layouts}; } else if (input_layouts[0].empty() || input_layouts[1].empty()) { int undef_idx = input_layouts[0] == "" ? 0 : 1; - int def_idx = 1 - undef_idx; + int def_idx = 1 - undef_idx; CHECK_GE(input_shapes[def_idx].size(), input_shapes[undef_idx].size()); auto ret = out_layouts[def_idx]; if (input_size == 2) { @@ -147,7 +166,7 @@ std::vector> InferLayoutForBroadcast(const std::vector< ir::Layout layout0(input_layouts[0]); ir::Layout layout1(input_layouts[1]); int large_idx = layout0.ndims() >= layout1.ndims() ? 0 : 1; - auto ret = input_layouts[large_idx]; + auto ret = input_layouts[large_idx]; if (input_size == 2) { return {{ret}, {ret, ret}}; } else { @@ -156,24 +175,29 @@ std::vector> InferLayoutForBroadcast(const std::vector< } } -std::shared_ptr StrategyForBroadcastTo(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForBroadcastTo( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { std::vector out_shape; std::vector broadcast_axes; CHECK(attrs.attr_store.count("out_shape")); out_shape = absl::get>(attrs.attr_store.at("out_shape")); CHECK(attrs.attr_store.count("broadcast_axes")); - broadcast_axes = absl::get>(attrs.attr_store.at("broadcast_axes")); + broadcast_axes = + absl::get>(attrs.attr_store.at("broadcast_axes")); VLOG(3) << "broadcast out shape: " << utils::Join(out_shape, ", "); VLOG(3) << "broadcast_axes shape: " << utils::Join(broadcast_axes, ", "); - framework::CINNCompute broadcast_to_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of broadcast_to compute is empty! Please check."; + framework::CINNCompute broadcast_to_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of broadcast_to compute is empty! Please check."; CINNValuePack pack_args = args[0]; - CHECK(!pack_args.empty()) << "The input tensors of broadcast_to compute is empty! Please check."; + CHECK(!pack_args.empty()) + << "The input tensors of broadcast_to compute is empty! Please check."; std::string tensor_name = UniqName("broadcast_to_Out"); if (FLAGS_cinn_ir_schedule) { CHECK_GE(pack_args.size(), 2U); @@ -183,26 +207,30 @@ std::shared_ptr StrategyForBroadcastTo(const framework::NodeAttr &at Expr A_expr = pack_args[0]; CHECK(A_expr.as_tensor()); ir::Tensor A = A_expr.as_tensor_ref(); - auto out = pe::BroadcastTo(A, out_shape, broadcast_axes, tensor_name); - auto stages = CreateStages({A, out}); - *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + auto out = pe::BroadcastTo(A, out_shape, broadcast_axes, tensor_name); + auto stages = CreateStages({A, out}); + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; }); auto strategy = std::make_shared(); - strategy->AddImpl( - broadcast_to_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.broadcast_to.x86", 1); + strategy->AddImpl(broadcast_to_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.broadcast_to.x86", + 1); return strategy; } -std::vector InferShapeForBroadcastTo(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 1UL) << "input_shape size should be one. Please Check."; +std::vector InferShapeForBroadcastTo( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 1UL) + << "input_shape size should be one. Please Check."; std::vector broadcast_axes; std::vector out_shape; CHECK(attrs.count("broadcast_axes")); CHECK(attrs.count("out_shape")); - out_shape = absl::get>(attrs.at("out_shape")); + out_shape = absl::get>(attrs.at("out_shape")); broadcast_axes = absl::get>(attrs.at("broadcast_axes")); VLOG(3) << "broadcast input shape: " << utils::Join(inputs_shape[0], ", "); @@ -210,58 +238,65 @@ std::vector InferShapeForBroadcastTo(const std::vector &inputs VLOG(3) << "broadcast_axes shape: " << utils::Join(broadcast_axes, ", "); CHECK_EQ(inputs_shape[0].size(), broadcast_axes.size()) << "broadcast_axes's size should be same with the input shape's size"; - CHECK_GE(out_shape.size(), broadcast_axes.size()) << "broadcast_axes's size should be no more than out_shape's size"; + CHECK_GE(out_shape.size(), broadcast_axes.size()) + << "broadcast_axes's size should be no more than out_shape's size"; return {out_shape}; } -std::vector> InferLayoutForBroadcastTo(const std::vector> &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK(input_layouts.size() == 1U) << "The input's layouts size is not 1! Please check again."; +std::vector> InferLayoutForBroadcastTo( + const std::vector> &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK(input_layouts.size() == 1U) + << "The input's layouts size is not 1! Please check again."; std::vector out_layouts = {""}; if (attrs.attr_store.count("out_layouts")) { - out_layouts = absl::get>(attrs.attr_store.at("out_layouts")); + out_layouts = + absl::get>(attrs.attr_store.at("out_layouts")); } return {out_layouts, input_layouts}; } -std::vector InferDtypeForBroadcastGrad(const std::vector &inputs_type, - const framework::AttrMapType &attrs) { +std::vector InferDtypeForBroadcastGrad( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK_EQ(inputs_type.size(), 3UL); - // Avoid no need buffer var, like elementwise_add_grad's input X and Y is no need buffer var, - // in this situation, the X and Y's type is default value FP32, not the real type, - // we should get the real type from dout. + // Avoid no need buffer var, like elementwise_add_grad's input X and Y is no + // need buffer var, in this situation, the X and Y's type is default value + // FP32, not the real type, we should get the real type from dout. std::vector out_type{inputs_type[0], inputs_type[0]}; return out_type; } -std::vector InferShapeForBroadcastGrad(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForBroadcastGrad( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 3UL); std::vector out_shape{inputs_shape[1], inputs_shape[2]}; return out_shape; } -std::shared_ptr StrategyForBroadcastGrad(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - LOG(FATAL) - << "Gradient operator will be decomposed into several primitive operators. Please Use Decomposer Program Pass."; +std::shared_ptr StrategyForBroadcastGrad( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + LOG(FATAL) << "Gradient operator will be decomposed into several primitive " + "operators. Please Use Decomposer Program Pass."; } -std::shared_ptr StrategyForIsClose(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector &output_shapes, - const Target &target) { +std::shared_ptr StrategyForIsClose( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector &output_shapes, + const Target &target) { float rtol = 1e-05f, atol = 1e-08f; bool equal_nan = false; - int axis = -1; + int axis = -1; if (attrs.attr_store.count("axis")) { axis = absl::get(attrs.attr_store.at("axis")); @@ -276,49 +311,60 @@ std::shared_ptr StrategyForIsClose(const framework::NodeAttr &attrs, equal_nan = absl::get(attrs.attr_store.at("equal_nan")); } - framework::CINNCompute isclose_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of isclose compute is empty! Please check."; - CINNValuePack pack_args = args[0]; - int input_size = pack_args.size(); - - std::string tensor_name = UniqName("IsClose_output"); - if (FLAGS_cinn_ir_schedule) { - // the last pack argument is the output tensor name - tensor_name = pack_args.back().operator std::string(); - --input_size; - } - CHECK_EQ(input_size, 2) << "The input number of isclose should be 2, but here " << input_size << "! Please check."; - - // the input tensor are in front - Expr x_expr = pack_args[0]; - CHECK(x_expr.as_tensor()); - auto x_tensor = x_expr.as_tensor_ref(); - - Expr y_expr = pack_args[1]; - CHECK(y_expr.as_tensor()); - auto y_tensor = y_expr.as_tensor_ref(); - - auto out = pe::IsClose(x_tensor, y_tensor, axis, rtol, atol, equal_nan, tensor_name); - - auto stages = CreateStages({out}); - *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; - }); + framework::CINNCompute isclose_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of isclose compute is empty! Please check."; + CINNValuePack pack_args = args[0]; + int input_size = pack_args.size(); + + std::string tensor_name = UniqName("IsClose_output"); + if (FLAGS_cinn_ir_schedule) { + // the last pack argument is the output tensor name + tensor_name = pack_args.back().operator std::string(); + --input_size; + } + CHECK_EQ(input_size, 2) + << "The input number of isclose should be 2, but here " + << input_size << "! Please check."; + + // the input tensor are in front + Expr x_expr = pack_args[0]; + CHECK(x_expr.as_tensor()); + auto x_tensor = x_expr.as_tensor_ref(); + + Expr y_expr = pack_args[1]; + CHECK(y_expr.as_tensor()); + auto y_tensor = y_expr.as_tensor_ref(); + + auto out = pe::IsClose( + x_tensor, y_tensor, axis, rtol, atol, equal_nan, tensor_name); + + auto stages = CreateStages({out}); + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + }); auto strategy = std::make_shared(); - strategy->AddImpl(isclose_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.assertisclose", 1); + strategy->AddImpl(isclose_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.assertisclose", + 1); return strategy; } -std::vector InferDtypeForIsClose(const std::vector &inputs_type, const framework::AttrMapType &attrs) { +std::vector InferDtypeForIsClose(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { int input_size = inputs_type.size(); - CHECK_EQ(input_size, 2UL) << "The input number of isclose should be a multiple of 2, but here " << input_size - << "! Please check."; - CHECK(inputs_type[0].is_float()) << "The op \"isclose\" only support float point dtype now, but here " - << inputs_type[0]; + CHECK_EQ(input_size, 2UL) + << "The input number of isclose should be a multiple of 2, but here " + << input_size << "! Please check."; + CHECK(inputs_type[0].is_float()) + << "The op \"isclose\" only support float point dtype now, but here " + << inputs_type[0]; CHECK(inputs_type[0] == inputs_type[1]) - << "The two inputs dtype sof isclose should be equal, but here x:" << inputs_type[0] << " != y:" << inputs_type[1] - << "! Please check."; + << "The two inputs dtype sof isclose should be equal, but here x:" + << inputs_type[0] << " != y:" << inputs_type[1] << "! Please check."; return {Bool()}; } @@ -360,28 +406,38 @@ StrategyForBinary(logical_right_shift, LogicalRightShift); } // namespace cinn CINN_REGISTER_HELPER(broadcast_ops) { -#define CINN_REGISTER_BINARY(op__, op_stragegy__) \ - CINN_REGISTER_OP(op__) \ - .describe(#op__ " function") \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr("CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBroadcast)) \ - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForBroadcast)) \ - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForBroadcast)) \ - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kBroadcast) \ +#define CINN_REGISTER_BINARY(op__, op_stragegy__) \ + CINN_REGISTER_OP(op__) \ + .describe(#op__ " function") \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr( \ + "CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ + .set_attr("infershape", \ + MakeOpFunction(cinn::hlir::op::InferShapeForBroadcast)) \ + .set_attr("inferdtype", \ + MakeOpFunction(cinn::hlir::op::InferDtypeForBroadcast)) \ + .set_attr("inferlayout", \ + MakeOpFunction(cinn::hlir::op::InferLayoutForBroadcast)) \ + .set_attr( \ + "OpPattern", cinn::hlir::framework::OpPatternKind::kBroadcast) \ .set_support_level(4); -#define CINN_REGISTER_BINARY_CMP(op__, op_stragegy__) \ - CINN_REGISTER_OP(op__) \ - .describe(#op__ " function") \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr("CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBroadcast)) \ - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForBroadcastCmp)) \ - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForBroadcast)) \ - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kBroadcast) \ +#define CINN_REGISTER_BINARY_CMP(op__, op_stragegy__) \ + CINN_REGISTER_OP(op__) \ + .describe(#op__ " function") \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr( \ + "CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ + .set_attr("infershape", \ + MakeOpFunction(cinn::hlir::op::InferShapeForBroadcast)) \ + .set_attr("inferdtype", \ + MakeOpFunction(cinn::hlir::op::InferDtypeForBroadcastCmp)) \ + .set_attr("inferlayout", \ + MakeOpFunction(cinn::hlir::op::InferLayoutForBroadcast)) \ + .set_attr( \ + "OpPattern", cinn::hlir::framework::OpPatternKind::kBroadcast) \ .set_support_level(4); CINN_REGISTER_BINARY(elementwise_add, Add); @@ -419,24 +475,36 @@ CINN_REGISTER_HELPER(broadcast_ops) { .describe("broadcast one tensor to the target shape") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForBroadcastTo) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBroadcastTo)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForBroadcast)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForBroadcastTo) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForBroadcastTo)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForBroadcast)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForBroadcastTo)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForBroadcastTo)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kBroadcast) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kBroadcast) .set_support_level(4); CINN_REGISTER_OP(isclose) - .describe("This operator checks if all x and y satisfy the condition: |x - y| <= atol + rtol * |y|") + .describe( + "This operator checks if all x and y satisfy the condition: |x - y| " + "<= atol + rtol * |y|") .set_num_inputs(2) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForIsClose) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBroadcast)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForIsClose)) - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForBroadcast)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kBroadcast) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForIsClose) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForBroadcast)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForIsClose)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForBroadcast)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kBroadcast) .set_support_level(4); return true; @@ -447,9 +515,12 @@ CINN_REGISTER_HELPER(broadcast_grad_ops) { .describe("The gradient of elementwise_add operator.") .set_num_inputs(3) .set_num_outputs(2) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForBroadcastGrad) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBroadcastGrad)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForBroadcastGrad)); + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForBroadcastGrad) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForBroadcastGrad)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForBroadcastGrad)); return true; } diff --git a/paddle/cinn/hlir/op/contrib/argmax.cc b/paddle/cinn/hlir/op/contrib/argmax.cc index a8c0150fc38af..ef3473fd51bf6 100644 --- a/paddle/cinn/hlir/op/contrib/argmax.cc +++ b/paddle/cinn/hlir/op/contrib/argmax.cc @@ -52,7 +52,7 @@ std::vector Argmax(const Tensor &in_tensor, const bool &keep_dims, const std::string &name) { auto shape = in_tensor->shape; - auto ndim = shape.size(); + auto ndim = shape.size(); CHECK_GT(ndim, 0) << "tensor's dim must be more than 0"; int pos_axis = axis; @@ -64,7 +64,8 @@ std::vector Argmax(const Tensor &in_tensor, std::vector output_shape; for (int i = 0; i < shape.size(); ++i) { - CHECK(shape[i].is_constant()) << "Input tensor's shape should be constant value."; + CHECK(shape[i].is_constant()) + << "Input tensor's shape should be constant value."; if (pos_axis == i) { if (keep_dims) { output_shape.push_back(Expr(1)); @@ -77,8 +78,9 @@ std::vector Argmax(const Tensor &in_tensor, output_shape.push_back(Expr(1)); } - auto sort_index = ArgSort(in_tensor, target, stages, pos_axis, false, name + "_index"); - auto res = Compute( + auto sort_index = + ArgSort(in_tensor, target, stages, pos_axis, false, name + "_index"); + auto res = Compute( output_shape, [=](const std::vector &indices) { std::vector eval_indices(indices); @@ -94,11 +96,12 @@ std::vector Argmax(const Tensor &in_tensor, return {res, sort_index.at(0), sort_index.at(1)}; } -std::shared_ptr StrategyForArgmax(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForArgmax( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { int axis; bool keep_dims = false; @@ -111,28 +114,36 @@ std::shared_ptr StrategyForArgmax(const framework::NodeAt keep_dims = absl::get(attrs.attr_store.at("keep_dim")); } - framework::CINNCompute argmax_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of argmax compute is empty! Please check."; - common::CINNValuePack pack_args = args[0]; - std::string tensor_name = UniqName("Argmax_out"); - CHECK_GE(pack_args.size(), 1U) << "There should be 1 input args for argmax compute"; - Expr in_expr = pack_args[0]; - CHECK(in_expr.as_tensor()); - Tensor in_tensor = in_expr.as_tensor_ref(); - auto stages = CreateStages({in_tensor}); - CHECK_EQ(pack_args.size(), 2U); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); - std::vector out_tensor = Argmax(in_tensor, target, stages, axis, keep_dims, tensor_name); - - stages->InsertLazily(out_tensor[0]); - std::vector cinn_values{ - CINNValue(out_tensor[0]), CINNValue(out_tensor[1]), CINNValue(out_tensor[2]), CINNValue(stages)}; - *ret = common::CINNValuePack{cinn_values}; - }); + framework::CINNCompute argmax_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of argmax compute is empty! Please check."; + common::CINNValuePack pack_args = args[0]; + std::string tensor_name = UniqName("Argmax_out"); + CHECK_GE(pack_args.size(), 1U) + << "There should be 1 input args for argmax compute"; + Expr in_expr = pack_args[0]; + CHECK(in_expr.as_tensor()); + Tensor in_tensor = in_expr.as_tensor_ref(); + auto stages = CreateStages({in_tensor}); + CHECK_EQ(pack_args.size(), 2U); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + std::vector out_tensor = + Argmax(in_tensor, target, stages, axis, keep_dims, tensor_name); + + stages->InsertLazily(out_tensor[0]); + std::vector cinn_values{CINNValue(out_tensor[0]), + CINNValue(out_tensor[1]), + CINNValue(out_tensor[2]), + CINNValue(stages)}; + *ret = common::CINNValuePack{cinn_values}; + }); - framework::CINNSchedule argmax_schedule([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of argmax_schedule is empty! Please check.\n"; + framework::CINNSchedule argmax_schedule([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of argmax_schedule is empty! Please check.\n"; common::CINNValuePack arg_pack = args[0]; std::vector vec_ast; for (int i = 0; i < arg_pack.size(); i++) { @@ -146,16 +157,21 @@ std::shared_ptr StrategyForArgmax(const framework::NodeAt ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); auto blocks = ir_sch.GetAllBlocks(); - // TODO: It needs to be rewritten according to the reduction_max operator to improve performance. - // Do not use local variables, because the size will exceed the limit. + // TODO: It needs to be rewritten according to the reduction_max operator to + // improve performance. Do not use local variables, because the size will + // exceed the limit. ir_sch.SetBuffer(blocks[0], "local"); ir_sch.SetBuffer(blocks[1], "local"); - long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies()); + long prod_size = std::accumulate(output_shapes[0].begin(), + output_shapes[0].end(), + 1, + std::multiplies()); if (prod_size > 1 && target.arch == Target::Arch::X86) { pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true); } - std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = common::CINNValuePack{res}; }); @@ -165,8 +181,9 @@ std::shared_ptr StrategyForArgmax(const framework::NodeAt return strategy; } -std::vector InferShapeForArgmax(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForArgmax( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { CHECK(inputs_shape.size() == 1UL); auto ndim = inputs_shape[0].size(); CHECK_GT(ndim, 0) << "tensor's dim must be more than 0"; @@ -207,17 +224,22 @@ std::vector InferShapeForArgmax(const std::vector &inputs_shap return {out_shapes}; } -std::vector InferDtypeForArgmax(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForArgmax(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; return {Int(32)}; } -std::vector> InferLayoutForArgmax(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_shapes.size(), 1U) << "The input's shape size is not 1! Please check again."; - CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again."; +std::vector> InferLayoutForArgmax( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_shapes.size(), 1U) + << "The input's shape size is not 1! Please check again."; + CHECK_EQ(input_layouts.size(), 1U) + << "The input's layout size is not 1! Please check again."; return {input_layouts, input_layouts}; } } // namespace op @@ -229,10 +251,14 @@ CINN_REGISTER_HELPER(argmax_ops) { .describe("This operator implements the op argmax.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForArgmax) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForArgmax)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForArgmax)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForArgmax) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForArgmax)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForArgmax)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/argmax.h b/paddle/cinn/hlir/op/contrib/argmax.h index 43b306202924b..b52f9e80f4ce5 100644 --- a/paddle/cinn/hlir/op/contrib/argmax.h +++ b/paddle/cinn/hlir/op/contrib/argmax.h @@ -25,7 +25,7 @@ std::vector Argmax(const ir::Tensor &in_tensor, const common::Target &target, poly::StageMap stages, const int &axis, - const bool &keep_dims = false, + const bool &keep_dims = false, const std::string &name = "T_Argmax_out"); } // namespace op } // namespace hlir diff --git a/paddle/cinn/hlir/op/contrib/argmax_test.cc b/paddle/cinn/hlir/op/contrib/argmax_test.cc index 49b3cfb38c91f..786e19b163a9a 100644 --- a/paddle/cinn/hlir/op/contrib/argmax_test.cc +++ b/paddle/cinn/hlir/op/contrib/argmax_test.cc @@ -47,11 +47,19 @@ TEST(GenerateCode_Cpu, Argmax_Keep) { lang::Placeholder in("in", {n, in_c, h, w}); poly::StageMap stages = poly::CreateStages({in}); - ir::Tensor res = Argmax(in, target, stages, axis, true, "test_argmax_in").at(0); + ir::Tensor res = + Argmax(in, target, stages, axis, true, "test_argmax_in").at(0); stages->InsertLazily(res); std::vector funcs = - lang::LowerVec("TestGenerateCodeCpu_Argmax_Keep", stages, {in, res}, {}, {}, nullptr, target, true); + lang::LowerVec("TestGenerateCodeCpu_Argmax_Keep", + stages, + {in, res}, + {}, + {}, + nullptr, + target, + true); VLOG(6) << "Expr before CPU codegen:"; VLOG(6) << funcs[0]->body; @@ -63,7 +71,8 @@ TEST(GenerateCode_Cpu, Argmax_Keep) { backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); codegen.SetInlineBuiltinCodes(false); - std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + std::string code = + codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); auto target_source = R"ROC( #include #include diff --git a/paddle/cinn/hlir/op/contrib/argmin.cc b/paddle/cinn/hlir/op/contrib/argmin.cc index f6f2c641cfc73..b0377f3434959 100644 --- a/paddle/cinn/hlir/op/contrib/argmin.cc +++ b/paddle/cinn/hlir/op/contrib/argmin.cc @@ -52,7 +52,7 @@ std::vector Argmin(const Tensor &in_tensor, const bool &keep_dims, const std::string &name) { auto shape = in_tensor->shape; - auto ndim = shape.size(); + auto ndim = shape.size(); CHECK_GT(ndim, 0) << "tensor's dim must be more than 0"; int pos_axis = axis; @@ -64,7 +64,8 @@ std::vector Argmin(const Tensor &in_tensor, std::vector output_shape; for (int i = 0; i < shape.size(); ++i) { - CHECK(shape[i].is_constant()) << "Input tensor's shape should be constant value."; + CHECK(shape[i].is_constant()) + << "Input tensor's shape should be constant value."; if (pos_axis == i) { if (keep_dims) { output_shape.push_back(Expr(1)); @@ -76,8 +77,9 @@ std::vector Argmin(const Tensor &in_tensor, if (output_shape.empty()) { output_shape.push_back(Expr(1)); } - auto sort_index = ArgSort(in_tensor, target, stages, pos_axis, true, name + "_index"); - auto res = Compute( + auto sort_index = + ArgSort(in_tensor, target, stages, pos_axis, true, name + "_index"); + auto res = Compute( output_shape, [=](const std::vector &indices) { std::vector eval_indices(indices); @@ -93,11 +95,12 @@ std::vector Argmin(const Tensor &in_tensor, return {res, sort_index.at(0), sort_index.at(1)}; } -std::shared_ptr StrategyForArgmin(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForArgmin( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { int axis; bool keep_dims = false; @@ -110,27 +113,35 @@ std::shared_ptr StrategyForArgmin(const framework::NodeAt keep_dims = absl::get(attrs.attr_store.at("keep_dim")); } - framework::CINNCompute argmin_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of argmin compute is empty! Please check."; - common::CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 1U) << "There should be 1 input args for argmax compute"; - Expr in_expr = pack_args[0]; - CHECK(in_expr.as_tensor()); - Tensor in_tensor = in_expr.as_tensor_ref(); - auto stages = CreateStages({in_tensor}); - CHECK_EQ(pack_args.size(), 2U); - CHECK(pack_args[1].is_string()); - std::string tensor_name = pack_args[1].operator std::string(); - auto out_tensor = Argmin(in_tensor, target, stages, axis, keep_dims, tensor_name); - - stages->InsertLazily(out_tensor[0]); - std::vector cinn_values{ - CINNValue(out_tensor[0]), CINNValue(out_tensor[1]), CINNValue(out_tensor[2]), CINNValue(stages)}; - *ret = common::CINNValuePack{cinn_values}; - }); + framework::CINNCompute argmin_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of argmin compute is empty! Please check."; + common::CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 1U) + << "There should be 1 input args for argmax compute"; + Expr in_expr = pack_args[0]; + CHECK(in_expr.as_tensor()); + Tensor in_tensor = in_expr.as_tensor_ref(); + auto stages = CreateStages({in_tensor}); + CHECK_EQ(pack_args.size(), 2U); + CHECK(pack_args[1].is_string()); + std::string tensor_name = pack_args[1].operator std::string(); + auto out_tensor = + Argmin(in_tensor, target, stages, axis, keep_dims, tensor_name); + + stages->InsertLazily(out_tensor[0]); + std::vector cinn_values{CINNValue(out_tensor[0]), + CINNValue(out_tensor[1]), + CINNValue(out_tensor[2]), + CINNValue(stages)}; + *ret = common::CINNValuePack{cinn_values}; + }); - framework::CINNSchedule argmin_schedule([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of arange_schedule is empty! Please check.\n"; + framework::CINNSchedule argmin_schedule([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of arange_schedule is empty! Please check.\n"; common::CINNValuePack arg_pack = args[0]; std::vector vec_ast; for (int i = 0; i < arg_pack.size(); i++) { @@ -144,15 +155,20 @@ std::shared_ptr StrategyForArgmin(const framework::NodeAt ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); auto blocks = ir_sch.GetAllBlocks(); - // TODO: It needs to be rewritten according to the reduction_min operator to improve performance. - // Do not use local variables, because the size will exceed the limit. + // TODO: It needs to be rewritten according to the reduction_min operator to + // improve performance. Do not use local variables, because the size will + // exceed the limit. ir_sch.SetBuffer(blocks[0], "local"); ir_sch.SetBuffer(blocks[1], "local"); - long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies()); + long prod_size = std::accumulate(output_shapes[0].begin(), + output_shapes[0].end(), + 1, + std::multiplies()); if (prod_size > 1 && target.arch == Target::Arch::X86) { pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true); } - std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = common::CINNValuePack{res}; }); @@ -162,8 +178,9 @@ std::shared_ptr StrategyForArgmin(const framework::NodeAt return strategy; } -std::vector InferShapeForArgmin(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForArgmin( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { CHECK(inputs_shape.size() == 1UL); auto ndim = inputs_shape[0].size(); CHECK_GT(ndim, 0) << "tensor's dim must be more than 0"; @@ -205,17 +222,22 @@ std::vector InferShapeForArgmin(const std::vector &inputs_shap return {out_shapes}; } -std::vector InferDtypeForArgmin(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForArgmin(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; return {Int(32)}; } -std::vector> InferLayoutForArgmin(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_shapes.size(), 1U) << "The input's shape size is not 1! Please check again."; - CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again."; +std::vector> InferLayoutForArgmin( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_shapes.size(), 1U) + << "The input's shape size is not 1! Please check again."; + CHECK_EQ(input_layouts.size(), 1U) + << "The input's layout size is not 1! Please check again."; return {input_layouts, input_layouts}; } } // namespace op @@ -227,10 +249,14 @@ CINN_REGISTER_HELPER(argmin_ops) { .describe("This operator implements the op argmin.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForArgmin) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForArgmin)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForArgmin)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForArgmin) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForArgmin)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForArgmin)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/argmin.h b/paddle/cinn/hlir/op/contrib/argmin.h index 839e5ec6eee79..17b0095b5c8a4 100644 --- a/paddle/cinn/hlir/op/contrib/argmin.h +++ b/paddle/cinn/hlir/op/contrib/argmin.h @@ -25,7 +25,7 @@ std::vector Argmin(const ir::Tensor& in_tensor, const common::Target& target, poly::StageMap stages, const int& axis, - const bool& keep_dims = false, + const bool& keep_dims = false, const std::string& name = "T_Argmin_out"); } // namespace op } // namespace hlir diff --git a/paddle/cinn/hlir/op/contrib/argmin_test.cc b/paddle/cinn/hlir/op/contrib/argmin_test.cc index d4625e14df04c..a979870fe88a9 100644 --- a/paddle/cinn/hlir/op/contrib/argmin_test.cc +++ b/paddle/cinn/hlir/op/contrib/argmin_test.cc @@ -46,11 +46,19 @@ TEST(GenerateCode_Cpu, Argmin_Keep) { lang::Placeholder in("in", {n, in_c, h, w}); poly::StageMap stages = poly::CreateStages({in}); - ir::Tensor res = Argmin(in, target, stages, axis, true, "test_argmin_in").at(0); + ir::Tensor res = + Argmin(in, target, stages, axis, true, "test_argmin_in").at(0); stages->InsertLazily(res); std::vector funcs = - lang::LowerVec("TestGenerateCodeCpu_Argmin_Keep", stages, {in, res}, {}, {}, nullptr, target, true); + lang::LowerVec("TestGenerateCodeCpu_Argmin_Keep", + stages, + {in, res}, + {}, + {}, + nullptr, + target, + true); VLOG(6) << "Expr before CPU codegen:"; VLOG(6) << funcs[0]->body; @@ -62,7 +70,8 @@ TEST(GenerateCode_Cpu, Argmin_Keep) { backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); codegen.SetInlineBuiltinCodes(false); - std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + std::string code = + codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); auto target_source = R"ROC( #include #include diff --git a/paddle/cinn/hlir/op/contrib/assert_true.cc b/paddle/cinn/hlir/op/contrib/assert_true.cc index 0fb0328ed87d5..a91f740c54892 100644 --- a/paddle/cinn/hlir/op/contrib/assert_true.cc +++ b/paddle/cinn/hlir/op/contrib/assert_true.cc @@ -36,37 +36,45 @@ namespace op { using common::CINNValue; using common::CINNValuePack; -std::shared_ptr StrategyForAssertTrue(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute assert_true_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of assert_true is empty! Please check."; +std::shared_ptr StrategyForAssertTrue( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute assert_true_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of assert_true is empty! Please check."; CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 2U) << "Two input tensors are required for the computation of assert_true."; - Expr a_expr = pack_args[0]; - Expr b_expr = pack_args[1]; - ir::Tensor a = a_expr.as_tensor_ref(); - ir::Tensor b = b_expr.as_tensor_ref(); + CHECK_GE(pack_args.size(), 2U) + << "Two input tensors are required for the computation of assert_true."; + Expr a_expr = pack_args[0]; + Expr b_expr = pack_args[1]; + ir::Tensor a = a_expr.as_tensor_ref(); + ir::Tensor b = b_expr.as_tensor_ref(); std::string tensor_name = "assert_true_out"; - auto out = pe::Identity(b, tensor_name).front(); - auto stages = CreateStages({out}); + auto out = pe::Identity(b, tensor_name).front(); + auto stages = CreateStages({out}); std::vector res{CINNValue(out), CINNValue(stages)}; *ret = CINNValuePack{res}; }); auto strategy = std::make_shared(); - strategy->AddImpl( - assert_true_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.assert_true.x86", 1); + strategy->AddImpl(assert_true_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.assert_true.x86", + 1); return strategy; } -std::vector InferShapeForAssertTrue(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForAssertTrue( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { return inputs_shape; } -std::vector InferDtypeForAssertTrue(const std::vector &inputs_type, const framework::AttrMapType &attrs) { +std::vector InferDtypeForAssertTrue(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { return inputs_type; } @@ -79,10 +87,14 @@ CINN_REGISTER_HELPER(assert_true_ops) { .describe("AssertTrue") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForAssertTrue) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForAssertTrue)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForAssertTrue)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForAssertTrue) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForAssertTrue)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForAssertTrue)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/bitcast_convert.cc b/paddle/cinn/hlir/op/contrib/bitcast_convert.cc index 4105ec32d4d93..6973afa3b6239 100644 --- a/paddle/cinn/hlir/op/contrib/bitcast_convert.cc +++ b/paddle/cinn/hlir/op/contrib/bitcast_convert.cc @@ -45,62 +45,77 @@ using common::CINNValue; using common::CINNValuePack; using framework::shape_t; -ir::Tensor BitcastConvert(const ir::Tensor &input, const Type &dtype, const std::string &name) { +ir::Tensor BitcastConvert(const ir::Tensor &input, + const Type &dtype, + const std::string &name) { auto res = Compute( - input->shape, [=](const std::vector &indices) { return input(indices); }, name); + input->shape, + [=](const std::vector &indices) { return input(indices); }, + name); return res; } -std::shared_ptr StrategyForBitcastConvert(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForBitcastConvert( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { std::string op_name("bitcast_convert"); - framework::CINNCompute bitcast_convert_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check."; - CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 1U) << "1 input tensor for " << op_name << " compute"; - std::string tensor_name = UniqName(op_name + "_Out"); - Expr A_expr = pack_args[0]; - CHECK(A_expr.as_tensor()); - ir::Tensor A = A_expr.as_tensor_ref(); - auto out = BitcastConvert(A, out_type[0], tensor_name); - auto stages = CreateStages({A}); - std::vector res; - stages->InsertLazily(out); - res.push_back(CINNValue(out)); - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); + framework::CINNCompute bitcast_convert_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of " << op_name + << " compute is empty! Please check."; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 1U) + << "1 input tensor for " << op_name << " compute"; + std::string tensor_name = UniqName(op_name + "_Out"); + Expr A_expr = pack_args[0]; + CHECK(A_expr.as_tensor()); + ir::Tensor A = A_expr.as_tensor_ref(); + auto out = BitcastConvert(A, out_type[0], tensor_name); + auto stages = CreateStages({A}); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); auto strategy = std::make_shared(); - strategy->AddImpl( - bitcast_convert_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.bitcast_convert.x86", 1); + strategy->AddImpl(bitcast_convert_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.bitcast_convert.x86", + 1); return strategy; } -std::vector InferShapeForBitcastConvert(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 1U) << "The input's shape size should be 1! Please check again."; +std::vector InferShapeForBitcastConvert( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 1U) + << "The input's shape size should be 1! Please check again."; - auto input_data_type_name = absl::get(attrs.at("input_data_type")); + auto input_data_type_name = + absl::get(attrs.at("input_data_type")); auto output_data_type_name = absl::get(attrs.at("dtype")); - auto input_data_type = common::Str2Type(input_data_type_name); - auto output_data_type = common::Str2Type(output_data_type_name); + auto input_data_type = common::Str2Type(input_data_type_name); + auto output_data_type = common::Str2Type(output_data_type_name); - auto output_shape = std::vector(inputs_shape.begin(), inputs_shape.end()); - auto ratio = input_data_type.bits() / output_data_type.bits(); + auto output_shape = + std::vector(inputs_shape.begin(), inputs_shape.end()); + auto ratio = input_data_type.bits() / output_data_type.bits(); if (ratio == 1) return inputs_shape; if (ratio > 0) { output_shape.back().emplace_back(ratio); } else { - if (output_shape.back().back() != (output_data_type.bits() / input_data_type.bits())) { - LOG(FATAL) - << "The rightmost dimension of input must be equal to sizeof(output_data_type)/sizeof(input_data_type) when " - "sizeof(output_data_type) > sizeof(input_data_type)"; + if (output_shape.back().back() != + (output_data_type.bits() / input_data_type.bits())) { + LOG(FATAL) << "The rightmost dimension of input must be equal to " + "sizeof(output_data_type)/sizeof(input_data_type) when " + "sizeof(output_data_type) > sizeof(input_data_type)"; } output_shape.back().pop_back(); } @@ -108,8 +123,8 @@ std::vector InferShapeForBitcastConvert(const std::vector &inp return output_shape; } -std::vector InferDtypeForBitcastConvert(const std::vector &inputs_type, - const framework::AttrMapType &attrs) { +std::vector InferDtypeForBitcastConvert( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(attrs.count("dtype")); return {common::Str2Type(absl::get(attrs.at("dtype")))}; } @@ -123,10 +138,14 @@ CINN_REGISTER_HELPER(bitcast_convert_ops) { .describe("BitcastConvert") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForBitcastConvert) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBitcastConvert)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForBitcastConvert)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForBitcastConvert) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForBitcastConvert)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForBitcastConvert)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/cholesky.cc b/paddle/cinn/hlir/op/contrib/cholesky.cc index d64cbcab4acd7..93f7649d0cd52 100644 --- a/paddle/cinn/hlir/op/contrib/cholesky.cc +++ b/paddle/cinn/hlir/op/contrib/cholesky.cc @@ -20,6 +20,7 @@ #include #include "absl/types/variant.h" +#include "glog/logging.h" #include "paddle/cinn/common/cas.h" #include "paddle/cinn/common/cinn_value.h" #include "paddle/cinn/common/common.h" @@ -43,7 +44,6 @@ #include "paddle/cinn/lang/compute.h" #include "paddle/cinn/lang/packed_func.h" #include "paddle/cinn/poly/stage.h" -#include "glog/logging.h" namespace cinn { namespace hlir { @@ -52,40 +52,51 @@ namespace op { using common::CINNValue; using common::CINNValuePack; -std::shared_ptr StrategyForCholesky(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute cholesky_compute([=](lang::Args args, lang::RetValue *ret) { - CINNValuePack pack_args = args[0]; - CHECK(!pack_args.empty()) << "at least one input tensor for cholesky compute\n"; - Expr x_expr = pack_args[0]; - ir::Tensor x = x_expr.as_tensor_ref(); - std::string tensor_name = "cholesky_out"; - auto out = pe::Identity(x, tensor_name).front(); - auto stages = CreateStages({out}); - std::vector res{CINNValue(out), CINNValue(stages)}; - *ret = CINNValuePack{res}; - }); +std::shared_ptr StrategyForCholesky( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute cholesky_compute( + [=](lang::Args args, lang::RetValue *ret) { + CINNValuePack pack_args = args[0]; + CHECK(!pack_args.empty()) + << "at least one input tensor for cholesky compute\n"; + Expr x_expr = pack_args[0]; + ir::Tensor x = x_expr.as_tensor_ref(); + std::string tensor_name = "cholesky_out"; + auto out = pe::Identity(x, tensor_name).front(); + auto stages = CreateStages({out}); + std::vector res{CINNValue(out), CINNValue(stages)}; + *ret = CINNValuePack{res}; + }); auto strategy = std::make_shared(); - strategy->AddImpl(cholesky_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.cholesky.x86", 1); + strategy->AddImpl(cholesky_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.cholesky.x86", + 1); return strategy; } -std::vector InferShapeForCholesky(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 1U) << "The input's shape size should be 1! Please check again."; +std::vector InferShapeForCholesky( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 1U) + << "The input's shape size should be 1! Please check again."; framework::shape_t x_shape = inputs_shape[0]; - int x_shape_size = x_shape.size(); - CHECK_GE(x_shape_size, 2U) << "The input x shape size should >= 2! Please check again."; + int x_shape_size = x_shape.size(); + CHECK_GE(x_shape_size, 2U) + << "The input x shape size should >= 2! Please check again."; CHECK_EQ(x_shape[x_shape_size - 2], x_shape[x_shape_size - 1]) << "The last two dimensions of the input x must be the same!"; return inputs_shape; } -std::vector InferDtypeForCholesky(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 1U) << "The input's shape size should be 1! Please check again."; +std::vector InferDtypeForCholesky(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 1U) + << "The input's shape size should be 1! Please check again."; CHECK(inputs_type[0].is_float(32) || inputs_type[0].is_float(64)) << "The input's dtype should be float32 or float64! Please check again."; return inputs_type; @@ -100,10 +111,14 @@ CINN_REGISTER_HELPER(cholesky_ops) { .describe("Cholesky") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForCholesky) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForCholesky)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForCholesky)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForCholesky) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForCholesky)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForCholesky)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/gather_nd.cc b/paddle/cinn/hlir/op/contrib/gather_nd.cc index 3dd8625172930..ff7844862ae3b 100644 --- a/paddle/cinn/hlir/op/contrib/gather_nd.cc +++ b/paddle/cinn/hlir/op/contrib/gather_nd.cc @@ -47,27 +47,34 @@ namespace op { using common::CINNValue; using common::CINNValuePack; -ir::Tensor GatherNd(const ir::Tensor &x, const ir::Tensor &index, const std::string &name) { - std::vector x_shape = x->shape; +ir::Tensor GatherNd(const ir::Tensor &x, + const ir::Tensor &index, + const std::string &name) { + std::vector x_shape = x->shape; std::vector index_shape = index->shape; - size_t x_shape_size = x_shape.size(); - size_t index_shape_size = index_shape.size(); + size_t x_shape_size = x_shape.size(); + size_t index_shape_size = index_shape.size(); std::vector out_shape; out_shape.insert(out_shape.end(), index_shape.begin(), index_shape.end() - 1); - out_shape.insert(out_shape.end(), x_shape.begin() + index_shape.back().as_int32(), x_shape.end()); + out_shape.insert(out_shape.end(), + x_shape.begin() + index_shape.back().as_int32(), + x_shape.end()); auto res = Compute( out_shape, [=](const std::vector &indices) { std::vector indices_position; for (size_t i = 0; i < index_shape_size - 1; ++i) { - indices_position.push_back(ir::Cast::Make(common::Int(32), indices[i])); + indices_position.push_back( + ir::Cast::Make(common::Int(32), indices[i])); } indices_position.push_back(ir::Cast::Make(common::Int(32), Expr(0))); size_t indices_position_size = indices_position.size(); std::vector real_indices; for (size_t i = 0; i < index_shape.back().as_int32(); ++i) { - indices_position[indices_position_size - 1] = ir::Cast::Make(common::Int(32), Expr(i)); - real_indices.push_back(ir::Cast::Make(common::Int(32), index(indices_position))); + indices_position[indices_position_size - 1] = + ir::Cast::Make(common::Int(32), Expr(i)); + real_indices.push_back( + ir::Cast::Make(common::Int(32), index(indices_position))); } if (real_indices.size() == x_shape_size) { return x(real_indices); @@ -81,45 +88,52 @@ ir::Tensor GatherNd(const ir::Tensor &x, const ir::Tensor &index, const std::str return res; } -std::shared_ptr StrategyForGatherNd(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForGatherNd( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { std::string op_name("gather_nd"); - framework::CINNCompute gather_nd_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of " << op_name << " compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 2U) << "2 input tensors for " << op_name << " compute\n"; - Expr x = pack_args[0]; - Expr index = pack_args[1]; - CHECK(x.as_tensor()); - CHECK(index.as_tensor()); - CHECK(!output_shapes.empty()); - auto tensor_x = x.as_tensor_ref(); - auto tensor_index = index.as_tensor_ref(); - auto stages = CreateStages({tensor_x, tensor_index}); - VLOG(3) << "x shape: " << utils::Join(tensor_x->shape, ", ") - << ", index shape: " << utils::Join(tensor_index->shape, ", ") - << ", output_shapes: " << utils::Join(output_shapes[0], ", "); - std::string tensor_name = UniqName("GatherNd_out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 3U); - tensor_name = pack_args[2].operator std::string(); - } - ir::Tensor out = GatherNd(tensor_x, tensor_index, tensor_name); - std::vector res; - stages->InsertLazily(out); - res.push_back(CINNValue(out)); - CHECK(!out_type.empty()) << "Output type of " << op_name << " is empty! Please check.\n"; - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); - - framework::CINNSchedule gather_nd_schedule([=](lang::Args args, lang::RetValue *ret) { + framework::CINNCompute gather_nd_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of " << op_name + << " compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 2U) + << "2 input tensors for " << op_name << " compute\n"; + Expr x = pack_args[0]; + Expr index = pack_args[1]; + CHECK(x.as_tensor()); + CHECK(index.as_tensor()); + CHECK(!output_shapes.empty()); + auto tensor_x = x.as_tensor_ref(); + auto tensor_index = index.as_tensor_ref(); + auto stages = CreateStages({tensor_x, tensor_index}); + VLOG(3) << "x shape: " << utils::Join(tensor_x->shape, ", ") + << ", index shape: " << utils::Join(tensor_index->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + std::string tensor_name = UniqName("GatherNd_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 3U); + tensor_name = pack_args[2].operator std::string(); + } + ir::Tensor out = GatherNd(tensor_x, tensor_index, tensor_name); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) + << "Output type of " << op_name << " is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule gather_nd_schedule([=](lang::Args args, + lang::RetValue *ret) { if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of gather_nd_schedule is empty! Please check.\n"; + CHECK(!args.empty()) << "The input argument of gather_nd_schedule is " + "empty! Please check.\n"; common::CINNValuePack arg_pack = args[0]; std::vector vec_ast; for (int i = 0; i < arg_pack.size(); i++) { @@ -132,45 +146,59 @@ std::shared_ptr StrategyForGatherNd(const framework::Node ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); - long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies()); + long prod_size = std::accumulate(output_shapes[0].begin(), + output_shapes[0].end(), + 1, + std::multiplies()); if (prod_size > 1) { if (target.arch == Target::Arch::NVGPU) { pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target); } else if (target.arch == Target::Arch::X86) { - pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true); + pe::IRScheduleInjectiveCPU( + ir_sch, output_shapes.front(), target, true); } } - std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = common::CINNValuePack{res}; } else { - CHECK(!args.empty()) << "The input argument of gather_nd_schedule is empty! Please check.\n"; + CHECK(!args.empty()) << "The input argument of gather_nd_schedule is " + "empty! Please check.\n"; CINNValuePack arg_pack = args[0]; - Expr out = arg_pack[0]; + Expr out = arg_pack[0]; CHECK(out.as_tensor()); *ret = arg_pack; } }); auto strategy = std::make_shared(); - strategy->AddImpl(gather_nd_compute, gather_nd_schedule, "strategy.gather_nd.x86", 1); + strategy->AddImpl( + gather_nd_compute, gather_nd_schedule, "strategy.gather_nd.x86", 1); return strategy; } -std::vector> InferShapeForGatherNd(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; - std::vector x_shape = inputs_shape[0]; +std::vector> InferShapeForGatherNd( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 2U) + << "The input's shape size should be 2! Please check again."; + std::vector x_shape = inputs_shape[0]; std::vector index_shape = inputs_shape[1]; CHECK_GE(index_shape.size(), 1U) << "Index shape must greater or equal to 1!"; - CHECK_LE(index_shape.back(), x_shape.size()) << "Index shape[-1] must be no more than x.rank! Please check again."; + CHECK_LE(index_shape.back(), x_shape.size()) + << "Index shape[-1] must be no more than x.rank! Please check again."; std::vector output_shape; - output_shape.insert(output_shape.end(), index_shape.begin(), index_shape.end() - 1); - output_shape.insert(output_shape.end(), x_shape.begin() + index_shape.back(), x_shape.end()); + output_shape.insert( + output_shape.end(), index_shape.begin(), index_shape.end() - 1); + output_shape.insert( + output_shape.end(), x_shape.begin() + index_shape.back(), x_shape.end()); return {output_shape}; } -std::vector InferDtypeForGatherNd(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForGatherNd(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } @@ -184,9 +212,12 @@ CINN_REGISTER_HELPER(gather_nd_ops) { .describe("GatherNd.") .set_num_inputs(2) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForGatherNd) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForGatherNd)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForGatherNd)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForGatherNd) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForGatherNd)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForGatherNd)) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/gather_nd.h b/paddle/cinn/hlir/op/contrib/gather_nd.h index 45bc23d215657..b236086f5ee01 100644 --- a/paddle/cinn/hlir/op/contrib/gather_nd.h +++ b/paddle/cinn/hlir/op/contrib/gather_nd.h @@ -25,7 +25,9 @@ namespace cinn { namespace hlir { namespace op { -ir::Tensor GatherNd(const ir::Tensor& x, const ir::Tensor& index, const std::string& name); +ir::Tensor GatherNd(const ir::Tensor& x, + const ir::Tensor& index, + const std::string& name); } // namespace op } // namespace hlir diff --git a/paddle/cinn/hlir/op/contrib/gather_nd_test.cc b/paddle/cinn/hlir/op/contrib/gather_nd_test.cc index 275a0f6de5d95..ee5f47477a5de 100644 --- a/paddle/cinn/hlir/op/contrib/gather_nd_test.cc +++ b/paddle/cinn/hlir/op/contrib/gather_nd_test.cc @@ -48,7 +48,14 @@ TEST(GenerateCode_Cpu, GatherNd) { poly::StageMap stages = poly::CreateStages({res}); std::vector funcs = - lang::LowerVec("TestGenerateCodeCpu_GatherNd", stages, {res}, {}, {}, nullptr, target, true); + lang::LowerVec("TestGenerateCodeCpu_GatherNd", + stages, + {res}, + {}, + {}, + nullptr, + target, + true); VLOG(6) << "Expr before CPU codegen:"; VLOG(6) << funcs[0]->body; @@ -60,7 +67,8 @@ TEST(GenerateCode_Cpu, GatherNd) { backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); codegen.SetInlineBuiltinCodes(false); - std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + std::string code = + codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); VLOG(6) << "Cpu Codegen result:"; VLOG(6) << code << std::endl; diff --git a/paddle/cinn/hlir/op/contrib/gaussian_random.cc b/paddle/cinn/hlir/op/contrib/gaussian_random.cc index 55478d51be4dc..ef1e30f7f3178 100644 --- a/paddle/cinn/hlir/op/contrib/gaussian_random.cc +++ b/paddle/cinn/hlir/op/contrib/gaussian_random.cc @@ -20,6 +20,7 @@ #include #include "absl/types/variant.h" +#include "glog/logging.h" #include "paddle/cinn/common/cas.h" #include "paddle/cinn/common/cinn_value.h" #include "paddle/cinn/common/common.h" @@ -43,7 +44,6 @@ #include "paddle/cinn/lang/compute.h" #include "paddle/cinn/lang/packed_func.h" #include "paddle/cinn/poly/stage.h" -#include "glog/logging.h" namespace cinn { namespace hlir { @@ -52,43 +52,49 @@ namespace op { using common::CINNValue; using common::CINNValuePack; -std::shared_ptr StrategyForGaussianRandom(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute gaussian_random_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(attrs.attr_store.count("shape")); - ir::Tensor shape_tensor; - std::string tensor_name = "gaussian_random_out"; - auto out = pe::Identity(shape_tensor, tensor_name).front(); - auto stages = CreateStages({out}); - std::vector res{CINNValue(out), CINNValue(stages)}; - *ret = CINNValuePack{res}; - }); +std::shared_ptr StrategyForGaussianRandom( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute gaussian_random_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(attrs.attr_store.count("shape")); + ir::Tensor shape_tensor; + std::string tensor_name = "gaussian_random_out"; + auto out = pe::Identity(shape_tensor, tensor_name).front(); + auto stages = CreateStages({out}); + std::vector res{CINNValue(out), CINNValue(stages)}; + *ret = CINNValuePack{res}; + }); auto strategy = std::make_shared(); - strategy->AddImpl( - gaussian_random_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.gaussian_random.x86", 1); + strategy->AddImpl(gaussian_random_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.gaussian_random.x86", + 1); return strategy; } -std::vector InferShapeForGaussianRandom(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForGaussianRandom( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { CHECK(attrs.count("shape")); auto shape = absl::get>(attrs.at("shape")); CHECK(!shape.empty()) << "shape attr is empty!"; return {shape}; } -std::vector InferDtypeForGaussianRandom(const std::vector &inputs_type, - const framework::AttrMapType &attrs) { +std::vector InferDtypeForGaussianRandom( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { std::string dtype = "float32"; if (attrs.find("dtype") != attrs.end()) { dtype = absl::get(attrs.at("dtype")); } std::vector res{common::Str2Type(dtype)}; CHECK(res[0].is_float(32) || res[0].is_float(64)) - << "gaussian_random only support float32 and float64, but here " << res[0] << "! Please check."; + << "gaussian_random only support float32 and float64, but here " << res[0] + << "! Please check."; return res; } @@ -101,10 +107,14 @@ CINN_REGISTER_HELPER(gaussian_random_ops) { .describe("GaussianRandom") .set_num_inputs(0) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForGaussianRandom) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForGaussianRandom)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForGaussianRandom)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForGaussianRandom) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForGaussianRandom)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForGaussianRandom)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/logical_right_shift.cc b/paddle/cinn/hlir/op/contrib/logical_right_shift.cc index 4ac382fc756f8..68f117aad4956 100644 --- a/paddle/cinn/hlir/op/contrib/logical_right_shift.cc +++ b/paddle/cinn/hlir/op/contrib/logical_right_shift.cc @@ -17,6 +17,7 @@ #include #include +#include "gflags/gflags.h" #include "paddle/cinn/common/cas.h" #include "paddle/cinn/common/common.h" #include "paddle/cinn/common/context.h" @@ -35,7 +36,6 @@ #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/lang/builtin.h" #include "paddle/cinn/lang/compute.h" -#include "gflags/gflags.h" DECLARE_bool(cinn_ir_schedule); @@ -81,36 +81,40 @@ ir::Tensor LogicalRightShift(const ir::Tensor &A, output_name); } -std::shared_ptr StrategyForLogicalRightShift(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForLogicalRightShift( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { std::string op_name("logical_right_shift"); - framework::CINNCompute logical_right_shift_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 2U) << "2 input tensors for " << op_name << " compute\n"; - - Expr A_expr = pack_args[0]; - Expr B_expr = pack_args[1]; - CHECK(A_expr.as_tensor()); - CHECK(B_expr.as_tensor()); - ir::Tensor A = A_expr.as_tensor_ref(); - ir::Tensor B = B_expr.as_tensor_ref(); - - std::string tensor_name = UniqName("T_LogicalRightShift_out"); - - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 3U); - tensor_name = pack_args[2].operator std::string(); - } - - auto out = LogicalRightShift(A, B, target, tensor_name); - auto stages = CreateStages({out}); - *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; - }); + framework::CINNCompute logical_right_shift_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of " << op_name + << " compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 2U) + << "2 input tensors for " << op_name << " compute\n"; + + Expr A_expr = pack_args[0]; + Expr B_expr = pack_args[1]; + CHECK(A_expr.as_tensor()); + CHECK(B_expr.as_tensor()); + ir::Tensor A = A_expr.as_tensor_ref(); + ir::Tensor B = B_expr.as_tensor_ref(); + + std::string tensor_name = UniqName("T_LogicalRightShift_out"); + + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 3U); + tensor_name = pack_args[2].operator std::string(); + } + + auto out = LogicalRightShift(A, B, target, tensor_name); + auto stages = CreateStages({out}); + *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; + }); auto strategy = std::make_shared(); strategy->AddImpl(logical_right_shift_compute, @@ -120,20 +124,25 @@ std::shared_ptr StrategyForLogicalRightShift(const framework::NodeAt return strategy; } -std::vector InferShapeForLogicalRightShift(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; - CHECK_EQ(inputs_shape[0].size(), inputs_shape[1].size()) << "The inputs' dims should be equal."; +std::vector InferShapeForLogicalRightShift( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 2U) + << "The input's shape size should be 2! Please check again."; + CHECK_EQ(inputs_shape[0].size(), inputs_shape[1].size()) + << "The inputs' dims should be equal."; std::vector res{inputs_shape[0]}; return res; } -std::vector InferDtypeForLogicalRightShift(const std::vector &inputs_type, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 2UL) << "The logical_right_shift op should has two inputs! Please check."; +std::vector InferDtypeForLogicalRightShift( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 2UL) + << "The logical_right_shift op should has two inputs! Please check."; CHECK_EQ(inputs_type[0], inputs_type[1]) - << "The data type of input tensors of logical_right_shift op should be equal, but here x:" << inputs_type[0] - << " != y:" << inputs_type[1] << "! Please check."; + << "The data type of input tensors of logical_right_shift op should be " + "equal, but here x:" + << inputs_type[0] << " != y:" << inputs_type[1] << "! Please check."; std::vector res{inputs_type[0]}; return res; } @@ -147,10 +156,14 @@ CINN_REGISTER_HELPER(logical_right_shift_ops) { .describe("Logical Right Shift.") .set_num_inputs(2) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForLogicalRightShift) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForLogicalRightShift)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForLogicalRightShift)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForLogicalRightShift) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForLogicalRightShift)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForLogicalRightShift)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/logical_right_shift_test.cc b/paddle/cinn/hlir/op/contrib/logical_right_shift_test.cc index 108fa018b3523..1931502216711 100644 --- a/paddle/cinn/hlir/op/contrib/logical_right_shift_test.cc +++ b/paddle/cinn/hlir/op/contrib/logical_right_shift_test.cc @@ -42,7 +42,14 @@ TEST(GenerateCode_Cpu, LogicalRightShift) { poly::StageMap stages = poly::CreateStages({res}); std::vector funcs = - lang::LowerVec("TestGenerateCodeCpu_LogicalRightShift", stages, {res}, {}, {}, nullptr, target, true); + lang::LowerVec("TestGenerateCodeCpu_LogicalRightShift", + stages, + {res}, + {}, + {}, + nullptr, + target, + true); VLOG(6) << "Expr before CPU codegen:"; VLOG(6) << funcs[0]->body; @@ -54,7 +61,8 @@ TEST(GenerateCode_Cpu, LogicalRightShift) { backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); codegen.SetInlineBuiltinCodes(false); - std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + std::string code = + codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); VLOG(6) << "Cpu Codegen result:"; VLOG(6) << code << std::endl; } diff --git a/paddle/cinn/hlir/op/contrib/lookup_table.cc b/paddle/cinn/hlir/op/contrib/lookup_table.cc index dcd11361644cb..037812ab46348 100644 --- a/paddle/cinn/hlir/op/contrib/lookup_table.cc +++ b/paddle/cinn/hlir/op/contrib/lookup_table.cc @@ -19,6 +19,7 @@ #include #include +#include "gflags/gflags.h" #include "paddle/cinn/common/cas.h" #include "paddle/cinn/common/common.h" #include "paddle/cinn/common/context.h" @@ -37,7 +38,6 @@ #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/lang/builtin.h" #include "paddle/cinn/lang/compute.h" -#include "gflags/gflags.h" DECLARE_bool(cinn_ir_schedule); namespace cinn { @@ -53,7 +53,7 @@ ir::Tensor LookupTable(const ir::Tensor& table, const std::string& output_name) { CHECK_EQ(table->shape.size(), 2); CHECK_GT(ids->shape.size(), 1); - auto output_shape = ids->shape; + auto output_shape = ids->shape; output_shape.back() = table->shape.back(); return lang::Compute( @@ -64,29 +64,37 @@ ir::Tensor LookupTable(const ir::Tensor& table, offsets.emplace_back(indices[i]); } offsets.emplace_back(Expr(0)); - // Because the current conversion rules have not been completed, static conversion is done here. + // Because the current conversion rules have not been completed, static + // conversion is done here. auto ids_offset = ir::Cast::Make(common::I32(), ids(offsets)); - auto pred = - ir::And::Make(Expr(padding_idx != -1), ir::EQ::Make(ids_offset, Expr(static_cast(padding_idx)))); - return ir::Select::Make(pred, ir::Cast::Make(table->type(), Expr(0)), table(ids_offset, indices.back())); + auto pred = ir::And::Make( + Expr(padding_idx != -1), + ir::EQ::Make(ids_offset, Expr(static_cast(padding_idx)))); + return ir::Select::Make(pred, + ir::Cast::Make(table->type(), Expr(0)), + table(ids_offset, indices.back())); }, common::UniqName(output_name)); } -std::shared_ptr StrategyForLookupTable(const framework::NodeAttr& attrs, - const std::vector& inputs, - const std::vector& out_type, - const std::vector>& output_shapes, - const Target& target) { +std::shared_ptr StrategyForLookupTable( + const framework::NodeAttr& attrs, + const std::vector& inputs, + const std::vector& out_type, + const std::vector>& output_shapes, + const Target& target) { std::string op_name("lookup_table"); const auto& attr_store = attrs.attr_store; CHECK(attr_store.count("padding_idx")) << "find no attr of axis"; auto padding_idx = absl::get(attr_store.at("padding_idx")); - framework::CINNCompute lookup_table_compute([=](lang::Args args, lang::RetValue* ret) { - CHECK(!args.empty()) << "The input arguments of " << op_name << " compute is empty! Please check.\n"; + framework::CINNCompute lookup_table_compute([=](lang::Args args, + lang::RetValue* ret) { + CHECK(!args.empty()) << "The input arguments of " << op_name + << " compute is empty! Please check.\n"; CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 2U) << "2 input tensors for " << op_name << " compute\n"; + CHECK_GE(pack_args.size(), 2U) + << "2 input tensors for " << op_name << " compute\n"; Expr A = pack_args[0]; Expr B = pack_args[1]; CHECK(A.as_tensor()); @@ -94,8 +102,9 @@ std::shared_ptr StrategyForLookupTable(const framework::N CHECK(!output_shapes.empty()); auto tensor_A = A.as_tensor_ref(); auto tensor_B = B.as_tensor_ref(); - auto stages = CreateStages({tensor_A, tensor_B}); - VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") << ", B shape: " << utils::Join(tensor_B->shape, ", ") + auto stages = CreateStages({tensor_A, tensor_B}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") + << ", B shape: " << utils::Join(tensor_B->shape, ", ") << ", output_shapes: " << utils::Join(output_shapes[0], ", "); std::string tensor_name = UniqName("LookupTable_out"); if (FLAGS_cinn_ir_schedule) { @@ -106,27 +115,35 @@ std::shared_ptr StrategyForLookupTable(const framework::N std::vector res; stages->InsertLazily(out); res.push_back(CINNValue(out)); - CHECK(!out_type.empty()) << "Output type of " << op_name << " is empty! Please check.\n"; + CHECK(!out_type.empty()) + << "Output type of " << op_name << " is empty! Please check.\n"; res.push_back(CINNValue(stages)); *ret = CINNValuePack{res}; }); auto strategy = std::make_shared(); - strategy->AddImpl(lookup_table_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.lookup_table", 1); + strategy->AddImpl(lookup_table_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.lookup_table", + 1); return strategy; } -std::vector InferShapeForLookupTable(const std::vector& inputs_shape, - const framework::AttrMapType& attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; +std::vector InferShapeForLookupTable( + const std::vector& inputs_shape, + const framework::AttrMapType& attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) + << "The input's shape size is 0! Please check again."; - auto res = inputs_shape[1]; + auto res = inputs_shape[1]; res.back() = inputs_shape[0].back(); return {res}; } -std::vector InferDtypeForLookupTable(const std::vector& inputs_type, const framework::AttrMapType& attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForLookupTable( + const std::vector& inputs_type, const framework::AttrMapType& attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } @@ -140,9 +157,13 @@ CINN_REGISTER_HELPER(lookup_table_ops) { .describe("Lookup table Operator.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForLookupTable) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForLookupTable)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForLookupTable)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kInjective); + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForLookupTable) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForLookupTable)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForLookupTable)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kInjective); return true; } diff --git a/paddle/cinn/hlir/op/contrib/lookup_table_test.cc b/paddle/cinn/hlir/op/contrib/lookup_table_test.cc index 984e6024b6572..d09d4238f6268 100644 --- a/paddle/cinn/hlir/op/contrib/lookup_table_test.cc +++ b/paddle/cinn/hlir/op/contrib/lookup_table_test.cc @@ -44,7 +44,14 @@ TEST(GenerateCode_Cpu, LookupTable) { poly::StageMap stages = poly::CreateStages({res}); std::vector funcs = - lang::LowerVec("TestGenerateCodeCpu_LookupTable", stages, {res}, {}, {}, nullptr, target, true); + lang::LowerVec("TestGenerateCodeCpu_LookupTable", + stages, + {res}, + {}, + {}, + nullptr, + target, + true); VLOG(6) << "Expr before CPU codegen:"; VLOG(6) << funcs[0]->body; @@ -56,7 +63,8 @@ TEST(GenerateCode_Cpu, LookupTable) { backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); codegen.SetInlineBuiltinCodes(false); - std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + std::string code = + codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); VLOG(6) << "codegen code: " << code; } @@ -78,7 +86,14 @@ TEST(GenerateCode_Gpu, LookupTable) { stages[res]->SetBuffer("global"); std::vector funcs = - lang::LowerVec("TestGenerateCodeCuda_LookupTable", stages, {res}, {}, {}, nullptr, target, true); + lang::LowerVec("TestGenerateCodeCuda_LookupTable", + stages, + {res}, + {}, + {}, + nullptr, + target, + true); VLOG(6) << "Expr before CUDA codegen:"; VLOG(6) << funcs[0]->body; diff --git a/paddle/cinn/hlir/op/contrib/one_hot.cc b/paddle/cinn/hlir/op/contrib/one_hot.cc index a782481069c50..111ea234df7e2 100755 --- a/paddle/cinn/hlir/op/contrib/one_hot.cc +++ b/paddle/cinn/hlir/op/contrib/one_hot.cc @@ -55,13 +55,16 @@ ir::Tensor OneHot(const ir::Tensor& indices, const Type& dtype, const std::string& output_name) { int ndim = static_cast(indices->shape.size()); - CHECK(axis == -1 || (0 <= axis && axis <= ndim)) << "one_hot only accepts `axis` in [-1, data.ndim]" - << ", but got axis = " << axis << ", and data.ndim = " << ndim; + CHECK(axis == -1 || (0 <= axis && axis <= ndim)) + << "one_hot only accepts `axis` in [-1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; CHECK(depth > 0) << "one_hot only accepts `depth > 0`" << ", but got depth = " << depth; - CHECK(on_value->shape.size() == 1U && on_value->shape[0].as_int32() == 1U) << "The shape of on_value must be [1]"; - CHECK(off_value->shape.size() == 1U && off_value->shape[0].as_int32() == 1U) << "The shape of off_value must be [1]"; + CHECK(on_value->shape.size() == 1U && on_value->shape[0].as_int32() == 1U) + << "The shape of on_value must be [1]"; + CHECK(off_value->shape.size() == 1U && off_value->shape[0].as_int32() == 1U) + << "The shape of off_value must be [1]"; int true_axis = (axis == -1) ? ndim : axis; std::vector new_shape; @@ -75,7 +78,7 @@ ir::Tensor OneHot(const ir::Tensor& indices, } } - Expr on_value_cast = ir::Cast::Make(dtype, on_value(Expr(0))); + Expr on_value_cast = ir::Cast::Make(dtype, on_value(Expr(0))); Expr off_value_cast = ir::Cast::Make(dtype, off_value(Expr(0))); ir::Tensor res = lang::Compute( @@ -90,18 +93,21 @@ ir::Tensor OneHot(const ir::Tensor& indices, indices_indices.push_back(iter[i]); } - Expr idx = iter[true_axis]; + Expr idx = iter[true_axis]; Expr elem = ir::Cast::Make(idx.type(), indices(indices_indices)); - return ir::Select::Make(ir::EQ::Make(elem, idx), on_value_cast, off_value_cast); + return ir::Select::Make( + ir::EQ::Make(elem, idx), on_value_cast, off_value_cast); }, common::UniqName(output_name)); return res; } -std::vector InferShapeForOneHot(const std::vector& inputs_shape, - const framework::AttrMapType& attrs) { - CHECK_EQ(inputs_shape.size(), 3UL) << "The number of one_hot's input should be 3"; +std::vector InferShapeForOneHot( + const std::vector& inputs_shape, + const framework::AttrMapType& attrs) { + CHECK_EQ(inputs_shape.size(), 3UL) + << "The number of one_hot's input should be 3"; int depth; int axis; @@ -115,9 +121,9 @@ std::vector InferShapeForOneHot(const std::vector& in_shape = inputs_shape[0]; - int ndim = static_cast(in_shape.size()); - int true_axis = (axis == -1) ? in_shape.size() : axis; - int indices_index = 0; + int ndim = static_cast(in_shape.size()); + int true_axis = (axis == -1) ? in_shape.size() : axis; + int indices_index = 0; std::vector new_shape; for (int i = 0; i < ndim + 1; ++i) { @@ -132,8 +138,10 @@ std::vector InferShapeForOneHot(const std::vector InferDtypeForOneHot(const std::vector& inputs_type, const framework::AttrMapType& attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForOneHot(const std::vector& inputs_type, + const framework::AttrMapType& attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::string dtype = "float32"; if (attrs.find("dtype") != attrs.end()) { @@ -144,11 +152,12 @@ std::vector InferDtypeForOneHot(const std::vector& inputs_type, cons return res; } -std::shared_ptr StrategyForOneHot(const framework::NodeAttr& attrs, - const std::vector& inputs, - const std::vector& out_type, - const std::vector>& output_shapes, - const Target& target) { +std::shared_ptr StrategyForOneHot( + const framework::NodeAttr& attrs, + const std::vector& inputs, + const std::vector& out_type, + const std::vector>& output_shapes, + const Target& target) { int depth; int axis; std::string dtype = "float32"; @@ -166,20 +175,23 @@ std::shared_ptr StrategyForOneHot(const framework::NodeAt CHECK(depth > 0) << "one_hot only accepts `depth > 0`" << ", but got depth = " << depth; - framework::CINNCompute one_hot_compute([=](lang::Args args, lang::RetValue* ret) { - CHECK(!args.empty()) << "The input argument of one_hot compute is empty! Please check.\n"; + framework::CINNCompute one_hot_compute([=](lang::Args args, + lang::RetValue* ret) { + CHECK(!args.empty()) + << "The input argument of one_hot compute is empty! Please check.\n"; common::CINNValuePack pack_args = args[0]; - CHECK(!pack_args.empty()) << "at least one input tensor for transpose compute\n"; + CHECK(!pack_args.empty()) + << "at least one input tensor for transpose compute\n"; CHECK_GE(pack_args.size(), 3U); - Expr indices_expr = pack_args[0]; - Expr on_value_expr = pack_args[1]; + Expr indices_expr = pack_args[0]; + Expr on_value_expr = pack_args[1]; Expr off_value_expr = pack_args[2]; CHECK(indices_expr.as_tensor()); CHECK(on_value_expr.as_tensor()); CHECK(off_value_expr.as_tensor()); - ir::Tensor indices = indices_expr.as_tensor_ref(); - ir::Tensor on_value = on_value_expr.as_tensor_ref(); + ir::Tensor indices = indices_expr.as_tensor_ref(); + ir::Tensor on_value = on_value_expr.as_tensor_ref(); ir::Tensor off_value = off_value_expr.as_tensor_ref(); std::string tensor_name = common::UniqName("T_OneHot_out"); @@ -189,7 +201,13 @@ std::shared_ptr StrategyForOneHot(const framework::NodeAt tensor_name = pack_args[3].operator std::string(); } - ir::Tensor out = OneHot(indices, on_value, off_value, depth, axis, common::Str2Type(dtype), tensor_name); + ir::Tensor out = OneHot(indices, + on_value, + off_value, + depth, + axis, + common::Str2Type(dtype), + tensor_name); std::vector res; auto stages = CreateStages({indices, on_value, off_value}); @@ -200,7 +218,10 @@ std::shared_ptr StrategyForOneHot(const framework::NodeAt }); auto strategy = std::make_shared(); - strategy->AddImpl(one_hot_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.one_hot.x86", 1); + strategy->AddImpl(one_hot_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.one_hot.x86", + 1); return strategy; } @@ -211,14 +232,19 @@ std::shared_ptr StrategyForOneHot(const framework::NodeAt CINN_REGISTER_HELPER(one_hot_ops) { CINN_REGISTER_OP(one_hot) .describe( - "Returns a one-hot tensor where the locations repsented by indices take value `on_value`, " + "Returns a one-hot tensor where the locations repsented by indices " + "take value `on_value`, " "other locations take value `off_value`.") .set_num_inputs(3) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForOneHot) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForOneHot)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForOneHot)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForOneHot) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForOneHot)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForOneHot)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/one_hot_test.cc b/paddle/cinn/hlir/op/contrib/one_hot_test.cc index fc96d75179206..572172de0ab41 100644 --- a/paddle/cinn/hlir/op/contrib/one_hot_test.cc +++ b/paddle/cinn/hlir/op/contrib/one_hot_test.cc @@ -39,19 +39,32 @@ TEST(GenerateCode_Cpu, OneHot) { Expr m(4); Expr n(4); - const int depth = 3; - const int axis = 1; + const int depth = 3; + const int axis = 1; const std::string dtype = "float32"; lang::Placeholder in("in", {m, n}); lang::Placeholder on_value("on_value", {Expr(1)}); lang::Placeholder off_value("off_value", {Expr(1)}); - ir::Tensor res = OneHot(in, on_value, off_value, depth, axis, common::Str2Type(dtype), "test_one_hot"); + ir::Tensor res = OneHot(in, + on_value, + off_value, + depth, + axis, + common::Str2Type(dtype), + "test_one_hot"); poly::StageMap stages = poly::CreateStages({res}); std::vector funcs = - lang::LowerVec("TestGenerateCodeCpu_OneHot", stages, {res}, {}, {}, nullptr, target, true); + lang::LowerVec("TestGenerateCodeCpu_OneHot", + stages, + {res}, + {}, + {}, + nullptr, + target, + true); VLOG(6) << "Expr before CPU codegen:"; VLOG(6) << funcs[0]->body; @@ -63,7 +76,8 @@ TEST(GenerateCode_Cpu, OneHot) { backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); codegen.SetInlineBuiltinCodes(false); - std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + std::string code = + codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); VLOG(6) << "Cpu Codegen result:"; VLOG(6) << code << std::endl; diff --git a/paddle/cinn/hlir/op/contrib/randint.cc b/paddle/cinn/hlir/op/contrib/randint.cc index a39c89458fda5..7c6c1a0582107 100644 --- a/paddle/cinn/hlir/op/contrib/randint.cc +++ b/paddle/cinn/hlir/op/contrib/randint.cc @@ -20,6 +20,7 @@ #include #include "absl/types/variant.h" +#include "glog/logging.h" #include "paddle/cinn/common/cas.h" #include "paddle/cinn/common/cinn_value.h" #include "paddle/cinn/common/common.h" @@ -43,7 +44,6 @@ #include "paddle/cinn/lang/compute.h" #include "paddle/cinn/lang/packed_func.h" #include "paddle/cinn/poly/stage.h" -#include "glog/logging.h" namespace cinn { namespace hlir { @@ -52,34 +52,41 @@ namespace op { using common::CINNValue; using common::CINNValuePack; -std::shared_ptr StrategyForRandInt(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute randint_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(attrs.attr_store.count("shape")); - ir::Tensor shape_tensor; - std::string tensor_name = "randint_out"; - auto out = pe::Identity(shape_tensor, tensor_name).front(); - auto stages = CreateStages({out}); - std::vector res{CINNValue(out), CINNValue(stages)}; - *ret = CINNValuePack{res}; - }); +std::shared_ptr StrategyForRandInt( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute randint_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(attrs.attr_store.count("shape")); + ir::Tensor shape_tensor; + std::string tensor_name = "randint_out"; + auto out = pe::Identity(shape_tensor, tensor_name).front(); + auto stages = CreateStages({out}); + std::vector res{CINNValue(out), CINNValue(stages)}; + *ret = CINNValuePack{res}; + }); auto strategy = std::make_shared(); - strategy->AddImpl(randint_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.randint.x86", 1); + strategy->AddImpl(randint_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.randint.x86", + 1); return strategy; } -std::vector InferShapeForRandInt(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForRandInt( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { CHECK(attrs.count("shape")); auto shape = absl::get>(attrs.at("shape")); CHECK(!shape.empty()) << "shape attr is empty!"; return {shape}; } -std::vector InferDtypeForRandInt(const std::vector &inputs_type, const framework::AttrMapType &attrs) { +std::vector InferDtypeForRandInt(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { std::string dtype = "int32"; std::vector res{common::Str2Type(dtype)}; return res; @@ -94,10 +101,14 @@ CINN_REGISTER_HELPER(randint_ops) { .describe("RandInt") .set_num_inputs(0) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForRandInt) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForRandInt)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForRandInt)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForRandInt) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForRandInt)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForRandInt)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/reciprocal.cc b/paddle/cinn/hlir/op/contrib/reciprocal.cc index 262aadc8b3d74..09d180601ac10 100644 --- a/paddle/cinn/hlir/op/contrib/reciprocal.cc +++ b/paddle/cinn/hlir/op/contrib/reciprocal.cc @@ -17,6 +17,7 @@ #include #include +#include "gflags/gflags.h" #include "paddle/cinn/common/cas.h" #include "paddle/cinn/common/common.h" #include "paddle/cinn/common/context.h" @@ -35,7 +36,6 @@ #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/lang/builtin.h" #include "paddle/cinn/lang/compute.h" -#include "gflags/gflags.h" DECLARE_bool(cinn_ir_schedule); @@ -77,62 +77,74 @@ ir::Tensor Reciprocal(const ir::Tensor &input, const std::string &output_name) { output_name)}; } -std::shared_ptr StrategyForReciprocal(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForReciprocal( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { std::string op_name("reciprocal"); - framework::CINNCompute reciprocal_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - CHECK(!pack_args.empty()) << "at least one input tensor for " << op_name << " compute\n"; - - std::string tensor_name = UniqName("Reciprocal_out"); - - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); - } - - Expr A = pack_args[0]; - CHECK(A.as_tensor()); - CHECK(!output_shapes.empty()); - auto tensor_A = A.as_tensor_ref(); - auto stages = CreateStages({tensor_A}); - VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") - << ", output_shapes: " << utils::Join(output_shapes[0], ", "); - - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2U); - tensor_name = pack_args[1].operator std::string(); - } - - ir::Tensor out = Reciprocal(tensor_A, tensor_name); - std::vector res; - stages->InsertLazily(out); - res.push_back(CINNValue(out)); - CHECK(!out_type.empty()) << "Output type of Reciprocal is empty! Please check.\n"; - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); + framework::CINNCompute reciprocal_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of " << op_name + << " compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK(!pack_args.empty()) + << "at least one input tensor for " << op_name << " compute\n"; + + std::string tensor_name = UniqName("Reciprocal_out"); + + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } + + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + CHECK(!output_shapes.empty()); + auto tensor_A = A.as_tensor_ref(); + auto stages = CreateStages({tensor_A}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2U); + tensor_name = pack_args[1].operator std::string(); + } + + ir::Tensor out = Reciprocal(tensor_A, tensor_name); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) + << "Output type of Reciprocal is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); auto strategy = std::make_shared(); - strategy->AddImpl(reciprocal_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.reciprocal.x86", 1); + strategy->AddImpl(reciprocal_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.reciprocal.x86", + 1); return strategy; } -std::vector InferShapeForReciprocal(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK(!inputs_shape.empty()) << "The input's shape size is empty! Please check again."; +std::vector InferShapeForReciprocal( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(!inputs_shape.empty()) + << "The input's shape size is empty! Please check again."; std::vector res{inputs_shape[0]}; return res; } -std::vector InferDtypeForReciprocal(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForReciprocal(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } @@ -146,10 +158,14 @@ CINN_REGISTER_HELPER(reciprocal_ops) { .describe("Counting Leading Zeros.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForReciprocal) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForReciprocal)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForReciprocal)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForReciprocal) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForReciprocal)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForReciprocal)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/reciprocal.h b/paddle/cinn/hlir/op/contrib/reciprocal.h index 0011b6840f3a4..447af81737600 100644 --- a/paddle/cinn/hlir/op/contrib/reciprocal.h +++ b/paddle/cinn/hlir/op/contrib/reciprocal.h @@ -23,7 +23,8 @@ namespace cinn { namespace hlir { namespace op { -ir::Tensor Reciprocal(const ir::Tensor& input, const std::string& name = "T_Reciprocal_out"); +ir::Tensor Reciprocal(const ir::Tensor& input, + const std::string& name = "T_Reciprocal_out"); } // namespace op } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/op/contrib/reciprocal_test.cc b/paddle/cinn/hlir/op/contrib/reciprocal_test.cc index e5a4cc20ad75d..c23afb3e7fd46 100644 --- a/paddle/cinn/hlir/op/contrib/reciprocal_test.cc +++ b/paddle/cinn/hlir/op/contrib/reciprocal_test.cc @@ -46,7 +46,14 @@ TEST(GenerateCode_Cpu, Reciprocal) { poly::StageMap stages = poly::CreateStages({res}); std::vector funcs = - lang::LowerVec("TestGenerateCodeCpu_Reciprocal", stages, {res}, {}, {}, nullptr, target, true); + lang::LowerVec("TestGenerateCodeCpu_Reciprocal", + stages, + {res}, + {}, + {}, + nullptr, + target, + true); VLOG(6) << "Expr before CPU codegen:"; VLOG(6) << funcs[0]->body; @@ -58,7 +65,8 @@ TEST(GenerateCode_Cpu, Reciprocal) { backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); codegen.SetInlineBuiltinCodes(false); - std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + std::string code = + codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); VLOG(6) << "Cpu Codegen result:"; VLOG(6) << code << std::endl; } diff --git a/paddle/cinn/hlir/op/contrib/repeat.cc b/paddle/cinn/hlir/op/contrib/repeat.cc index 2c770af2164e4..90608fbe36a36 100755 --- a/paddle/cinn/hlir/op/contrib/repeat.cc +++ b/paddle/cinn/hlir/op/contrib/repeat.cc @@ -46,10 +46,14 @@ namespace op { using common::CINNValuePack; -std::vector Repeat(const ir::Tensor &tensor, int repeats, int axis, const std::string &output_name) { +std::vector Repeat(const ir::Tensor &tensor, + int repeats, + int axis, + const std::string &output_name) { int ndim = static_cast(tensor->shape.size()); - CHECK(-ndim - 1 <= axis && axis <= ndim) << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis << ", and data.ndim = " << ndim; + CHECK(-ndim - 1 <= axis && axis <= ndim) + << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; CHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`" << ", but got repeats = " << repeats; @@ -83,15 +87,17 @@ std::vector Repeat(const ir::Tensor &tensor, int repeats, int axis, return {res}; } -std::vector> InferShapeForRepeat(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 1U) << "The input's shape size should be 1! Please check again."; +std::vector> InferShapeForRepeat( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 1U) + << "The input's shape size should be 1! Please check again."; int repeats = 0; - int axis = 0; + int axis = 0; std::vector new_shape; const std::vector &tensor_shape = inputs_shape[0]; - int ndim = static_cast(tensor_shape.size()); + int ndim = static_cast(tensor_shape.size()); if (attrs.find("repeats") != attrs.end()) { repeats = absl::get(attrs.at("repeats")); @@ -117,19 +123,22 @@ std::vector> InferShapeForRepeat(const std::vector InferDtypeForRepeat(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForRepeat(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::shared_ptr StrategyForRepeat(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForRepeat( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { int repeats = 0; - int axis = 0; + int axis = 0; for (auto &iter : attrs.attr_store) { if (iter.first == "repeats") { repeats = absl::get(iter.second); @@ -141,10 +150,13 @@ std::shared_ptr StrategyForRepeat(const framework::NodeAt CHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`" << ", but got repeats = " << repeats; - framework::CINNCompute repeat_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of Repeat compute is empty! Please check.\n"; + framework::CINNCompute repeat_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of Repeat compute is empty! Please check.\n"; CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 1U) << "at least 1 input tensors for Repeat compute\n"; + CHECK_GE(pack_args.size(), 1U) + << "at least 1 input tensors for Repeat compute\n"; Expr A = pack_args[0]; CHECK(A.as_tensor()); CHECK(!output_shapes.empty()); @@ -172,9 +184,11 @@ std::shared_ptr StrategyForRepeat(const framework::NodeAt *ret = common::CINNValuePack{res}; }); - framework::CINNSchedule repeat_schedule([=](lang::Args args, lang::RetValue *ret) { + framework::CINNSchedule repeat_schedule([=](lang::Args args, + lang::RetValue *ret) { if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of repeat schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of repeat schedule is empty! Please check.\n"; common::CINNValuePack arg_pack = args[0]; std::vector vec_ast; for (int i = 0; i < arg_pack.size(); i++) { @@ -187,20 +201,26 @@ std::shared_ptr StrategyForRepeat(const framework::NodeAt ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); - long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies()); + long prod_size = std::accumulate(output_shapes[0].begin(), + output_shapes[0].end(), + 1, + std::multiplies()); if (prod_size > 1) { if (target.arch == Target::Arch::NVGPU) { pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target); } else if (target.arch == Target::Arch::X86) { - pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true); + pe::IRScheduleInjectiveCPU( + ir_sch, output_shapes.front(), target, true); } } - std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = common::CINNValuePack{res}; } else { - CHECK(!args.empty()) << "The input argument of repeat schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of repeat schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; - Expr out = arg_pack[0]; + Expr out = arg_pack[0]; CHECK(out.as_tensor()); *ret = arg_pack; } @@ -221,9 +241,12 @@ CINN_REGISTER_HELPER(repeat_ops) { .describe("Repeat elements of an array `repeats` times along axis `axis`") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForRepeat) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForRepeat)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForRepeat)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForRepeat) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForRepeat)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForRepeat)) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/repeat.h b/paddle/cinn/hlir/op/contrib/repeat.h index 2a23d03ba2c29..0f3fc8bcc0ee2 100644 --- a/paddle/cinn/hlir/op/contrib/repeat.h +++ b/paddle/cinn/hlir/op/contrib/repeat.h @@ -25,7 +25,10 @@ namespace cinn { namespace hlir { namespace op { -std::vector Repeat(const ir::Tensor &tensor, int repeats, int axis, const std::string &output_name); +std::vector Repeat(const ir::Tensor &tensor, + int repeats, + int axis, + const std::string &output_name); } // namespace op } // namespace hlir diff --git a/paddle/cinn/hlir/op/contrib/repeat_test.cc b/paddle/cinn/hlir/op/contrib/repeat_test.cc index f21dbd031614d..a5abd5bb75804 100755 --- a/paddle/cinn/hlir/op/contrib/repeat_test.cc +++ b/paddle/cinn/hlir/op/contrib/repeat_test.cc @@ -40,15 +40,15 @@ TEST(GenerateCode_Cpu, Repeat) { ir::Expr m(4); ir::Expr n(4); const int repeats = 2; - const int axis = 0; + const int axis = 0; lang::Placeholder in("in", {m, n}); std::vector res = Repeat(in, repeats, axis, "test_repeat"); poly::StageMap stages = poly::CreateStages({res}); - std::vector funcs = - lang::LowerVec("TestGenerateCodeCpu_Repeat", stages, res, {}, {}, nullptr, target, true); + std::vector funcs = lang::LowerVec( + "TestGenerateCodeCpu_Repeat", stages, res, {}, {}, nullptr, target, true); VLOG(6) << "Expr before CPU codegen:"; VLOG(6) << funcs[0]->body; @@ -82,7 +82,8 @@ function TestGenerateCodeCpu_Repeat (_test_repeat) backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); codegen.SetInlineBuiltinCodes(false); - std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + std::string code = + codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); VLOG(6) << "Cpu Codegen result:"; VLOG(6) << code << std::endl; diff --git a/paddle/cinn/hlir/op/contrib/resize.cc b/paddle/cinn/hlir/op/contrib/resize.cc index b1f31a6f9bb31..48c5f81f23686 100644 --- a/paddle/cinn/hlir/op/contrib/resize.cc +++ b/paddle/cinn/hlir/op/contrib/resize.cc @@ -45,11 +45,13 @@ namespace op { using common::CINNValuePack; -#define __get_pixel(input, h, w, n, c, y, x) \ - input({n, \ - c, \ - common::AutoSimplify(ir::Max::Make(ir::Min::Make(y, h - Expr(1)), Expr(0))), \ - common::AutoSimplify(ir::Max::Make(ir::Min::Make(x, w - Expr(1)), Expr(0)))}) +#define __get_pixel(input, h, w, n, c, y, x) \ + input({n, \ + c, \ + common::AutoSimplify( \ + ir::Max::Make(ir::Min::Make(y, h - Expr(1)), Expr(0))), \ + common::AutoSimplify( \ + ir::Max::Make(ir::Min::Make(x, w - Expr(1)), Expr(0)))}) ir::Tensor Resize(const ir::Tensor &input, const common::Target &target, @@ -72,13 +74,14 @@ ir::Tensor Resize(const ir::Tensor &input, func_name.append("bicubic"); } - Expr in_h = input->shape[2]; - Expr in_w = input->shape[3]; + Expr in_h = input->shape[2]; + Expr in_w = input->shape[3]; Expr out_h = Expr(out_shape[0]); Expr out_w = Expr(out_shape[1]); - std::vector new_shape = {input->shape[0], input->shape[1], out_h, out_w}; - ir::Tensor res = lang::Compute( + std::vector new_shape = { + input->shape[0], input->shape[1], out_h, out_w}; + ir::Tensor res = lang::Compute( {new_shape}, [=](const std::vector &indices) { Expr out_y = indices[2]; @@ -86,22 +89,43 @@ ir::Tensor Resize(const ir::Tensor &input, Expr value; if (mode == "nearest") { - Expr in_y = ir::Cast::Make(common::F32(), in_h) / ir::Cast::Make(common::F32(), out_h) * + Expr in_y = ir::Cast::Make(common::F32(), in_h) / + ir::Cast::Make(common::F32(), out_h) * ir::Cast::Make(common::F32(), out_y); - Expr in_x = ir::Cast::Make(common::F32(), in_w) / ir::Cast::Make(common::F32(), out_w) * + Expr in_x = ir::Cast::Make(common::F32(), in_w) / + ir::Cast::Make(common::F32(), out_w) * ir::Cast::Make(common::F32(), out_x); - Expr in_y_int = ir::Cast::Make(common::Int(32), lang::Floor(in_y)); - Expr in_x_int = ir::Cast::Make(common::Int(32), lang::Floor(in_x)); - std::vector in_indices = {indices[0], indices[1], in_y_int, in_x_int}; - value = input(in_indices); + Expr in_y_int = ir::Cast::Make(common::Int(32), lang::Floor(in_y)); + Expr in_x_int = ir::Cast::Make(common::Int(32), lang::Floor(in_x)); + std::vector in_indices = { + indices[0], indices[1], in_y_int, in_x_int}; + value = input(in_indices); } else if (mode == "bilinear") { - value = lang::CallExtern( - func_name, {input, input->shape[1], in_h, in_w, out_h, out_w, indices[0], indices[1], out_y, out_x}); + value = lang::CallExtern(func_name, + {input, + input->shape[1], + in_h, + in_w, + out_h, + out_w, + indices[0], + indices[1], + out_y, + out_x}); } else if (mode == "bicubic") { - value = lang::CallExtern( - func_name, {input, input->shape[1], in_h, in_w, out_h, out_w, indices[0], indices[1], out_y, out_x}); + value = lang::CallExtern(func_name, + {input, + input->shape[1], + in_h, + in_w, + out_h, + out_w, + indices[0], + indices[1], + out_y, + out_x}); } return value; @@ -111,18 +135,22 @@ ir::Tensor Resize(const ir::Tensor &input, return res; } -std::vector> InferShapeForResize(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape[0].size(), 4U) << "The input's shape size should be 4! Please check again."; +std::vector> InferShapeForResize( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape[0].size(), 4U) + << "The input's shape size should be 4! Please check again."; CHECK(attrs.find("out_shape") != attrs.end()) << "Cannot find \"out_shape\" attribute in \"resize\" op, Please Check."; std::vector out_shape; out_shape = absl::get>(attrs.at("out_shape")); CHECK_EQ(out_shape.size(), 2U) << "The length of out_shape must be 2."; - CHECK(out_shape[0] > 0 && out_shape[1] > 0) << "The element of out_shape must be great that 0."; + CHECK(out_shape[0] > 0 && out_shape[1] > 0) + << "The element of out_shape must be great that 0."; - CHECK(attrs.find("mode") != attrs.end()) << "Cannot find \"mode\" attribute in \"resize\" op, Please Check."; + CHECK(attrs.find("mode") != attrs.end()) + << "Cannot find \"mode\" attribute in \"resize\" op, Please Check."; std::string mode = absl::get(attrs.at("mode")); CHECK(mode == "nearest" || mode == "bilinear" || mode == "bicubic") << "Resize only supports `nearest`, `bilinear` and `bicubic` mode."; @@ -137,18 +165,21 @@ std::vector> InferShapeForResize(const std::vector InferDtypeForResize(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForResize(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; CHECK(inputs_type[0] == Int(32)) << "Resize only supports int32 type input."; std::vector res{inputs_type[0]}; return res; } -std::shared_ptr StrategyForResize(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForResize( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { std::vector out_shape; std::string mode = "bilinear"; @@ -163,10 +194,13 @@ std::shared_ptr StrategyForResize(const framework::NodeAt CHECK(mode == "nearest" || mode == "bilinear" || mode == "bicubic") << "Resize only supports `nearest`, `bilinear` and `bicubic` mode."; - framework::CINNCompute resize_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of Resize compute is empty! Please check.\n"; + framework::CINNCompute resize_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of Resize compute is empty! Please check.\n"; CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 1U) << "at least 1 input tensors for Resize compute\n"; + CHECK_GE(pack_args.size(), 1U) + << "at least 1 input tensors for Resize compute\n"; Expr A = pack_args[0]; CHECK(A.as_tensor()); CHECK(!output_shapes.empty()); @@ -190,8 +224,10 @@ std::shared_ptr StrategyForResize(const framework::NodeAt *ret = common::CINNValuePack{res}; }); - framework::CINNSchedule resize_schedule([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of resize schedule is empty! Please check.\n"; + framework::CINNSchedule resize_schedule([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of resize schedule is empty! Please check.\n"; common::CINNValuePack arg_pack = args[0]; std::vector vec_ast; for (int i = 0; i < arg_pack.size(); i++) { @@ -204,7 +240,10 @@ std::shared_ptr StrategyForResize(const framework::NodeAt ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); - long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies()); + long prod_size = std::accumulate(output_shapes[0].begin(), + output_shapes[0].end(), + 1, + std::multiplies()); if (prod_size > 1) { if (target.arch == Target::Arch::NVGPU) { pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target); @@ -212,7 +251,8 @@ std::shared_ptr StrategyForResize(const framework::NodeAt pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true); } } - std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = common::CINNValuePack{res}; }); @@ -231,10 +271,14 @@ CINN_REGISTER_HELPER(resize_ops) { .describe(" ") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForResize) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForResize)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForResize)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForResize) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForResize)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForResize)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/sort.cc b/paddle/cinn/hlir/op/contrib/sort.cc index 73648ff22bc0e..9c0a0aefa0c57 100644 --- a/paddle/cinn/hlir/op/contrib/sort.cc +++ b/paddle/cinn/hlir/op/contrib/sort.cc @@ -63,9 +63,11 @@ std::vector ArgSort(const ir::Tensor &A, LOG(FATAL) << "ArgSort only supports X86 and NVGPU ! Please Check.\n"; } if (is_ascend) { - index_func_name = cinn::hlir::GetExternFuncName(target, A->type(), "lt_num"); + index_func_name = + cinn::hlir::GetExternFuncName(target, A->type(), "lt_num"); } else { - index_func_name = cinn::hlir::GetExternFuncName(target, A->type(), "gt_num"); + index_func_name = + cinn::hlir::GetExternFuncName(target, A->type(), "gt_num"); } int pos_axis = axis; if (pos_axis < 0) { @@ -86,10 +88,11 @@ std::vector ArgSort(const ir::Tensor &A, stride = stride * A->shape[i]; } } - offset = common::AutoSimplify(offset); - stride = common::AutoSimplify(stride); + offset = common::AutoSimplify(offset); + stride = common::AutoSimplify(stride); auto A_shape_axis = A->shape[pos_axis]; - return lang::CallExtern(index_func_name, {A, A_shape_axis, A(indices), offset, stride}); + return lang::CallExtern(index_func_name, + {A, A_shape_axis, A(indices), offset, stride}); }, name + "_temp"); auto res = Compute( @@ -111,7 +114,9 @@ std::vector ArgSort(const ir::Tensor &A, stride = common::AutoSimplify(stride); auto A_shape_axis = A->shape[pos_axis]; - auto idx = lang::CallExtern(find_func_name, {positions, A_shape_axis, indices[pos_axis], offset, stride}); + auto idx = lang::CallExtern( + find_func_name, + {positions, A_shape_axis, indices[pos_axis], offset, stride}); return idx; }, name); @@ -129,8 +134,9 @@ std::vector Sort(const ir::Tensor &A, if (pos_axis < 0) { pos_axis += A->shape.size(); } - auto sort_index = ArgSort(A, target, stages, pos_axis, is_ascend, name + "_index"); - auto res = Compute( + auto sort_index = + ArgSort(A, target, stages, pos_axis, is_ascend, name + "_index"); + auto res = Compute( A->shape, [=](const std::vector &indices) { std::vector A_indices(indices); @@ -142,49 +148,58 @@ std::vector Sort(const ir::Tensor &A, return {res, sort_index.at(0), sort_index.at(1)}; } -std::shared_ptr StrategyForSort(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForSort( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { auto attr_store = attrs.attr_store; std::string op_name("sort"); CHECK(attr_store.count("axis")) << "find no attr of axis"; - int axis = absl::get(attr_store.at("axis")); + int axis = absl::get(attr_store.at("axis")); bool is_ascend = true; if (attr_store.count("is_ascend")) { is_ascend = absl::get(attr_store.at("is_ascend")); } - framework::CINNCompute sort_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of Sort compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 1U) << "At least 1 input tensors for Sort compute\n"; - Expr A = pack_args[0]; - CHECK(A.as_tensor()); - CHECK(!output_shapes.empty()); - auto tensor_A = A.as_tensor_ref(); - auto stages = CreateStages({tensor_A}); - VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") - << ", output_shapes: " << utils::Join(output_shapes[0], ", "); - auto tensor_name = UniqName("Sort_out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2U); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); - } - std::vector out = Sort(tensor_A, target, stages, axis, is_ascend, tensor_name); - stages->InsertLazily(out[0]); - std::vector res{CINNValue(out[0]), CINNValue(out[1]), CINNValue(out[2])}; - CHECK(!out_type.empty()) << "Output type of Sort is empty! Please check.\n"; - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); - - framework::CINNSchedule sort_schedule([=](lang::Args args, lang::RetValue *ret) { + framework::CINNCompute sort_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of Sort compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 1U) + << "At least 1 input tensors for Sort compute\n"; + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + CHECK(!output_shapes.empty()); + auto tensor_A = A.as_tensor_ref(); + auto stages = CreateStages({tensor_A}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + auto tensor_name = UniqName("Sort_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2U); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } + std::vector out = + Sort(tensor_A, target, stages, axis, is_ascend, tensor_name); + stages->InsertLazily(out[0]); + std::vector res{ + CINNValue(out[0]), CINNValue(out[1]), CINNValue(out[2])}; + CHECK(!out_type.empty()) + << "Output type of Sort is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule sort_schedule([=](lang::Args args, + lang::RetValue *ret) { if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of sort_schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of sort_schedule is empty! Please check.\n"; common::CINNValuePack arg_pack = args[0]; std::vector vec_ast; for (int i = 0; i < arg_pack.size(); i++) { @@ -203,16 +218,21 @@ std::shared_ptr StrategyForSort(const framework::NodeAttr ir_sch.SetBuffer(blocks[0], "local"); ir_sch.SetBuffer(blocks[1], "local"); - long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies()); + long prod_size = std::accumulate(output_shapes[0].begin(), + output_shapes[0].end(), + 1, + std::multiplies()); if (prod_size > 1 && target.arch == Target::Arch::X86) { pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true); } - std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = common::CINNValuePack{res}; } else { - CHECK(!args.empty()) << "The input argument of sort_schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of sort_schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; - Expr out = arg_pack[0]; + Expr out = arg_pack[0]; CHECK(out.as_tensor()); *ret = arg_pack; } @@ -223,28 +243,32 @@ std::shared_ptr StrategyForSort(const framework::NodeAttr return strategy; } -std::shared_ptr StrategyForArgSort(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForArgSort( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { auto attr_store = attrs.attr_store; CHECK(attr_store.count("axis")) << "find no attr of axis"; - int axis = absl::get(attr_store.at("axis")); + int axis = absl::get(attr_store.at("axis")); bool is_ascend = true; if (attr_store.count("is_ascend")) { is_ascend = absl::get(attr_store.at("is_ascend")); } - framework::CINNCompute argsort_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of ArgSort compute is empty! Please check.\n"; + framework::CINNCompute argsort_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of ArgSort compute is empty! Please check.\n"; CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 1U) << "At least 1 input tensors for ArgSort compute\n"; + CHECK_GE(pack_args.size(), 1U) + << "At least 1 input tensors for ArgSort compute\n"; Expr A = pack_args[0]; CHECK(A.as_tensor()); CHECK(!output_shapes.empty()); auto tensor_A = A.as_tensor_ref(); - auto stages = CreateStages({tensor_A}); + auto stages = CreateStages({tensor_A}); VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") << ", output_shapes: " << utils::Join(output_shapes[0], ", "); auto tensor_name = UniqName("ArgSort_out"); @@ -259,14 +283,17 @@ std::shared_ptr StrategyForArgSort(const framework::NodeA stages->InsertLazily(out.at(1)); res.push_back(CINNValue(out.at(0))); res.push_back(CINNValue(out.at(1))); - CHECK(!out_type.empty()) << "Output type of ArgSort is empty! Please check.\n"; + CHECK(!out_type.empty()) + << "Output type of ArgSort is empty! Please check.\n"; res.push_back(CINNValue(stages)); *ret = CINNValuePack{res}; }); - framework::CINNSchedule argsort_schedule([=](lang::Args args, lang::RetValue *ret) { + framework::CINNSchedule argsort_schedule([=](lang::Args args, + lang::RetValue *ret) { if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of argsort_schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of argsort_schedule is empty! Please check.\n"; common::CINNValuePack arg_pack = args[0]; std::vector vec_ast; for (int i = 0; i < arg_pack.size(); i++) { @@ -282,18 +309,23 @@ std::shared_ptr StrategyForArgSort(const framework::NodeA auto blocks = ir_sch.GetAllBlocks(); // TODO: remove external calls, do not use local variables, because // the size will exceed the limit. - // TODO: There is a bug, setting buffer to "local" here will cause the var declared twice at CodeGen. - // ir_sch.SetBuffer(blocks[0], "local"); - long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies()); + // TODO: There is a bug, setting buffer to "local" here will cause the var + // declared twice at CodeGen. ir_sch.SetBuffer(blocks[0], "local"); + long prod_size = std::accumulate(output_shapes[0].begin(), + output_shapes[0].end(), + 1, + std::multiplies()); if (prod_size > 1 && target.arch == Target::Arch::X86) { pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true); } - std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = common::CINNValuePack{res}; } else { - CHECK(!args.empty()) << "The input argument of argsort_schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of argsort_schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; - Expr out = arg_pack[0]; + Expr out = arg_pack[0]; CHECK(out.as_tensor()); *ret = arg_pack; } @@ -304,9 +336,11 @@ std::shared_ptr StrategyForArgSort(const framework::NodeA return strategy; } -std::vector> InferShapeForSort(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 1UL) << "The input's shape size should be 1! Please check again."; +std::vector> InferShapeForSort( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 1UL) + << "The input's shape size should be 1! Please check again."; int axis = 0; for (auto &iter : attrs) { if (iter.first == "axis") { @@ -314,20 +348,25 @@ std::vector> InferShapeForSort(const std::vector> res{inputs_shape[0]}; return res; } -std::vector InferDtypeForSort(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 1UL) << "The input's type size should be 1! Please check again."; +std::vector InferDtypeForSort(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 1UL) + << "The input's type size should be 1! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::vector> InferShapeForArgSort(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 1UL) << "The input's shape size should be 1! Please check again."; +std::vector> InferShapeForArgSort( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 1UL) + << "The input's shape size should be 1! Please check again."; int axis = 0; for (auto &iter : attrs) { if (iter.first == "axis") { @@ -338,24 +377,29 @@ std::vector> InferShapeForArgSort(const std::vector> res{inputs_shape[0], inputs_shape[0]}; return res; } -std::vector InferDtypeForArgSort(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 1UL) << "The input's type size should be 1! Please check again."; +std::vector InferDtypeForArgSort(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 1UL) + << "The input's type size should be 1! Please check again."; return {Int(32), Int(32)}; } -std::vector> InferShapeForTopK(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 1UL) << "The input's shape size should be 1! Please check again."; - auto res = inputs_shape; +std::vector> InferShapeForTopK( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 1UL) + << "The input's shape size should be 1! Please check again."; + auto res = inputs_shape; auto k_it = attrs.find("k"); CHECK(k_it != attrs.end()) << "The attr k of topk does not exist."; - int k = absl::get(k_it->second); + int k = absl::get(k_it->second); auto axis_it = attrs.find("axis"); CHECK(axis_it != attrs.end()) << "The attr axis of topk does not exist."; int axis = absl::get(axis_it->second); @@ -368,8 +412,10 @@ std::vector> InferShapeForTopK(const std::vector InferDtypeForTopK(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 1UL) << "The input's type size should be 1! Please check again."; +std::vector InferDtypeForTopK(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 1UL) + << "The input's type size should be 1! Please check again."; std::vector res{inputs_type[0], Int(64)}; return res; } @@ -380,32 +426,42 @@ std::vector InferDtypeForTopK(const std::vector &inputs_type, const CINN_REGISTER_HELPER(sort_ops) { CINN_REGISTER_OP(sort) - .describe("Sort a variable x along the given axis and return sorted Variable.") + .describe( + "Sort a variable x along the given axis and return sorted Variable.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForSort) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForSort) .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForSort)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForSort)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); CINN_REGISTER_OP(argsort) .describe("Sort a variable x along the given axis and return indices.") .set_num_inputs(1) .set_num_outputs(2) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForArgSort) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForArgSort)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForArgSort)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForArgSort) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForArgSort)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForArgSort)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); CINN_REGISTER_OP(top_k) - .describe("Find values and indices of the k largest entries for the last dimension..") + .describe( + "Find values and indices of the k largest entries for the last " + "dimension..") .set_num_inputs(1) .set_num_outputs(2) .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForTopK)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForTopK)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/sort_test.cc b/paddle/cinn/hlir/op/contrib/sort_test.cc index ed39e957914de..3d2a8f6c73e38 100644 --- a/paddle/cinn/hlir/op/contrib/sort_test.cc +++ b/paddle/cinn/hlir/op/contrib/sort_test.cc @@ -42,10 +42,18 @@ TEST(GenerateCode_Cpu, ArgSort) { lang::Placeholder in("in", {n, h}); poly::StageMap stages = poly::CreateStages({in}); - ir::Tensor res = ArgSort(in.tensor(), target, stages, 1, true, "test_arg_sort_out").at(0); + ir::Tensor res = + ArgSort(in.tensor(), target, stages, 1, true, "test_arg_sort_out").at(0); stages->InsertLazily(res); std::vector funcs = - lang::LowerVec("TestGenerateCodeCpu_ArgSort", stages, {in, res}, {}, {}, nullptr, target, true); + lang::LowerVec("TestGenerateCodeCpu_ArgSort", + stages, + {in, res}, + {}, + {}, + nullptr, + target, + true); VLOG(6) << "Expr before CPU codegen:"; VLOG(6) << funcs[0]->body; @@ -57,7 +65,8 @@ TEST(GenerateCode_Cpu, ArgSort) { backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); codegen.SetInlineBuiltinCodes(false); - std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + std::string code = + codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); VLOG(6) << "Cpu Codegen result:"; VLOG(6) << code << std::endl; } @@ -71,11 +80,18 @@ TEST(GenerateCode_Cpu, Sort) { ir::Expr h(28); lang::Placeholder in("in", {n, h}); - auto stages = poly::CreateStages({in}); + auto stages = poly::CreateStages({in}); ir::Tensor out = Sort(in, target, stages, 1, true, "test_sort_out").at(0); stages->InsertLazily(out); std::vector funcs = - lang::LowerVec("TestGenerateCodeCpu_Sort", stages, {in, out}, {}, {}, nullptr, target, true); + lang::LowerVec("TestGenerateCodeCpu_Sort", + stages, + {in, out}, + {}, + {}, + nullptr, + target, + true); VLOG(6) << "Expr before CPU codegen:"; VLOG(6) << funcs[0]->body; @@ -87,7 +103,8 @@ TEST(GenerateCode_Cpu, Sort) { backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); codegen.SetInlineBuiltinCodes(false); - std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + std::string code = + codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); auto target_source = R"ROC( #include #include diff --git a/paddle/cinn/hlir/op/contrib/triangular_solve.cc b/paddle/cinn/hlir/op/contrib/triangular_solve.cc index 0b9b120b71083..3ec35013fc417 100644 --- a/paddle/cinn/hlir/op/contrib/triangular_solve.cc +++ b/paddle/cinn/hlir/op/contrib/triangular_solve.cc @@ -36,40 +36,51 @@ namespace op { using common::CINNValue; using common::CINNValuePack; -std::shared_ptr StrategyForTriangularSolve(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute triangular_solve_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of triangular_solve is empty! Please check."; - CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 2U) << "Two input tensors are required for the computation of triangular_solve."; - Expr a_expr = pack_args[0]; - Expr b_expr = pack_args[1]; - ir::Tensor a = a_expr.as_tensor_ref(); - ir::Tensor b = b_expr.as_tensor_ref(); - std::string tensor_name = "triangular_solve_out"; - auto out = pe::Identity(b, tensor_name).front(); - auto stages = CreateStages({out}); - std::vector res{CINNValue(out), CINNValue(stages)}; - *ret = CINNValuePack{res}; - }); +std::shared_ptr StrategyForTriangularSolve( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute triangular_solve_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of triangular_solve is empty! Please check."; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 2U) + << "Two input tensors are required for the computation of " + "triangular_solve."; + Expr a_expr = pack_args[0]; + Expr b_expr = pack_args[1]; + ir::Tensor a = a_expr.as_tensor_ref(); + ir::Tensor b = b_expr.as_tensor_ref(); + std::string tensor_name = "triangular_solve_out"; + auto out = pe::Identity(b, tensor_name).front(); + auto stages = CreateStages({out}); + std::vector res{CINNValue(out), CINNValue(stages)}; + *ret = CINNValuePack{res}; + }); auto strategy = std::make_shared(); - strategy->AddImpl( - triangular_solve_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.triangular_solve.x86", 1); + strategy->AddImpl(triangular_solve_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.triangular_solve.x86", + 1); return strategy; } -std::vector InferShapeForTriangularSolve(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; +std::vector InferShapeForTriangularSolve( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 2U) + << "The input's shape size should be 2! Please check again."; framework::shape_t a_shape = inputs_shape[0]; framework::shape_t b_shape = inputs_shape[1]; - int a_shape_size = a_shape.size(); - int b_shape_size = b_shape.size(); - CHECK_GE(a_shape_size, 2U) << "The input matrix A shape size should >= 2! Please check again."; - CHECK_GE(b_shape_size, 2U) << "The input matrix B shape size should >= 2! Please check again."; + int a_shape_size = a_shape.size(); + int b_shape_size = b_shape.size(); + CHECK_GE(a_shape_size, 2U) + << "The input matrix A shape size should >= 2! Please check again."; + CHECK_GE(b_shape_size, 2U) + << "The input matrix B shape size should >= 2! Please check again."; int left_side = -1; for (auto &iter : attrs) { @@ -92,9 +103,10 @@ std::vector InferShapeForTriangularSolve(const std::vector InferDtypeForTriangularSolve(const std::vector &inputs_type, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 2U) << "The input's shape size should be 2! Please check again."; +std::vector InferDtypeForTriangularSolve( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 2U) + << "The input's shape size should be 2! Please check again."; CHECK(inputs_type[0].is_float(32) || inputs_type[0].is_float(64)) << "The input's dtype should be float32 or float64! Please check again."; CHECK(inputs_type[1].is_float(32) || inputs_type[1].is_float(64)) @@ -111,10 +123,14 @@ CINN_REGISTER_HELPER(triangular_solve_ops) { .describe("TriangularSolve") .set_num_inputs(2) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForTriangularSolve) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForTriangularSolve)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForTriangularSolve)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForTriangularSolve) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForTriangularSolve)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForTriangularSolve)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/contrib/uniform_random.cc b/paddle/cinn/hlir/op/contrib/uniform_random.cc index dfa3fb17743ab..4895c0a16289c 100644 --- a/paddle/cinn/hlir/op/contrib/uniform_random.cc +++ b/paddle/cinn/hlir/op/contrib/uniform_random.cc @@ -20,6 +20,7 @@ #include #include "absl/types/variant.h" +#include "glog/logging.h" #include "paddle/cinn/common/cas.h" #include "paddle/cinn/common/cinn_value.h" #include "paddle/cinn/common/common.h" @@ -43,7 +44,6 @@ #include "paddle/cinn/lang/compute.h" #include "paddle/cinn/lang/packed_func.h" #include "paddle/cinn/poly/stage.h" -#include "glog/logging.h" namespace cinn { namespace hlir { @@ -52,43 +52,49 @@ namespace op { using common::CINNValue; using common::CINNValuePack; -std::shared_ptr StrategyForUniformRandom(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute uniform_random_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(attrs.attr_store.count("shape")); - ir::Tensor shape_tensor; - std::string tensor_name = "uniform_random_out"; - auto out = pe::Identity(shape_tensor, tensor_name).front(); - auto stages = CreateStages({out}); - std::vector res{CINNValue(out), CINNValue(stages)}; - *ret = CINNValuePack{res}; - }); +std::shared_ptr StrategyForUniformRandom( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute uniform_random_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(attrs.attr_store.count("shape")); + ir::Tensor shape_tensor; + std::string tensor_name = "uniform_random_out"; + auto out = pe::Identity(shape_tensor, tensor_name).front(); + auto stages = CreateStages({out}); + std::vector res{CINNValue(out), CINNValue(stages)}; + *ret = CINNValuePack{res}; + }); auto strategy = std::make_shared(); - strategy->AddImpl( - uniform_random_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.uniform_random.x86", 1); + strategy->AddImpl(uniform_random_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.uniform_random.x86", + 1); return strategy; } -std::vector InferShapeForUniformRandom(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForUniformRandom( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { CHECK(attrs.count("shape")); auto shape = absl::get>(attrs.at("shape")); CHECK(!shape.empty()) << "shape attr is empty!"; return {shape}; } -std::vector InferDtypeForUniformRandom(const std::vector &inputs_type, - const framework::AttrMapType &attrs) { +std::vector InferDtypeForUniformRandom( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { std::string dtype = "float32"; if (attrs.find("dtype") != attrs.end()) { dtype = absl::get(attrs.at("dtype")); } std::vector res{common::Str2Type(dtype)}; CHECK(res[0].is_float(32) || res[0].is_float(64)) - << "uniform_random only support float32 and float64, but here " << res[0] << "! Please check."; + << "uniform_random only support float32 and float64, but here " << res[0] + << "! Please check."; return res; } @@ -101,10 +107,14 @@ CINN_REGISTER_HELPER(uniform_random_ops) { .describe("UniformRandom") .set_num_inputs(0) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForUniformRandom) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForUniformRandom)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForUniformRandom)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForUniformRandom) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForUniformRandom)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForUniformRandom)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/custom_call.cc b/paddle/cinn/hlir/op/custom_call.cc index 548141837fdab..c109df2ec2866 100644 --- a/paddle/cinn/hlir/op/custom_call.cc +++ b/paddle/cinn/hlir/op/custom_call.cc @@ -41,8 +41,10 @@ using framework::OpStrategy; using framework::shape_t; using framework::StrategyFunction; -using ArgsFunc = std::function( - const framework::NodeAttr &, const std::vector &, const std::vector> &)>; +using ArgsFunc = + std::function(const framework::NodeAttr &, + const std::vector &, + const std::vector> &)>; class CustomCallArgsFuncRegistry { public: @@ -51,14 +53,18 @@ class CustomCallArgsFuncRegistry { return instance; } - void Register(const std::string &custom_call, const common::Target &target, ArgsFunc args_func) { - auto id = custom_call + "_" + target.arch_str(); + void Register(const std::string &custom_call, + const common::Target &target, + ArgsFunc args_func) { + auto id = custom_call + "_" + target.arch_str(); func_map_[id] = args_func; } - ArgsFunc Lookup(const std::string &custom_call, const common::Target &target) { + ArgsFunc Lookup(const std::string &custom_call, + const common::Target &target) { auto id = custom_call + "_" + target.arch_str(); - CHECK(func_map_.count(id)) << "Can't find " << custom_call << " for target " << target.arch_str(); + CHECK(func_map_.count(id)) + << "Can't find " << custom_call << " for target " << target.arch_str(); return func_map_[id]; } @@ -67,29 +73,32 @@ class CustomCallArgsFuncRegistry { std::unordered_map func_map_; }; -std::shared_ptr StrategyForCustomCall(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForCustomCall( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { framework::CINNCompute compute([=](lang::Args args, lang::RetValue *ret) { CHECK_EQ(args.size(), 1UL); CINNValuePack pack_args = args[0]; CHECK_EQ(pack_args.size(), 2UL); CHECK(pack_args[0].is_string() && pack_args[1].is_string()); - std::string func_name = pack_args[0].operator std::string(); + std::string func_name = pack_args[0].operator std::string(); std::string custom_call_api = pack_args[1].operator std::string(); - auto args_func = CustomCallArgsFuncRegistry::Global().Lookup(custom_call_api, target); + auto args_func = + CustomCallArgsFuncRegistry::Global().Lookup(custom_call_api, target); // create call function. ir::Var kernel_args(KERNEL_ARGS, type_of()); ir::Var kernel_args_num(KERNEL_ARGS_NUM, type_of()); - auto args_list = args_func(attrs, inputs, output_shapes); + auto args_list = args_func(attrs, inputs, output_shapes); std::vector host_args = {kernel_args, kernel_args_num}; host_args.insert(host_args.end(), args_list.begin(), args_list.end()); - std::vector arguments = {ir::Argument(kernel_args, ir::Argument::IO::kOutput), - ir::Argument(kernel_args_num, ir::Argument::IO::kInput)}; + std::vector arguments = { + ir::Argument(kernel_args, ir::Argument::IO::kOutput), + ir::Argument(kernel_args_num, ir::Argument::IO::kInput)}; // if target is nvgpu, add stream. if (target == common::DefaultNVGPUTarget()) { ir::Var kernel_stream(KERNEL_STREAM, type_of()); @@ -97,15 +106,22 @@ std::shared_ptr StrategyForCustomCall(const framework::NodeAttr &att host_args.push_back(kernel_stream); arguments.emplace_back(kernel_stream, ir::Argument::IO::kOutput); } - auto call_extern_api = - ir::Call::Make(Void(), custom_call_api, host_args, {}, ir::CallType::Extern, ir::FunctionRef(), 0); - auto func = ir::_LoweredFunc_::Make(func_name, arguments, call_extern_api, {}); + auto call_extern_api = ir::Call::Make(Void(), + custom_call_api, + host_args, + {}, + ir::CallType::Extern, + ir::FunctionRef(), + 0); + auto func = + ir::_LoweredFunc_::Make(func_name, arguments, call_extern_api, {}); VLOG(3) << func; *ret = CINNValuePack{{CINNValue(ir::Expr(func))}}; }); - framework::CINNSchedule schedule([=](lang::Args args, lang::RetValue *ret) {}); + framework::CINNSchedule schedule( + [=](lang::Args args, lang::RetValue *ret) {}); auto strategy = std::make_shared(); strategy->AddImpl(compute, schedule, "strategy.custom_call.x86", 1); @@ -113,25 +129,42 @@ std::shared_ptr StrategyForCustomCall(const framework::NodeAttr &att } #ifdef CINN_WITH_CUDA -std::vector CustomCallArgsForCublas(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForCublas( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { CHECK_EQ(inputs.size(), 2); CHECK_EQ(output_shapes.size(), 1); CHECK_LE(inputs[0]->shape.size(), 4); CHECK_LE(inputs[1]->shape.size(), 4); const auto &attr_store = attrs.attr_store; - bool trans_a = attr_store.count("trans_a") ? absl::get(attr_store.at("trans_a")) : false; - bool trans_b = attr_store.count("trans_b") ? absl::get(attr_store.at("trans_b")) : false; - bool trans_out = attr_store.count("trans_out") ? absl::get(attr_store.at("trans_out")) : false; - float alpha = attr_store.count("alpha") ? absl::get(attr_store.at("alpha")) : 1.0f; - float beta = attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; - - int x_num_col_dims = attr_store.count("x_num_col_dims") ? absl::get(attr_store.at("x_num_col_dims")) : 0; - int y_num_col_dims = attr_store.count("y_num_col_dims") ? absl::get(attr_store.at("y_num_col_dims")) : 0; - bool is_infer = attr_store.count("is_infer") ? absl::get(attr_store.at("is_infer")) : false; - CHECK((x_num_col_dims == 0 && y_num_col_dims == 0) || (x_num_col_dims > 0 && y_num_col_dims > 0)); + bool trans_a = attr_store.count("trans_a") + ? absl::get(attr_store.at("trans_a")) + : false; + bool trans_b = attr_store.count("trans_b") + ? absl::get(attr_store.at("trans_b")) + : false; + bool trans_out = attr_store.count("trans_out") + ? absl::get(attr_store.at("trans_out")) + : false; + float alpha = attr_store.count("alpha") + ? absl::get(attr_store.at("alpha")) + : 1.0f; + float beta = + attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; + + int x_num_col_dims = attr_store.count("x_num_col_dims") + ? absl::get(attr_store.at("x_num_col_dims")) + : 0; + int y_num_col_dims = attr_store.count("y_num_col_dims") + ? absl::get(attr_store.at("y_num_col_dims")) + : 0; + bool is_infer = attr_store.count("is_infer") + ? absl::get(attr_store.at("is_infer")) + : false; + CHECK((x_num_col_dims == 0 && y_num_col_dims == 0) || + (x_num_col_dims > 0 && y_num_col_dims > 0)); std::vector a_shape, b_shape; if (x_num_col_dims == 0 && y_num_col_dims == 0) { @@ -147,7 +180,7 @@ std::vector CustomCallArgsForCublas(const framework::NodeAttr &attrs, a_shape[3] = inputs[0]->shape[0]; } } else { - a_shape = inputs[0]->shape; + a_shape = inputs[0]->shape; int insert_1_to_a = 4 - a_shape.size(); for (int idx = 0; idx < insert_1_to_a; ++idx) { a_shape.insert(a_shape.begin(), ir::Expr(1)); @@ -163,7 +196,7 @@ std::vector CustomCallArgsForCublas(const framework::NodeAttr &attrs, b_shape[2] = inputs[1]->shape[0]; } } else { - b_shape = inputs[1]->shape; + b_shape = inputs[1]->shape; int insert_1_to_b = 4 - b_shape.size(); for (int idx = 0; idx < insert_1_to_b; ++idx) { b_shape.insert(b_shape.begin(), ir::Expr(1)); @@ -171,9 +204,9 @@ std::vector CustomCallArgsForCublas(const framework::NodeAttr &attrs, } } else if (x_num_col_dims > 0 && y_num_col_dims > 0) { // input a shape. - a_shape = {Expr(1), Expr(1)}; + a_shape = {Expr(1), Expr(1)}; int a_height = 1; - int a_width = 1; + int a_width = 1; for (int idx = 0; idx < x_num_col_dims; ++idx) { a_height *= inputs[0]->shape[idx].as_int32(); } @@ -184,9 +217,9 @@ std::vector CustomCallArgsForCublas(const framework::NodeAttr &attrs, a_shape.emplace_back(a_width); // input b shape. - b_shape = {Expr(1), Expr(1)}; + b_shape = {Expr(1), Expr(1)}; int b_height = 1; - int b_width = 1; + int b_width = 1; for (int idx = 0; idx < y_num_col_dims; ++idx) { b_height *= inputs[1]->shape[idx].as_int32(); } @@ -197,10 +230,12 @@ std::vector CustomCallArgsForCublas(const framework::NodeAttr &attrs, b_shape.emplace_back(b_width); if (is_infer) { - CHECK_EQ(a_width, b_width) << "The K dimension of mul shold be equal! Please check."; + CHECK_EQ(a_width, b_width) + << "The K dimension of mul shold be equal! Please check."; trans_b = true; } else { - CHECK_EQ(a_width, b_height) << "The K dimension of mul shold be equal! Please check."; + CHECK_EQ(a_width, b_height) + << "The K dimension of mul shold be equal! Please check."; } } else { LOG(FATAL) << "Unkown Matmul Setting!"; @@ -209,39 +244,59 @@ std::vector CustomCallArgsForCublas(const framework::NodeAttr &attrs, CHECK_EQ(a_shape.size(), 4); CHECK_EQ(b_shape.size(), 4); // func args - std::vector args = { - ir::Expr(trans_a), ir::Expr(trans_b), ir::Expr(trans_out), ir::Expr(alpha), ir::Expr(beta)}; + std::vector args = {ir::Expr(trans_a), + ir::Expr(trans_b), + ir::Expr(trans_out), + ir::Expr(alpha), + ir::Expr(beta)}; args.insert(args.end(), a_shape.begin(), a_shape.end()); args.insert(args.end(), b_shape.begin(), b_shape.end()); return args; } -std::vector CustomCallArgsForBatchedCublas(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForBatchedCublas( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { CHECK_GT(inputs.size(), 2); CHECK_GT(output_shapes.size(), 1); CHECK_EQ(inputs.size() - 1, output_shapes.size()); const auto &attr_store = attrs.attr_store; - bool trans_a = attr_store.count("trans_a") ? absl::get(attr_store.at("trans_a")) : false; - bool trans_b = attr_store.count("trans_b") ? absl::get(attr_store.at("trans_b")) : false; - bool trans_out = attr_store.count("trans_out") ? absl::get(attr_store.at("trans_out")) : false; - float alpha = attr_store.count("alpha") ? absl::get(attr_store.at("alpha")) : 1.0f; - float beta = attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; - - int x_num_col_dims = attr_store.count("x_num_col_dims") ? absl::get(attr_store.at("x_num_col_dims")) : 0; - int y_num_col_dims = attr_store.count("y_num_col_dims") ? absl::get(attr_store.at("y_num_col_dims")) : 0; - bool is_infer = attr_store.count("is_infer") ? absl::get(attr_store.at("is_infer")) : false; - CHECK((x_num_col_dims == 0 && y_num_col_dims == 0) || (x_num_col_dims > 0 && y_num_col_dims > 0)); + bool trans_a = attr_store.count("trans_a") + ? absl::get(attr_store.at("trans_a")) + : false; + bool trans_b = attr_store.count("trans_b") + ? absl::get(attr_store.at("trans_b")) + : false; + bool trans_out = attr_store.count("trans_out") + ? absl::get(attr_store.at("trans_out")) + : false; + float alpha = attr_store.count("alpha") + ? absl::get(attr_store.at("alpha")) + : 1.0f; + float beta = + attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; + + int x_num_col_dims = attr_store.count("x_num_col_dims") + ? absl::get(attr_store.at("x_num_col_dims")) + : 0; + int y_num_col_dims = attr_store.count("y_num_col_dims") + ? absl::get(attr_store.at("y_num_col_dims")) + : 0; + bool is_infer = attr_store.count("is_infer") + ? absl::get(attr_store.at("is_infer")) + : false; + CHECK((x_num_col_dims == 0 && y_num_col_dims == 0) || + (x_num_col_dims > 0 && y_num_col_dims > 0)); ir::Tensor left, right; CHECK(attr_store.count("side")); if (absl::get(attr_store.at("side")) == "left") { - left = inputs[0]; + left = inputs[0]; right = inputs[1]; } else { - left = inputs[1]; + left = inputs[1]; right = inputs[0]; } @@ -259,7 +314,7 @@ std::vector CustomCallArgsForBatchedCublas(const framework::NodeAttr & a_shape[3] = left->shape[0]; } } else { - a_shape = left->shape; + a_shape = left->shape; int insert_1_to_a = 4 - a_shape.size(); for (int idx = 0; idx < insert_1_to_a; ++idx) { a_shape.insert(a_shape.begin(), ir::Expr(1)); @@ -275,7 +330,7 @@ std::vector CustomCallArgsForBatchedCublas(const framework::NodeAttr & b_shape[2] = right->shape[0]; } } else { - b_shape = right->shape; + b_shape = right->shape; int insert_1_to_b = 4 - b_shape.size(); for (int idx = 0; idx < insert_1_to_b; ++idx) { b_shape.insert(b_shape.begin(), ir::Expr(1)); @@ -283,9 +338,9 @@ std::vector CustomCallArgsForBatchedCublas(const framework::NodeAttr & } } else if (x_num_col_dims > 0 && y_num_col_dims > 0) { // input a shape. - a_shape = {Expr(1), Expr(1)}; + a_shape = {Expr(1), Expr(1)}; int a_height = 1; - int a_width = 1; + int a_width = 1; for (int idx = 0; idx < x_num_col_dims; ++idx) { a_height *= left->shape[idx].as_int32(); } @@ -296,9 +351,9 @@ std::vector CustomCallArgsForBatchedCublas(const framework::NodeAttr & a_shape.emplace_back(a_width); // input b shape. - b_shape = {Expr(1), Expr(1)}; + b_shape = {Expr(1), Expr(1)}; int b_height = 1; - int b_width = 1; + int b_width = 1; for (int idx = 0; idx < y_num_col_dims; ++idx) { b_height *= right->shape[idx].as_int32(); } @@ -309,10 +364,12 @@ std::vector CustomCallArgsForBatchedCublas(const framework::NodeAttr & b_shape.emplace_back(b_width); if (is_infer) { - CHECK_EQ(a_width, b_width) << "The K dimension of mul shold be equal! Please check."; + CHECK_EQ(a_width, b_width) + << "The K dimension of mul shold be equal! Please check."; trans_b = true; } else { - CHECK_EQ(a_width, b_height) << "The K dimension of mul shold be equal! Please check."; + CHECK_EQ(a_width, b_height) + << "The K dimension of mul shold be equal! Please check."; } } else { LOG(FATAL) << "Unkown Matmul Setting!"; @@ -321,12 +378,14 @@ std::vector CustomCallArgsForBatchedCublas(const framework::NodeAttr & CHECK_EQ(a_shape.size(), 4); CHECK_EQ(b_shape.size(), 4); // func args - std::vector args = {absl::get(attr_store.at("side")) == "left" ? ir::Expr(0) : ir::Expr(1), - ir::Expr(trans_a), - ir::Expr(trans_b), - ir::Expr(trans_out), - ir::Expr(alpha), - ir::Expr(beta)}; + std::vector args = { + absl::get(attr_store.at("side")) == "left" ? ir::Expr(0) + : ir::Expr(1), + ir::Expr(trans_a), + ir::Expr(trans_b), + ir::Expr(trans_out), + ir::Expr(alpha), + ir::Expr(beta)}; args.insert(args.end(), a_shape.begin(), a_shape.end()); args.insert(args.end(), b_shape.begin(), b_shape.end()); return args; @@ -335,44 +394,55 @@ std::vector CustomCallArgsForBatchedCublas(const framework::NodeAttr & #endif #ifdef CINN_WITH_CUDNN -std::vector CustomCallArgsForCudnnConvForward(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForCudnnConvForward( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { CHECK_EQ(inputs.size(), 2UL); // CHECK_EQ(output_shapes.size(), 1UL); const auto &attr_store = attrs.attr_store; - float alpha = attr_store.count("alpha") ? absl::get(attr_store.at("alpha")) : 1.0f; - float beta = attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; + float alpha = attr_store.count("alpha") + ? absl::get(attr_store.at("alpha")) + : 1.0f; + float beta = + attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; CHECK(attr_store.count("padding")); auto padding = absl::get>(attr_store.at("padding")); CHECK(attr_store.count("stride")); auto stride = absl::get>(attr_store.at("stride")); - auto dilation = - attr_store.count("dilation") ? absl::get>(attr_store.at("dilation")) : std::vector({1, 1}); + auto dilation = attr_store.count("dilation") + ? absl::get>(attr_store.at("dilation")) + : std::vector({1, 1}); std::string data_format = - attr_store.count("data_format") ? absl::get(attr_store.at("data_format")) : "NCHW"; + attr_store.count("data_format") + ? absl::get(attr_store.at("data_format")) + : "NCHW"; if (data_format == "AnyLayout") { data_format = "NCHW"; } - int groups = attr_store.count("groups") ? absl::get(attr_store.at("groups")) : 1; - cudnnTensorFormat_t format = data_format == "NCHW" ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + int groups = + attr_store.count("groups") ? absl::get(attr_store.at("groups")) : 1; + cudnnTensorFormat_t format = + data_format == "NCHW" ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - std::vector input = inputs[0]->shape; + std::vector input = inputs[0]->shape; std::vector filter = inputs[1]->shape; std::vector output = {}; - std::transform(output_shapes[0].begin(), output_shapes[0].end(), std::back_inserter(output), [](const int dim) { - return ir::Expr(dim); - }); + std::transform(output_shapes[0].begin(), + output_shapes[0].end(), + std::back_inserter(output), + [](const int dim) { return ir::Expr(dim); }); // if format is nhwc if (format == CUDNN_TENSOR_NHWC) { - input = {input[0], input[3], input[1], input[2]}; + input = {input[0], input[3], input[1], input[2]}; filter = {filter[0], filter[3], filter[1], filter[2]}; output = {output[0], output[3], output[1], output[2]}; } - std::vector args = {ir::Expr(static_cast(format)), ir::Expr(alpha), ir::Expr(beta)}; + std::vector args = { + ir::Expr(static_cast(format)), ir::Expr(alpha), ir::Expr(beta)}; args.insert(args.end(), input.begin(), input.end()); args.insert(args.end(), filter.begin(), filter.end()); args.push_back(ir::Expr(padding[0])); @@ -387,44 +457,55 @@ std::vector CustomCallArgsForCudnnConvForward(const framework::NodeAtt return args; } -std::vector CustomCallArgsForCudnnConvBackwardData(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForCudnnConvBackwardData( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { CHECK_EQ(inputs.size(), 2UL); CHECK_EQ(output_shapes.size(), 1UL); const auto &attr_store = attrs.attr_store; - float alpha = attr_store.count("alpha") ? absl::get(attr_store.at("alpha")) : 1.0f; - float beta = attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; + float alpha = attr_store.count("alpha") + ? absl::get(attr_store.at("alpha")) + : 1.0f; + float beta = + attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; CHECK(attr_store.count("padding")); auto padding = absl::get>(attr_store.at("padding")); CHECK(attr_store.count("stride")); auto stride = absl::get>(attr_store.at("stride")); - auto dilation = - attr_store.count("dilation") ? absl::get>(attr_store.at("dilation")) : std::vector({1, 1}); + auto dilation = attr_store.count("dilation") + ? absl::get>(attr_store.at("dilation")) + : std::vector({1, 1}); std::string data_format = - attr_store.count("data_format") ? absl::get(attr_store.at("data_format")) : "NCHW"; + attr_store.count("data_format") + ? absl::get(attr_store.at("data_format")) + : "NCHW"; if (data_format == "AnyLayout") { data_format = "NCHW"; } - int groups = attr_store.count("groups") ? absl::get(attr_store.at("groups")) : 1; - cudnnTensorFormat_t format = data_format == "NCHW" ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + int groups = + attr_store.count("groups") ? absl::get(attr_store.at("groups")) : 1; + cudnnTensorFormat_t format = + data_format == "NCHW" ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; std::vector input = {}; - std::transform(output_shapes[0].begin(), output_shapes[0].end(), std::back_inserter(input), [](const int dim) { - return ir::Expr(dim); - }); + std::transform(output_shapes[0].begin(), + output_shapes[0].end(), + std::back_inserter(input), + [](const int dim) { return ir::Expr(dim); }); std::vector filter = inputs[0]->shape; std::vector output = inputs[1]->shape; // if format is nhwc if (format == CUDNN_TENSOR_NHWC) { - input = {input[0], input[3], input[1], input[2]}; + input = {input[0], input[3], input[1], input[2]}; filter = {filter[0], filter[3], filter[1], filter[2]}; output = {output[0], output[3], output[1], output[2]}; } - std::vector args = {ir::Expr(static_cast(format)), ir::Expr(alpha), ir::Expr(beta)}; + std::vector args = { + ir::Expr(static_cast(format)), ir::Expr(alpha), ir::Expr(beta)}; args.insert(args.end(), input.begin(), input.end()); args.insert(args.end(), filter.begin(), filter.end()); args.push_back(ir::Expr(padding[0])); @@ -438,45 +519,56 @@ std::vector CustomCallArgsForCudnnConvBackwardData(const framework::No return args; } -std::vector CustomCallArgsForCudnnConvBackwardFilter(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForCudnnConvBackwardFilter( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { CHECK_EQ(inputs.size(), 2UL); CHECK_EQ(output_shapes.size(), 1UL); const auto &attr_store = attrs.attr_store; - float alpha = attr_store.count("alpha") ? absl::get(attr_store.at("alpha")) : 1.0f; - float beta = attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; + float alpha = attr_store.count("alpha") + ? absl::get(attr_store.at("alpha")) + : 1.0f; + float beta = + attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; CHECK(attr_store.count("padding")); auto padding = absl::get>(attr_store.at("padding")); CHECK(attr_store.count("stride")); auto stride = absl::get>(attr_store.at("stride")); - auto dilation = - attr_store.count("dilation") ? absl::get>(attr_store.at("dilation")) : std::vector({1, 1}); + auto dilation = attr_store.count("dilation") + ? absl::get>(attr_store.at("dilation")) + : std::vector({1, 1}); std::string data_format = - attr_store.count("data_format") ? absl::get(attr_store.at("data_format")) : "NCHW"; + attr_store.count("data_format") + ? absl::get(attr_store.at("data_format")) + : "NCHW"; if (data_format == "AnyLayout") { data_format = "NCHW"; } - int groups = attr_store.count("groups") ? absl::get(attr_store.at("groups")) : 1; + int groups = + attr_store.count("groups") ? absl::get(attr_store.at("groups")) : 1; - cudnnTensorFormat_t format = data_format == "NCHW" ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + cudnnTensorFormat_t format = + data_format == "NCHW" ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - std::vector input = inputs[0]->shape; + std::vector input = inputs[0]->shape; std::vector filter = {}; - std::transform(output_shapes[0].begin(), output_shapes[0].end(), std::back_inserter(filter), [](const int dim) { - return ir::Expr(dim); - }); + std::transform(output_shapes[0].begin(), + output_shapes[0].end(), + std::back_inserter(filter), + [](const int dim) { return ir::Expr(dim); }); std::vector output = inputs[1]->shape; // if format is nhwc if (format == CUDNN_TENSOR_NHWC) { - input = {input[0], input[3], input[1], input[2]}; + input = {input[0], input[3], input[1], input[2]}; filter = {filter[0], filter[3], filter[1], filter[2]}; output = {output[0], output[3], output[1], output[2]}; } - std::vector args = {ir::Expr(static_cast(format)), ir::Expr(alpha), ir::Expr(beta)}; + std::vector args = { + ir::Expr(static_cast(format)), ir::Expr(alpha), ir::Expr(beta)}; args.insert(args.end(), input.begin(), input.end()); args.insert(args.end(), filter.begin(), filter.end()); args.push_back(ir::Expr(padding[0])); @@ -490,14 +582,18 @@ std::vector CustomCallArgsForCudnnConvBackwardFilter(const framework:: return args; } -std::vector CustomCallArgsForCudnnPoolForward(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForCudnnPoolForward( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { CHECK_EQ(inputs.size(), 1UL); CHECK_EQ(output_shapes.size(), 1UL); const auto &attr_store = attrs.attr_store; - float alpha = attr_store.count("alpha") ? absl::get(attr_store.at("alpha")) : 1.0f; - float beta = attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; + float alpha = attr_store.count("alpha") + ? absl::get(attr_store.at("alpha")) + : 1.0f; + float beta = + attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; CHECK(attr_store.count("kernel_size")); auto kernel = absl::get>(attr_store.at("kernel_size")); @@ -508,27 +604,36 @@ std::vector CustomCallArgsForCudnnPoolForward(const framework::NodeAtt CHECK(attr_store.count("pool_type")); auto pool_type = absl::get(attr_store.at("pool_type")); CHECK(attr_store.count("data_format")); - std::string data_format = absl::get(attr_store.at("data_format")); - - bool exclusive = attr_store.count("exclusive") ? absl::get(attrs.attr_store.at("exclusive")) : true; - cudnnPoolingMode_t mode = pool_type == "max" ? CUDNN_POOLING_MAX - : (exclusive ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING - : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING); - cudnnTensorFormat_t format = data_format == "NCHW" ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + std::string data_format = + absl::get(attr_store.at("data_format")); + + bool exclusive = attr_store.count("exclusive") + ? absl::get(attrs.attr_store.at("exclusive")) + : true; + cudnnPoolingMode_t mode = + pool_type == "max" + ? CUDNN_POOLING_MAX + : (exclusive ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING); + cudnnTensorFormat_t format = + data_format == "NCHW" ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; std::vector input = inputs[0]->shape; std::vector output; - std::transform(output_shapes[0].begin(), output_shapes[0].end(), std::back_inserter(output), [](const int dim) { - return ir::Expr(dim); - }); + std::transform(output_shapes[0].begin(), + output_shapes[0].end(), + std::back_inserter(output), + [](const int dim) { return ir::Expr(dim); }); // if format is nhwc if (format == CUDNN_TENSOR_NHWC) { - input = {input[0], input[3], input[1], input[2]}; + input = {input[0], input[3], input[1], input[2]}; output = {output[0], output[3], output[1], output[2]}; } - std::vector args = { - ir::Expr(static_cast(mode)), ir::Expr(static_cast(format)), ir::Expr(alpha), ir::Expr(beta)}; + std::vector args = {ir::Expr(static_cast(mode)), + ir::Expr(static_cast(format)), + ir::Expr(alpha), + ir::Expr(beta)}; args.insert(args.end(), input.begin(), input.end()); args.push_back(ir::Expr(kernel[0])); args.push_back(ir::Expr(kernel[1])); @@ -540,14 +645,18 @@ std::vector CustomCallArgsForCudnnPoolForward(const framework::NodeAtt return args; } -std::vector CustomCallArgsForCudnnPoolBackward(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForCudnnPoolBackward( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { CHECK_EQ(inputs.size(), 3UL); CHECK_EQ(output_shapes.size(), 1UL); const auto &attr_store = attrs.attr_store; - float alpha = attr_store.count("alpha") ? absl::get(attr_store.at("alpha")) : 1.0f; - float beta = attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; + float alpha = attr_store.count("alpha") + ? absl::get(attr_store.at("alpha")) + : 1.0f; + float beta = + attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; CHECK(attr_store.count("kernel_size")); auto kernel = absl::get>(attr_store.at("kernel_size")); @@ -558,24 +667,32 @@ std::vector CustomCallArgsForCudnnPoolBackward(const framework::NodeAt CHECK(attr_store.count("pool_type")); auto pool_type = absl::get(attrs.attr_store.at("pool_type")); CHECK(attr_store.count("data_format")); - std::string data_format = absl::get(attrs.attr_store.at("data_format")); - - bool exclusive = attr_store.count("exclusive") ? absl::get(attrs.attr_store.at("exclusive")) : true; - cudnnPoolingMode_t mode = pool_type == "max" ? CUDNN_POOLING_MAX - : (exclusive ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING - : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING); - cudnnTensorFormat_t format = data_format == "NCHW" ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - std::vector input = inputs[0]->shape; // 'x' + std::string data_format = + absl::get(attrs.attr_store.at("data_format")); + + bool exclusive = attr_store.count("exclusive") + ? absl::get(attrs.attr_store.at("exclusive")) + : true; + cudnnPoolingMode_t mode = + pool_type == "max" + ? CUDNN_POOLING_MAX + : (exclusive ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING); + cudnnTensorFormat_t format = + data_format == "NCHW" ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + std::vector input = inputs[0]->shape; // 'x' std::vector output = inputs[1]->shape; // 'y' // if format is nhwc if (format == CUDNN_TENSOR_NHWC) { - input = {input[0], input[3], input[1], input[2]}; + input = {input[0], input[3], input[1], input[2]}; output = {output[0], output[3], output[1], output[2]}; } - std::vector args = { - ir::Expr(static_cast(mode)), ir::Expr(static_cast(format)), ir::Expr(alpha), ir::Expr(beta)}; + std::vector args = {ir::Expr(static_cast(mode)), + ir::Expr(static_cast(format)), + ir::Expr(alpha), + ir::Expr(beta)}; args.insert(args.end(), input.begin(), input.end()); args.push_back(ir::Expr(kernel[0])); args.push_back(ir::Expr(kernel[1])); @@ -589,48 +706,66 @@ std::vector CustomCallArgsForCudnnPoolBackward(const framework::NodeAt } #endif -std::vector CustomCallArgsForAssertTrue(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForAssertTrue( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { CHECK_EQ(inputs.size(), 1UL); CHECK_EQ(output_shapes.size(), 1UL); const auto &attr_store = attrs.attr_store; CHECK(attr_store.count("msg")); - // TODO(thisjiang): change type from 'int' to 'std::string' when custom call support 'std::string' type - int msg = absl::get(attr_store.at("msg")); - bool only_warning = attr_store.count("only_warning") ? absl::get(attrs.attr_store.at("only_warning")) : false; + // TODO(thisjiang): change type from 'int' to 'std::string' when custom call + // support 'std::string' type + int msg = absl::get(attr_store.at("msg")); + bool only_warning = attr_store.count("only_warning") + ? absl::get(attrs.attr_store.at("only_warning")) + : false; std::vector args = {ir::Expr(msg), ir::Expr(only_warning)}; return args; } -std::vector CustomCallArgsForGaussianRandom(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForGaussianRandom( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { CHECK_EQ(output_shapes.size(), 1UL); const auto &attr_store = attrs.attr_store; - float mean = attr_store.count("mean") ? absl::get(attrs.attr_store.at("mean")) : 0.0f; - float std = attr_store.count("std") ? absl::get(attrs.attr_store.at("std")) : 1.0f; - int seed = attr_store.count("seed") ? absl::get(attrs.attr_store.at("seed")) : 0; + float mean = attr_store.count("mean") + ? absl::get(attrs.attr_store.at("mean")) + : 0.0f; + float std = attr_store.count("std") + ? absl::get(attrs.attr_store.at("std")) + : 1.0f; + int seed = attr_store.count("seed") + ? absl::get(attrs.attr_store.at("seed")) + : 0; std::vector args = {ir::Expr(mean), ir::Expr(std), ir::Expr(seed)}; return args; } -std::vector CustomCallArgsForUniformRandom(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForUniformRandom( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { CHECK_EQ(output_shapes.size(), 1UL); const auto &attr_store = attrs.attr_store; - float min = attr_store.count("min") ? absl::get(attrs.attr_store.at("min")) : -1.0f; - float max = attr_store.count("max") ? absl::get(attrs.attr_store.at("max")) : 1.0f; - int seed = attr_store.count("seed") ? absl::get(attrs.attr_store.at("seed")) : 0; + float min = attr_store.count("min") + ? absl::get(attrs.attr_store.at("min")) + : -1.0f; + float max = attr_store.count("max") + ? absl::get(attrs.attr_store.at("max")) + : 1.0f; + int seed = attr_store.count("seed") + ? absl::get(attrs.attr_store.at("seed")) + : 0; CHECK_GE(max, min) << "Arg max must greater than min, please check."; @@ -639,29 +774,33 @@ std::vector CustomCallArgsForUniformRandom(const framework::NodeAttr & return args; } -std::vector CustomCallArgsForRandInt(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForRandInt( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { CHECK_EQ(output_shapes.size(), 1UL); const auto &attr_store = attrs.attr_store; - int seed = attr_store.count("seed") ? absl::get(attrs.attr_store.at("seed")) : 0; + int seed = attr_store.count("seed") + ? absl::get(attrs.attr_store.at("seed")) + : 0; std::vector args = {ir::Expr(seed)}; return args; } -std::vector CustomCallArgsForCholesky(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForCholesky( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { CHECK_EQ(inputs.size(), 1UL); const auto &attr_store = attrs.attr_store; CHECK(attr_store.count("upper")); - ir::Tensor x = inputs.front(); - int ndim = static_cast(x->shape.size()); + ir::Tensor x = inputs.front(); + int ndim = static_cast(x->shape.size()); int batch_size = 1; for (int i = 0; i < ndim - 2; i++) { batch_size *= x->shape[i].as_int32(); @@ -670,14 +809,16 @@ std::vector CustomCallArgsForCholesky(const framework::NodeAttr &attrs auto upper = absl::get(attrs.attr_store.at("upper")); - std::vector args = {ir::Expr(batch_size), ir::Expr(m), ir::Expr(upper)}; + std::vector args = { + ir::Expr(batch_size), ir::Expr(m), ir::Expr(upper)}; return args; } -std::vector CustomCallArgsForTriangularSolve(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForTriangularSolve( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { CHECK_EQ(inputs.size(), 2UL); const auto &attr_store = attrs.attr_store; CHECK(attr_store.count("left_side")); @@ -685,22 +826,23 @@ std::vector CustomCallArgsForTriangularSolve(const framework::NodeAttr CHECK(attr_store.count("transpose_a")); CHECK(attr_store.count("unit_diagonal")); - ir::Tensor a = inputs[0]; - ir::Tensor b = inputs[1]; - int a_ndim = static_cast(a->shape.size()); - int b_ndim = static_cast(b->shape.size()); + ir::Tensor a = inputs[0]; + ir::Tensor b = inputs[1]; + int a_ndim = static_cast(a->shape.size()); + int b_ndim = static_cast(b->shape.size()); int batch_size = 1; for (int i = 0; i < a_ndim - 2; i++) { batch_size *= a->shape[i].as_int32(); } - auto left_side = absl::get(attrs.attr_store.at("left_side")); - auto upper = absl::get(attrs.attr_store.at("upper")); - auto transpose_a = absl::get(attrs.attr_store.at("transpose_a")); + auto left_side = absl::get(attrs.attr_store.at("left_side")); + auto upper = absl::get(attrs.attr_store.at("upper")); + auto transpose_a = absl::get(attrs.attr_store.at("transpose_a")); auto unit_diagonal = absl::get(attrs.attr_store.at("unit_diagonal")); int m = a->shape[a_ndim - 1].as_int32(); - int k = left_side ? b->shape[b_ndim - 1].as_int32() : b->shape[b_ndim - 2].as_int32(); + int k = left_side ? b->shape[b_ndim - 1].as_int32() + : b->shape[b_ndim - 2].as_int32(); std::vector args = {ir::Expr(batch_size), ir::Expr(m), @@ -713,13 +855,16 @@ std::vector CustomCallArgsForTriangularSolve(const framework::NodeAttr return args; } -std::vector CustomCallArgsForMemset(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { +std::vector CustomCallArgsForMemset( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { const auto &attr_store = attrs.attr_store; - CHECK(attr_store.count("value")) << "The memset custom_call must has attribute \"value\""; + CHECK(attr_store.count("value")) + << "The memset custom_call must has attribute \"value\""; CHECK(inputs.empty()) << "The memset custom_call should not has any input"; - CHECK_EQ(output_shapes.size(), 1) << "The memset custom_call should only has one output"; + CHECK_EQ(output_shapes.size(), 1) + << "The memset custom_call should only has one output"; struct Visitor { int *scalar_; @@ -733,8 +878,11 @@ std::vector CustomCallArgsForMemset(const framework::NodeAttr &attrs, void operator()(int64_t v) { *scalar_ = static_cast(v); } void operator()(bool v) { *scalar_ = v ? 0xFFFFFFFF : 0; } -#define EXPAND_MEMSET_TYPE_UNSUPPORT(TYPE) \ - void operator()(const TYPE &) { LOG(FATAL) << "The type of \"value\" of memset custom_call not support: " << #TYPE; } +#define EXPAND_MEMSET_TYPE_UNSUPPORT(TYPE) \ + void operator()(const TYPE &) { \ + LOG(FATAL) << "The type of \"value\" of memset custom_call not support: " \ + << #TYPE; \ + } EXPAND_MEMSET_TYPE_UNSUPPORT(std::string) EXPAND_MEMSET_TYPE_UNSUPPORT(std::vector) @@ -746,7 +894,7 @@ std::vector CustomCallArgsForMemset(const framework::NodeAttr &attrs, #undef EXPAND_MEMSET_TYPE_UNSUPPORT }; - int value = 0; + int value = 0; const auto &value_attr = attr_store.at("value"); absl::visit(Visitor(&value), value_attr); // can support memset non-0 ? @@ -757,19 +905,24 @@ std::vector CustomCallArgsForMemset(const framework::NodeAttr &attrs, count *= dim; } - const auto &dtype = common::Str2Type(absl::get(attr_store.at("dtype"))); + const auto &dtype = + common::Str2Type(absl::get(attr_store.at("dtype"))); count *= dtype.bytes(); - VLOG(4) << "call memset custom_call with value=" << utils::Attribute2String(value_attr) << " (" << value + VLOG(4) << "call memset custom_call with value=" + << utils::Attribute2String(value_attr) << " (" << value << "), count=" << count; return {Expr(value), Expr(count)}; } -std::vector CustomCallArgsForMemcpy(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector> &output_shapes) { - CHECK_EQ(inputs.size(), 1) << "The memcpy custom_call should only has one input"; - CHECK_EQ(output_shapes.size(), 1) << "The memcpy custom_call should only has one output"; +std::vector CustomCallArgsForMemcpy( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { + CHECK_EQ(inputs.size(), 1) + << "The memcpy custom_call should only has one input"; + CHECK_EQ(output_shapes.size(), 1) + << "The memcpy custom_call should only has one output"; const auto &input_shape = ToPodVector(inputs[0]->shape); @@ -786,39 +939,61 @@ std::vector CustomCallArgsForMemcpy(const framework::NodeAttr &attrs, bool RegisteryCustomCallArgsFunc() { #ifdef CINN_WITH_CUDA + CustomCallArgsFuncRegistry::Global().Register("cinn_call_cublas", + common::DefaultNVGPUTarget(), + CustomCallArgsForCublas); CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_cublas", common::DefaultNVGPUTarget(), CustomCallArgsForCublas); - CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_gaussian_random", common::DefaultNVGPUTarget(), CustomCallArgsForGaussianRandom); - CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_uniform_random", common::DefaultNVGPUTarget(), CustomCallArgsForUniformRandom); - CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_randint", common::DefaultNVGPUTarget(), CustomCallArgsForRandInt); - CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_cholesky_nvgpu", common::DefaultNVGPUTarget(), CustomCallArgsForCholesky); - CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_batched_cublas", common::DefaultNVGPUTarget(), CustomCallArgsForBatchedCublas); + "cinn_call_gaussian_random", + common::DefaultNVGPUTarget(), + CustomCallArgsForGaussianRandom); + CustomCallArgsFuncRegistry::Global().Register("cinn_call_uniform_random", + common::DefaultNVGPUTarget(), + CustomCallArgsForUniformRandom); + CustomCallArgsFuncRegistry::Global().Register("cinn_call_randint", + common::DefaultNVGPUTarget(), + CustomCallArgsForRandInt); + CustomCallArgsFuncRegistry::Global().Register("cinn_call_cholesky_nvgpu", + common::DefaultNVGPUTarget(), + CustomCallArgsForCholesky); + CustomCallArgsFuncRegistry::Global().Register("cinn_call_batched_cublas", + common::DefaultNVGPUTarget(), + CustomCallArgsForBatchedCublas); CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_triangular_solve_nvgpu", common::DefaultNVGPUTarget(), CustomCallArgsForTriangularSolve); - CustomCallArgsFuncRegistry::Global().Register( - "cinn_assert_true_nvgpu", common::DefaultNVGPUTarget(), CustomCallArgsForAssertTrue); - CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_cuda_memset", common::DefaultNVGPUTarget(), CustomCallArgsForMemset); - CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_cuda_memcpy", common::DefaultNVGPUTarget(), CustomCallArgsForMemcpy); + "cinn_call_triangular_solve_nvgpu", + common::DefaultNVGPUTarget(), + CustomCallArgsForTriangularSolve); + CustomCallArgsFuncRegistry::Global().Register("cinn_assert_true_nvgpu", + common::DefaultNVGPUTarget(), + CustomCallArgsForAssertTrue); + CustomCallArgsFuncRegistry::Global().Register("cinn_call_cuda_memset", + common::DefaultNVGPUTarget(), + CustomCallArgsForMemset); + CustomCallArgsFuncRegistry::Global().Register("cinn_call_cuda_memcpy", + common::DefaultNVGPUTarget(), + CustomCallArgsForMemcpy); #endif #ifdef CINN_WITH_CUDNN CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_cudnn_conv2d_forward", common::DefaultNVGPUTarget(), CustomCallArgsForCudnnConvForward); + "cinn_call_cudnn_conv2d_forward", + common::DefaultNVGPUTarget(), + CustomCallArgsForCudnnConvForward); CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_cudnn_conv2d_backward_data", common::DefaultNVGPUTarget(), CustomCallArgsForCudnnConvBackwardData); + "cinn_call_cudnn_conv2d_backward_data", + common::DefaultNVGPUTarget(), + CustomCallArgsForCudnnConvBackwardData); CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_cudnn_conv2d_backward_filter", common::DefaultNVGPUTarget(), CustomCallArgsForCudnnConvBackwardFilter); + "cinn_call_cudnn_conv2d_backward_filter", + common::DefaultNVGPUTarget(), + CustomCallArgsForCudnnConvBackwardFilter); CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_cudnn_pool2d_forward", common::DefaultNVGPUTarget(), CustomCallArgsForCudnnPoolForward); + "cinn_call_cudnn_pool2d_forward", + common::DefaultNVGPUTarget(), + CustomCallArgsForCudnnPoolForward); CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_cudnn_pool2d_backward", common::DefaultNVGPUTarget(), CustomCallArgsForCudnnPoolBackward); + "cinn_call_cudnn_pool2d_backward", + common::DefaultNVGPUTarget(), + CustomCallArgsForCudnnPoolBackward); #endif #ifdef CINN_WITH_MKLDNN @@ -827,13 +1002,15 @@ bool RegisteryCustomCallArgsFunc() { #ifdef CINN_WITH_MKL_CBLAS - CustomCallArgsFuncRegistry::Global().Register( - "cinn_call_cholesky_host", common::DefaultHostTarget(), CustomCallArgsForCholesky); + CustomCallArgsFuncRegistry::Global().Register("cinn_call_cholesky_host", + common::DefaultHostTarget(), + CustomCallArgsForCholesky); #endif - CustomCallArgsFuncRegistry::Global().Register( - "cinn_assert_true_host", common::DefaultHostTarget(), CustomCallArgsForAssertTrue); + CustomCallArgsFuncRegistry::Global().Register("cinn_assert_true_host", + common::DefaultHostTarget(), + CustomCallArgsForAssertTrue); return true; } @@ -846,8 +1023,10 @@ static bool registry_custom_call_list_func = RegisteryCustomCallArgsFunc(); CINN_REGISTER_HELPER(custom_call_op) { CINN_REGISTER_OP(custom_call) .describe("This operator implements the call of extern api!") - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForCustomCall) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible); + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForCustomCall) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible); return true; } diff --git a/paddle/cinn/hlir/op/elementwise.cc b/paddle/cinn/hlir/op/elementwise.cc index 8074fa8a89d94..c2ac4e1a1b2ec 100644 --- a/paddle/cinn/hlir/op/elementwise.cc +++ b/paddle/cinn/hlir/op/elementwise.cc @@ -38,89 +38,105 @@ using common::CINNValuePack; using framework::OpStrategy; using framework::shape_t; using framework::StrategyFunction; -using PeFunc = std::function(const ir::Tensor &A, const std::string &out_name)>; - -#define StrategyForUnary(op_name__, pe__) \ - std::shared_ptr StrategyFor##pe__(const framework::NodeAttr &attrs, \ - const std::vector &inputs, \ - const std::vector &out_type, \ - const std::vector> &output_shapes, \ - const Target &target) { \ - return StrategyForElementwise(attrs, inputs, out_type, output_shapes, target, #op_name__, pe::pe__); \ +using PeFunc = std::function( + const ir::Tensor &A, const std::string &out_name)>; + +#define StrategyForUnary(op_name__, pe__) \ + std::shared_ptr StrategyFor##pe__( \ + const framework::NodeAttr &attrs, \ + const std::vector &inputs, \ + const std::vector &out_type, \ + const std::vector> &output_shapes, \ + const Target &target) { \ + return StrategyForElementwise( \ + attrs, inputs, out_type, output_shapes, target, #op_name__, pe::pe__); \ } -std::shared_ptr StrategyForElementwise(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target, - const std::string &op_name, - const PeFunc &pe_func) { - framework::CINNCompute unary_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check."; - CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 1U) << "1 input tensor for " << op_name << " compute"; - std::string tensor_name = UniqName(op_name + "_Out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2U); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); - } - Expr A_expr = pack_args[0]; - CHECK(A_expr.as_tensor()); - ir::Tensor A = A_expr.as_tensor_ref(); - auto out = pe_func(A, tensor_name); - auto stages = CreateStages({A}); - std::vector res; - for (auto &t : out) { - stages->InsertLazily(t); - res.push_back(CINNValue(t)); - } - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); +std::shared_ptr StrategyForElementwise( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target, + const std::string &op_name, + const PeFunc &pe_func) { + framework::CINNCompute unary_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of " << op_name + << " compute is empty! Please check."; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 1U) + << "1 input tensor for " << op_name << " compute"; + std::string tensor_name = UniqName(op_name + "_Out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2U); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } + Expr A_expr = pack_args[0]; + CHECK(A_expr.as_tensor()); + ir::Tensor A = A_expr.as_tensor_ref(); + auto out = pe_func(A, tensor_name); + auto stages = CreateStages({A}); + std::vector res; + for (auto &t : out) { + stages->InsertLazily(t); + res.push_back(CINNValue(t)); + } + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); auto strategy = std::make_shared(); - strategy->AddImpl( - unary_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy." + op_name + ".x86", 1); + strategy->AddImpl(unary_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy." + op_name + ".x86", + 1); return strategy; } -std::vector InferShapeForElementwise(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForElementwise( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 1UL); std::vector res{inputs_shape[0]}; return res; } -std::vector InferDtypeForElementwise(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForElementwise( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::vector InferDtypeForElementwiseBool(const std::vector &inputs_type, - const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForElementwiseBool( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; return {Bool()}; } -std::vector> InferLayoutForElementwise(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 1U) << "The input's layouts size is not 1! Please check again."; +std::vector> InferLayoutForElementwise( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 1U) + << "The input's layouts size is not 1! Please check again."; return {input_layouts, input_layouts}; } -std::shared_ptr StrategyForScale(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - float scale = 1.f; - float bias = 0.f; +std::shared_ptr StrategyForScale( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + float scale = 1.f; + float bias = 0.f; bool bias_after_scale = true; for (auto &iter : attrs.attr_store) { if (iter.first == "scale") { @@ -131,42 +147,50 @@ std::shared_ptr StrategyForScale(const framework::NodeAttr &attrs, bias_after_scale = absl::get(iter.second); } } - framework::CINNCompute scale_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of scale compute is empty! Please check."; - CINNValuePack pack_args = args[0]; - CHECK(!pack_args.empty()) << "The input tensors of scale compute is empty! Please check."; - Expr A_expr = pack_args[0]; - CHECK(A_expr.as_tensor()); - ir::Tensor A = A_expr.as_tensor_ref(); - ir::Tensor out; - std::string tensor_name = UniqName("Scale_out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); - } - - if (bias_after_scale) { - out = Compute( - A->shape, - [=](const std::vector &indice) { - return ir::Cast::Make(A->type(), Expr(scale)) * A(indice) + ir::Cast::Make(A->type(), Expr(bias)); - }, - tensor_name); - } else { - out = Compute( - A->shape, - [=](const std::vector &indice) { - return ir::Cast::Make(A->type(), Expr(scale)) * (A(indice) + ir::Cast::Make(A->type(), Expr(bias))); - }, - tensor_name); - } - auto stages = CreateStages({out}); - *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; - }); + framework::CINNCompute scale_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of scale compute is empty! Please check."; + CINNValuePack pack_args = args[0]; + CHECK(!pack_args.empty()) + << "The input tensors of scale compute is empty! Please check."; + Expr A_expr = pack_args[0]; + CHECK(A_expr.as_tensor()); + ir::Tensor A = A_expr.as_tensor_ref(); + ir::Tensor out; + std::string tensor_name = UniqName("Scale_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } + + if (bias_after_scale) { + out = Compute( + A->shape, + [=](const std::vector &indice) { + return ir::Cast::Make(A->type(), Expr(scale)) * A(indice) + + ir::Cast::Make(A->type(), Expr(bias)); + }, + tensor_name); + } else { + out = Compute( + A->shape, + [=](const std::vector &indice) { + return ir::Cast::Make(A->type(), Expr(scale)) * + (A(indice) + ir::Cast::Make(A->type(), Expr(bias))); + }, + tensor_name); + } + auto stages = CreateStages({out}); + *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; + }); auto strategy = std::make_shared(); - strategy->AddImpl(scale_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.scale.x86", 1); + strategy->AddImpl(scale_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.scale.x86", + 1); return strategy; } @@ -182,26 +206,41 @@ Expr GetScalarExpr(const framework::NodeAttr::attr_t &attr) { void operator()(int64_t v) { scalar_ = Expr(v); } void operator()(bool v) { scalar_ = Expr(v); } void operator()(const std::string &v) { scalar_ = Expr(v); } - void operator()(const std::vector &) { LOG(FATAL) << "wrong type std::vector"; } - void operator()(const std::vector &) { LOG(FATAL) << "wrong type std::vector"; } - void operator()(const std::vector &) { LOG(FATAL) << "wrong type std::vector"; } - void operator()(const std::vector &) { LOG(FATAL) << "wrong type std::vector"; } - void operator()(const std::vector &) { LOG(FATAL) << "wrong type std::vector"; } - void operator()(const std::vector &) { LOG(FATAL) << "wrong type std::vector"; } + void operator()(const std::vector &) { + LOG(FATAL) << "wrong type std::vector"; + } + void operator()(const std::vector &) { + LOG(FATAL) << "wrong type std::vector"; + } + void operator()(const std::vector &) { + LOG(FATAL) << "wrong type std::vector"; + } + void operator()(const std::vector &) { + LOG(FATAL) << "wrong type std::vector"; + } + void operator()(const std::vector &) { + LOG(FATAL) << "wrong type std::vector"; + } + void operator()(const std::vector &) { + LOG(FATAL) << "wrong type std::vector"; + } }; absl::visit(Visitor{scalar}, attr); return scalar; } -std::shared_ptr StrategyForConstScalar(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute const_scalar_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of const_float compute is empty! Please check."; - auto scalar = GetScalarExpr(attrs.attr_store.at("value")); - auto scalar_type = out_type.at(0); +std::shared_ptr StrategyForConstScalar( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute const_scalar_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of const_float compute is empty! Please check."; + auto scalar = GetScalarExpr(attrs.attr_store.at("value")); + auto scalar_type = out_type.at(0); CINNValuePack pack_args = args[0]; std::string tensor_name = UniqName("const_scalar_Out"); if (FLAGS_cinn_ir_schedule) { @@ -213,28 +252,35 @@ std::shared_ptr StrategyForConstScalar(const framework::NodeAttr &at auto out = lang::Compute( {Expr(1)}, [=](const std::vector &indice) { - auto res = (scalar_type == scalar->type()) ? scalar : ir::Cast::Make(scalar_type, scalar); + auto res = (scalar_type == scalar->type()) + ? scalar + : ir::Cast::Make(scalar_type, scalar); return res; }, tensor_name); - CHECK(out.defined()) << "can't create const scalar with the given type " << out_type[0]; + CHECK(out.defined()) << "can't create const scalar with the given type " + << out_type[0]; auto stages = CreateStages({out}); - *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; }); auto strategy = std::make_shared(); - strategy->AddImpl( - const_scalar_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.const_scalar.x86", 1); + strategy->AddImpl(const_scalar_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.const_scalar.x86", + 1); return strategy; } -std::vector InferShapeForConstScalar(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForConstScalar( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { return {{1}}; } -std::vector InferDtypeForConstScalar(const std::vector &inputs_type, const framework::AttrMapType &attrs) { +std::vector InferDtypeForConstScalar( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { Type out_type; if (attrs.find("dtype") != attrs.end()) { auto dtype_str = absl::get(attrs.at("dtype")); @@ -243,34 +289,40 @@ std::vector InferDtypeForConstScalar(const std::vector &inputs_type, } } else { auto scalar = GetScalarExpr(attrs.at("value")); - out_type = scalar->type(); + out_type = scalar->type(); } VLOG(3) << "scalar type: " << out_type; return {out_type}; } -std::vector> InferLayoutForConstScalar(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector> InferLayoutForConstScalar( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { return {{"C"}, input_layouts}; } -std::shared_ptr StrategyForSum(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - LOG(FATAL) << "The operator will be decomposed into several primitive operators. Please Use Decomposer Program Pass."; +std::shared_ptr StrategyForSum( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + LOG(FATAL) << "The operator will be decomposed into several primitive " + "operators. Please Use Decomposer Program Pass."; } -std::vector InferShapeForSum(const std::vector &inputs_shape, const framework::AttrMapType &attrs) { +std::vector InferShapeForSum(const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { CHECK(!inputs_shape.empty()) << "At least 1 input tensor for sum operator."; auto shape = inputs_shape[0]; for (size_t i = 1; i < inputs_shape.size(); ++i) { if (inputs_shape[i] != shape) { - LOG(FATAL) << "The input shapes must be the same. But received: the i-th(" << i << ") input shape is " - << utils::Join(inputs_shape[i], ",") << " and the first input shape is " << utils::Join(shape, ","); + LOG(FATAL) << "The input shapes must be the same. But received: the i-th(" + << i << ") input shape is " + << utils::Join(inputs_shape[i], ",") + << " and the first input shape is " << utils::Join(shape, ","); } } std::vector out_shape{shape}; @@ -278,92 +330,109 @@ std::vector InferShapeForSum(const std::vector &inputs_shape, return out_shape; } -std::vector InferDtypeForSum(const std::vector &inputs_type, const framework::AttrMapType &attrs) { +std::vector InferDtypeForSum(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { CHECK(!inputs_type.empty()) << "At least 1 input tensor for sum operator."; auto type = inputs_type[0]; for (size_t i = 1; i < inputs_type.size(); ++i) { if (inputs_type[i] != type) { - LOG(FATAL) << "The input types must be the same. But received: the i-th(" << i << ") input type is " - << inputs_type[i] << " and the first input type is " << type; + LOG(FATAL) << "The input types must be the same. But received: the i-th(" + << i << ") input type is " << inputs_type[i] + << " and the first input type is " << type; } } std::vector res{type}; return res; } -std::shared_ptr StrategyForFillConstant(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute fill_constant_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of fill_constant compute is empty! Please check."; - bool force_cpu = false; - CHECK(attrs.attr_store.count("shape")); - auto shape = absl::get>(attrs.attr_store.at("shape")); - CHECK(attrs.attr_store.count("value")); - auto value = GetScalarExpr(attrs.attr_store.at("value")); - CHECK(attrs.attr_store.count("force_cpu")); - force_cpu = absl::get(attrs.attr_store.at("force_cpu")); - - if (force_cpu && target != common::DefaultHostTarget()) { - LOG(WARNING) << "The attribute \"force_cpu\" of \"fill_constant\" not supported in CINN! The \"fill_constant\"'s " - "output tensor will placed on " - << target; - } - - CINNValuePack arg_pack = args[0]; - std::string tensor_name = UniqName("fill_constant_Out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(arg_pack.size(), 1U); - CHECK(arg_pack[0].is_string()); - tensor_name = arg_pack[0].operator std::string(); - } - CHECK(!shape.empty()) << "shape attr is empty!"; - auto shape_exprs = ToCinnExprs(shape); - auto out = lang::Compute( - shape_exprs, [=](const std::vector &indice) { return ir::Cast::Make(out_type[0], value); }, tensor_name); - CHECK(out.defined()) << "can't create fill_constant with the given type " << out_type[0]; - auto stages = CreateStages({out}); - *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; - }); +std::shared_ptr StrategyForFillConstant( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute fill_constant_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of fill_constant compute " + "is empty! Please check."; + bool force_cpu = false; + CHECK(attrs.attr_store.count("shape")); + auto shape = absl::get>(attrs.attr_store.at("shape")); + CHECK(attrs.attr_store.count("value")); + auto value = GetScalarExpr(attrs.attr_store.at("value")); + CHECK(attrs.attr_store.count("force_cpu")); + force_cpu = absl::get(attrs.attr_store.at("force_cpu")); + + if (force_cpu && target != common::DefaultHostTarget()) { + LOG(WARNING) << "The attribute \"force_cpu\" of \"fill_constant\" " + "not supported in CINN! The \"fill_constant\"'s " + "output tensor will placed on " + << target; + } + + CINNValuePack arg_pack = args[0]; + std::string tensor_name = UniqName("fill_constant_Out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(arg_pack.size(), 1U); + CHECK(arg_pack[0].is_string()); + tensor_name = arg_pack[0].operator std::string(); + } + CHECK(!shape.empty()) << "shape attr is empty!"; + auto shape_exprs = ToCinnExprs(shape); + auto out = lang::Compute( + shape_exprs, + [=](const std::vector &indice) { + return ir::Cast::Make(out_type[0], value); + }, + tensor_name); + CHECK(out.defined()) + << "can't create fill_constant with the given type " << out_type[0]; + auto stages = CreateStages({out}); + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + }); auto strategy = std::make_shared(); - strategy->AddImpl( - fill_constant_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.fill_constant.x86", 1); + strategy->AddImpl(fill_constant_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.fill_constant.x86", + 1); return strategy; } -std::vector InferShapeForFillConstant(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForFillConstant( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { CHECK(attrs.count("shape")); auto shape = absl::get>(attrs.at("shape")); CHECK(!shape.empty()) << "shape attr is empty!"; return {shape}; } -std::vector InferDtypeForFillConstant(const std::vector &inputs_type, const framework::AttrMapType &attrs) { +std::vector InferDtypeForFillConstant( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { common::Type out_type; CHECK(attrs.count("value")); if (attrs.find("dtype") != attrs.end()) { // attribute [dtype] are given auto dtype_str = absl::get(attrs.at("dtype")); - out_type = common::Str2Type(dtype_str); + out_type = common::Str2Type(dtype_str); VLOG(3) << "FillConstant output dtype (from [dtype]): " << dtype_str; } else { // attribute [dtype] no given, inferred by value's type auto scalar = GetScalarExpr(attrs.at("value")); - out_type = scalar->type(); - VLOG(3) << "FillConstant scalar type (from [value]): " << common::Type2Str(out_type); + out_type = scalar->type(); + VLOG(3) << "FillConstant scalar type (from [value]): " + << common::Type2Str(out_type); } return {out_type}; } -std::vector> InferLayoutForFillConstant(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector> InferLayoutForFillConstant( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { return {{""}, input_layouts}; } @@ -374,17 +443,21 @@ std::vector> InferLayoutForFillConstant(const std::vect MACRO(double) \ MACRO(float) -std::shared_ptr StrategyForAssignValue(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute assign_value_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of assign_value compute is empty! Please check."; - CHECK(attrs.attr_store.count("values")) << "assign_value should set attribute [values]! Please check."; +std::shared_ptr StrategyForAssignValue( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute assign_value_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of assign_value compute is empty! Please check."; + CHECK(attrs.attr_store.count("values")) + << "assign_value should set attribute [values]! Please check."; const auto &value = attrs.attr_store.at("values"); - CINNValuePack arg_pack = args[0]; + CINNValuePack arg_pack = args[0]; std::string tensor_name = UniqName("T_assign_value_out"); if (FLAGS_cinn_ir_schedule) { CHECK_EQ(arg_pack.size(), 1U); @@ -393,12 +466,14 @@ std::shared_ptr StrategyForAssignValue(const framework::NodeAttr &at } absl::optional out; -#define EXPAND_VALUE_TO_TENSOR(TYPE) \ - else if (absl::get_if(&value)) { \ - out = pe::AssignValue(std::vector{absl::get(value)}, out_type[0], tensor_name); \ - } \ - else if (absl::get_if>(&value)) { \ - out = pe::AssignValue(absl::get>(value), out_type[0], tensor_name); \ +#define EXPAND_VALUE_TO_TENSOR(TYPE) \ + else if (absl::get_if(&value)) { \ + out = pe::AssignValue( \ + std::vector{absl::get(value)}, out_type[0], tensor_name); \ + } \ + else if (absl::get_if>(&value)) { \ + out = pe::AssignValue( \ + absl::get>(value), out_type[0], tensor_name); \ } if (false) { @@ -409,22 +484,28 @@ std::shared_ptr StrategyForAssignValue(const framework::NodeAttr &at } #undef EXPAND_VALUE_TO_TENSOR - CHECK(out && out.value().defined()) << "can't create assign_value with the given type " << out_type[0]; + CHECK(out && out.value().defined()) + << "can't create assign_value with the given type " << out_type[0]; auto stages = CreateStages({out.value()}); - *ret = CINNValuePack{{CINNValue(Expr(out.value().get())), CINNValue(stages)}}; + *ret = + CINNValuePack{{CINNValue(Expr(out.value().get())), CINNValue(stages)}}; }); auto strategy = std::make_shared(); - strategy->AddImpl( - assign_value_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.assign_value.x86", 1); + strategy->AddImpl(assign_value_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.assign_value.x86", + 1); return strategy; } -std::vector InferShapeForAssignValue(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK(attrs.count("values")) << "assign_value should set attribute [values]! Please check."; +std::vector InferShapeForAssignValue( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(attrs.count("values")) + << "assign_value should set attribute [values]! Please check."; const auto &value = attrs.at("values"); shape_t shape; @@ -444,12 +525,14 @@ std::vector InferShapeForAssignValue(const std::vector &inputs } #undef EXPAND_ATTR_TO_GET_SHAPE - VLOG(3) << "The output shape of assign_value is [" << cinn::utils::Join(shape, ", ") << "]"; + VLOG(3) << "The output shape of assign_value is [" + << cinn::utils::Join(shape, ", ") << "]"; return {shape}; } -std::vector InferDtypeForAssignValue(const std::vector &inputs_type, const framework::AttrMapType &attrs) { +std::vector InferDtypeForAssignValue( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { Type out_type; if (attrs.find("dtype") != attrs.end()) { // attribute [dtype] are given @@ -463,7 +546,8 @@ std::vector InferDtypeForAssignValue(const std::vector &inputs_type, // attribute [dtype] not given or is empty if (out_type.is_unk()) { // infer from [values]'s dtype - CHECK(attrs.count("values")) << "assign_value should set attribute [values]! Please check."; + CHECK(attrs.count("values")) + << "assign_value should set attribute [values]! Please check."; const auto &value = attrs.at("values"); #define EXPAND_ATTR_TO_GET_DTYPE(TYPE) \ @@ -488,10 +572,11 @@ std::vector InferDtypeForAssignValue(const std::vector &inputs_type, return {out_type}; } -std::vector> InferLayoutForAssignValue(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector> InferLayoutForAssignValue( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { return {{""}, input_layouts}; } @@ -538,23 +623,29 @@ StrategyForUnary(popc, Popc); #undef StrategyForUnary -std::shared_ptr StrategyForSqueeze(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForSqueeze( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { const std::vector &axes = - attrs.attr_store.count("axes") ? absl::get>(attrs.attr_store.at("axes")) : std::vector{}; - - framework::CINNCompute squeeze_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of Squeeze compute is empty! Please check.\n"; + attrs.attr_store.count("axes") + ? absl::get>(attrs.attr_store.at("axes")) + : std::vector{}; + + framework::CINNCompute squeeze_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of Squeeze compute is empty! Please check.\n"; CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 1U) << "at least 1 input tensors for Squeeze compute\n"; + CHECK_GE(pack_args.size(), 1U) + << "at least 1 input tensors for Squeeze compute\n"; Expr A = pack_args[0]; CHECK(A.as_tensor()); CHECK(!output_shapes.empty()); auto tensor_A = A.as_tensor_ref(); - auto stages = CreateStages({tensor_A}); + auto stages = CreateStages({tensor_A}); VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") << ", output_shapes: " << utils::Join(output_shapes[0], ", "); @@ -568,29 +659,37 @@ std::shared_ptr StrategyForSqueeze(const framework::NodeA std::vector res; stages->InsertLazily(out); res.push_back(CINNValue(out)); - CHECK(!out_type.empty()) << "Output type of Squeeze is empty! Please check.\n"; + CHECK(!out_type.empty()) + << "Output type of Squeeze is empty! Please check.\n"; res.push_back(CINNValue(stages)); *ret = CINNValuePack{res}; }); auto strategy = std::make_shared(); - strategy->AddImpl(squeeze_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.squeeze.x86", 1); + strategy->AddImpl(squeeze_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.squeeze.x86", + 1); return strategy; } -std::vector> InferShapeForSqueeze(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector> InferShapeForSqueeze( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 1U); const std::vector &axes = - attrs.count("axes") ? absl::get>(attrs.at("axes")) : std::vector{}; - VLOG(4) << "The [axis] value used in Squeeze: " << cinn::utils::Join(axes, ","); + attrs.count("axes") ? absl::get>(attrs.at("axes")) + : std::vector{}; + VLOG(4) << "The [axis] value used in Squeeze: " + << cinn::utils::Join(axes, ","); const auto &posi_axes = utils::GetPositiveAxes(axes, inputs_shape[0].size()); std::vector output_shape; if (posi_axes.size()) { for (int idx = 0; idx < inputs_shape[0].size(); ++idx) { // if can't find idx in axis - if (std::find(posi_axes.begin(), posi_axes.end(), idx) == posi_axes.end()) { + if (std::find(posi_axes.begin(), posi_axes.end(), idx) == + posi_axes.end()) { output_shape.push_back(inputs_shape[0][idx]); } else { CHECK_EQ(inputs_shape[0][idx], 1); @@ -604,7 +703,8 @@ std::vector> InferShapeForSqueeze(const std::vector> InferShapeForSqueeze(const std::vector StrategyForExpandDims(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForExpandDims( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { const std::vector &axes = - attrs.attr_store.count("axes") ? absl::get>(attrs.attr_store.at("axes")) : std::vector{}; + attrs.attr_store.count("axes") + ? absl::get>(attrs.attr_store.at("axes")) + : std::vector{}; - framework::CINNCompute expand_dims_compute{[=](lang::Args args, lang::RetValue *ret) { + framework::CINNCompute expand_dims_compute{[=](lang::Args args, + lang::RetValue *ret) { CHECK(!args.empty()) << "The input args are empty! Please check again."; CINNValuePack input_args = args[0]; - int input_size = input_args.size(); - CHECK_GE(input_size, 1U) << "Require 1 input tensors for expand_dims compute."; + int input_size = input_args.size(); + CHECK_GE(input_size, 1U) + << "Require 1 input tensors for expand_dims compute."; Expr x = input_args[0]; CHECK(x.as_tensor()); @@ -635,7 +740,8 @@ std::shared_ptr StrategyForExpandDims(const framework::NodeAttr &att tensor_name = input_args[1].operator std::string(); } - auto out = pe::ExpandDims(x.as_tensor_ref(), axes, output_shapes[0], tensor_name); + auto out = + pe::ExpandDims(x.as_tensor_ref(), axes, output_shapes[0], tensor_name); auto stages = CreateStages({x.as_tensor_ref()}); stages->InsertLazily(out); std::vector res{CINNValue(out), CINNValue(stages)}; @@ -643,19 +749,25 @@ std::shared_ptr StrategyForExpandDims(const framework::NodeAttr &att }}; auto strategy = std::make_shared(); - strategy->AddImpl( - expand_dims_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.expand_dims.x86", 1); + strategy->AddImpl(expand_dims_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.expand_dims.x86", + 1); return strategy; } -std::vector> InferShapeForExpandDims(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; +std::vector> InferShapeForExpandDims( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) + << "The input's shape size is 0! Please check again."; CHECK_EQ(inputs_shape.size(), 1U); const std::vector &axes = - attrs.count("axes") ? absl::get>(attrs.at("axes")) : std::vector{}; - VLOG(4) << "The [axes] value used in ExpandDims: " << cinn::utils::Join(axes, ","); + attrs.count("axes") ? absl::get>(attrs.at("axes")) + : std::vector{}; + VLOG(4) << "The [axes] value used in ExpandDims: " + << cinn::utils::Join(axes, ","); std::vector out_shape(inputs_shape[0].size() + axes.size(), 1); const auto &posi_axes = utils::GetPositiveAxes(axes, out_shape.size()); @@ -671,27 +783,33 @@ std::vector> InferShapeForExpandDims(const std::vector StrategyForReshape(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute reshape_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of Reshape compute is empty! Please check.\n"; +std::shared_ptr StrategyForReshape( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute reshape_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of Reshape compute is empty! Please check.\n"; CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 1U) << "at least 1 input tensors for Reshape compute\n"; + CHECK_GE(pack_args.size(), 1U) + << "at least 1 input tensors for Reshape compute\n"; Expr A = pack_args[0]; CHECK(A.as_tensor()); CHECK(!output_shapes.empty()); auto attr_store = attrs.attr_store; CHECK(attr_store.count("shape")) << "find no attr of shape"; - std::vector new_shape = absl::get>(attr_store.at("shape")); - auto tensor_A = A.as_tensor_ref(); - auto stages = CreateStages({tensor_A}); + std::vector new_shape = + absl::get>(attr_store.at("shape")); + auto tensor_A = A.as_tensor_ref(); + auto stages = CreateStages({tensor_A}); VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") << ", output_shapes: " << utils::Join(output_shapes[0], ", "); @@ -706,20 +824,26 @@ std::shared_ptr StrategyForReshape(const framework::NodeAttr &attrs, std::vector res; stages->InsertLazily(out); res.push_back(CINNValue(out)); - CHECK(!out_type.empty()) << "Output type of Reshape is empty! Please check.\n"; + CHECK(!out_type.empty()) + << "Output type of Reshape is empty! Please check.\n"; res.push_back(CINNValue(stages)); *ret = CINNValuePack{res}; }); auto strategy = std::make_shared(); - strategy->AddImpl(reshape_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.reshape.x86", 1); + strategy->AddImpl(reshape_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.reshape.x86", + 1); return strategy; } -std::vector> InferShapeForReshape(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 1U) << "The input's shape size should be 1! Please check again."; +std::vector> InferShapeForReshape( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 1U) + << "The input's shape size should be 1! Please check again."; std::vector output_shape; for (auto &iter : attrs) { if (iter.first == "shape") { @@ -735,15 +859,18 @@ std::vector> InferShapeForReshape(const std::vector 0) { CHECK_EQ(tensor_size % output_shape[i], 0) - << "Incompatible input shape and output shape in op reshape: " << tensor_size << ", " << output_shape[i]; + << "Incompatible input shape and output shape in op reshape: " + << tensor_size << ", " << output_shape[i]; tensor_size /= output_shape[i]; } else if (output_shape[i] == 0) { CHECK_LT(i, inputs_shape[0].size()) - << "In op reshape, when attribute shape[i] == 0, shape[i] = input_shape[i]. But now the size of input_shape " + << "In op reshape, when attribute shape[i] == 0, shape[i] = " + "input_shape[i]. But now the size of input_shape " "<= i, which is incompatible. Please check!"; output_shape[i] = inputs_shape[0][i]; CHECK_EQ(tensor_size % output_shape[i], 0) - << "Incompatible input shape and output shape in op reshape: " << tensor_size << ", " << output_shape[i]; + << "Incompatible input shape and output shape in op reshape: " + << tensor_size << ", " << output_shape[i]; tensor_size /= output_shape[i]; } else if (output_shape[i] == -1 && flag_index == -1) { flag_index = i; @@ -758,51 +885,61 @@ std::vector> InferShapeForReshape(const std::vector StrategyForCast(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute cast_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of Cast compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 1U) << "at least 1 input tensors for Cast compute\n"; - Expr A = pack_args[0]; - CHECK(A.as_tensor()); - CHECK(!output_shapes.empty()); - auto tensor_A = A.as_tensor_ref(); - auto stages = CreateStages({tensor_A}); - VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") - << ", output_shapes: " << utils::Join(output_shapes[0], ", "); - std::string tensor_name = UniqName("Cast_out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2U); - tensor_name = pack_args[1].operator std::string(); - } - ir::Tensor out = pe::Cast(tensor_A, out_type[0], tensor_name); - std::vector res; - stages->InsertLazily(out); - res.push_back(CINNValue(out)); - CHECK(!out_type.empty()) << "Output type of Cast is empty! Please check.\n"; - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); +std::shared_ptr StrategyForCast( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute cast_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of Cast compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 1U) + << "at least 1 input tensors for Cast compute\n"; + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + CHECK(!output_shapes.empty()); + auto tensor_A = A.as_tensor_ref(); + auto stages = CreateStages({tensor_A}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + std::string tensor_name = UniqName("Cast_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2U); + tensor_name = pack_args[1].operator std::string(); + } + ir::Tensor out = pe::Cast(tensor_A, out_type[0], tensor_name); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) + << "Output type of Cast is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); auto strategy = std::make_shared(); - strategy->AddImpl(cast_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.reshape.x86", 1); + strategy->AddImpl(cast_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.reshape.x86", + 1); return strategy; } -std::vector InferDtypeForCast(const std::vector &inputs_type, const framework::AttrMapType &attrs) { +std::vector InferDtypeForCast(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { CHECK(attrs.count("dtype")); return {common::Str2Type(absl::get(attrs.at("dtype")))}; } -std::shared_ptr StrategyForArange(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForArange( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { auto attr_store = attrs.attr_store; CHECK(attr_store.count("start")); CHECK(attr_store.count("stop")); @@ -810,55 +947,64 @@ std::shared_ptr StrategyForArange(const framework::NodeAt CHECK(attr_store.count("dtype")); auto start = absl::get(attr_store.at("start")); - auto stop = absl::get(attr_store.at("stop")); - auto step = absl::get(attr_store.at("step")); + auto stop = absl::get(attr_store.at("stop")); + auto step = absl::get(attr_store.at("step")); auto dtype = common::Str2Type(absl::get(attr_store.at("dtype"))); - framework::CINNCompute arange_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of arange compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - - std::string tensor_name = common::UniqName("T_Arange_out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 1U); - tensor_name = pack_args[0].operator std::string(); - } - - auto out = pe::Arange(start, stop, step, dtype, tensor_name); - std::vector res; - auto stages = CreateStages({out}); - res.push_back(common::CINNValue(out)); - res.push_back(common::CINNValue(stages)); - *ret = CINNValuePack{res}; - }); + framework::CINNCompute arange_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of arange compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + + std::string tensor_name = common::UniqName("T_Arange_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 1U); + tensor_name = pack_args[0].operator std::string(); + } + + auto out = pe::Arange(start, stop, step, dtype, tensor_name); + std::vector res; + auto stages = CreateStages({out}); + res.push_back(common::CINNValue(out)); + res.push_back(common::CINNValue(stages)); + *ret = CINNValuePack{res}; + }); auto strategy = std::make_shared(); - strategy->AddImpl(arange_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.reshape.x86", 1); + strategy->AddImpl(arange_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.reshape.x86", + 1); return strategy; } -std::vector> InferShapeForArange(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector> InferShapeForArange( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { CHECK(attrs.count("start")); CHECK(attrs.count("stop")); CHECK(attrs.count("step")); float start = absl::get(attrs.at("start")); - float stop = absl::get(attrs.at("stop")); - float step = absl::get(attrs.at("step")); + float stop = absl::get(attrs.at("stop")); + float step = absl::get(attrs.at("step")); CHECK_NE(step, 0.0f) << "The value of step can't be 0!"; int num = static_cast(std::ceil((stop - start) / step)); - CHECK(num) << "Invalid arange parameters, start = " << start << ", stop = " << stop << ", step = " << step + CHECK(num) << "Invalid arange parameters, start = " << start + << ", stop = " << stop << ", step = " << step << ", cause num_elem = " << num << " which is negative."; return {{num}}; } -std::vector InferDtypeForArange(const std::vector &inputs_type, const framework::AttrMapType &attrs) { +std::vector InferDtypeForArange(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { CHECK(attrs.count("dtype")); return {common::Str2Type(absl::get(attrs.at("dtype")))}; } -std::vector InferDtypeForLogicalNot(const std::vector &inputs_type, const framework::AttrMapType &attrs) { +std::vector InferDtypeForLogicalNot(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { return {common::Bool()}; } @@ -867,16 +1013,21 @@ std::vector InferDtypeForLogicalNot(const std::vector &inputs_type, } // namespace cinn CINN_REGISTER_HELPER(elementwise_ops) { -#define CINN_REGISTER_UNARY(op__, op_stragegy__) \ - CINN_REGISTER_OP(op__) \ - .describe(#op__ " function") \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr("CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) \ - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) \ - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) \ - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) \ +#define CINN_REGISTER_UNARY(op__, op_stragegy__) \ + CINN_REGISTER_OP(op__) \ + .describe(#op__ " function") \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr( \ + "CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ + .set_attr("infershape", \ + MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) \ + .set_attr("inferdtype", \ + MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) \ + .set_attr("inferlayout", \ + MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) \ + .set_attr( \ + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) \ .set_support_level(4); CINN_REGISTER_UNARY(exp, Exp); @@ -915,16 +1066,21 @@ CINN_REGISTER_HELPER(elementwise_ops) { #undef CINN_REGISTER_UNARY -#define CINN_REGISTER_COMPARE(op__, op_stragegy__) \ - CINN_REGISTER_OP(op__) \ - .describe(#op__ " function") \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr("CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) \ - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForElementwiseBool)) \ - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) \ - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) \ +#define CINN_REGISTER_COMPARE(op__, op_stragegy__) \ + CINN_REGISTER_OP(op__) \ + .describe(#op__ " function") \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr( \ + "CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ + .set_attr("infershape", \ + MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) \ + .set_attr("inferdtype", \ + MakeOpFunction(cinn::hlir::op::InferDtypeForElementwiseBool)) \ + .set_attr("inferlayout", \ + MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) \ + .set_attr( \ + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) \ .set_support_level(4); CINN_REGISTER_COMPARE(isnan, IsNan) @@ -937,133 +1093,186 @@ CINN_REGISTER_HELPER(elementwise_ops) { .describe("Putting scale and bias to the input Tensor") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForScale) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForScale) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); CINN_REGISTER_OP(const_scalar) .describe("create const scalar with the given value") .set_num_inputs(0) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForConstScalar) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForConstScalar)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForConstScalar)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForConstScalar) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForConstScalar)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForConstScalar)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForConstScalar)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForConstScalar)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); CINN_REGISTER_OP(sum) .describe("Sum the input tensors.") .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForSum) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForSum) .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForSum)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForSum)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise); + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise); CINN_REGISTER_OP(fill_constant) .describe("create tensor with the given value, type and shape") .set_num_inputs(0) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForFillConstant) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForFillConstant)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForFillConstant)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForFillConstant) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForFillConstant)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForFillConstant)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForFillConstant)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForFillConstant)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); CINN_REGISTER_OP(assign_value) .describe("create tensor with the given value, type and shape") .set_num_inputs(0) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForAssignValue) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForAssignValue)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForAssignValue)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForAssignValue) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForAssignValue)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForAssignValue)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForAssignValue)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForAssignValue)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); CINN_REGISTER_OP(squeeze) .describe("The operator is used to squeeze input tensor's dims") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForSqueeze) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForSqueeze)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForSqueeze) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForSqueeze)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); CINN_REGISTER_OP(expand_dims) .describe("This operator is used to expand input tensor's dims.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForExpandDims) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForExpandDims)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForExpandDims) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForExpandDims)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); CINN_REGISTER_OP(reshape) .describe("This operator is used to reshape input tensor X.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForReshape) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForReshape)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForReshape) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForReshape)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); CINN_REGISTER_OP(cast) .describe("This operator is used to cast input tensor's type to target.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForCast) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForCast) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForCast)) - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); CINN_REGISTER_OP(arange) .describe("Returns evenly spaced values within a given interval.") .set_num_inputs(0) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForArange) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForArange)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForArange)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForArange) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForArange)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForArange)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); CINN_REGISTER_OP(gelu) .describe("The implement of gelu.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise); + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise); CINN_REGISTER_OP(logical_not) .describe("Logical not function") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForLogicalNot) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForLogicalNot)) - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForLogicalNot) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForLogicalNot)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/external_api_registry.cc b/paddle/cinn/hlir/op/external_api_registry.cc index e2b9c3a1a3f40..000f8b92de905 100644 --- a/paddle/cinn/hlir/op/external_api_registry.cc +++ b/paddle/cinn/hlir/op/external_api_registry.cc @@ -18,25 +18,33 @@ namespace cinn { namespace hlir { namespace op { -ExternalApiInfo& ExternalApiRegistry::Register(const std::string& op_name, const common::Target& target) { +ExternalApiInfo& ExternalApiRegistry::Register(const std::string& op_name, + const common::Target& target) { return __REGISTER__(GenKey(op_name, target)); } -std::string ExternalApiRegistry::GetExternalApi(const framework::Node* op_node, const common::Target& target) { - CHECK(op_node->attrs.attr_store.count("original_op")) << "a custom_call op must store its original op name"; - std::string op_name = absl::get(op_node->attrs.attr_store.at("original_op")); +std::string ExternalApiRegistry::GetExternalApi(const framework::Node* op_node, + const common::Target& target) { + CHECK(op_node->attrs.attr_store.count("original_op")) + << "a custom_call op must store its original op name"; + std::string op_name = + absl::get(op_node->attrs.attr_store.at("original_op")); const ExternalApiInfo* external_api_info = Find(GenKey(op_name, target)); - CHECK(external_api_info) << "Op:" << op_name << " doesn't register external_api on " << target; + CHECK(external_api_info) << "Op:" << op_name + << " doesn't register external_api on " << target; std::string external_api = external_api_info->api_name; - if (external_api.empty()) { // if api_name not set directly, call trans_func to acquire + if (external_api.empty()) { // if api_name not set directly, call trans_func + // to acquire auto&& trans_func = external_api_info->trans_func; - CHECK(trans_func) << "Op:" << op_name << " register invalid ExternalApiInfo on " << target; + CHECK(trans_func) << "Op:" << op_name + << " register invalid ExternalApiInfo on " << target; external_api = trans_func(op_node); } return external_api; } -std::string ExternalApiRegistry::GenKey(const std::string& op_name, const common::Target& target) { +std::string ExternalApiRegistry::GenKey(const std::string& op_name, + const common::Target& target) { std::ostringstream oss; oss << target; return op_name + "_" + oss.str(); @@ -48,39 +56,58 @@ std::string ExternalApiRegistry::GenKey(const std::string& op_name, const common CINN_REGISTER_HELPER(op_external_api) { const auto& default_nvgpu = ::cinn::common::DefaultNVGPUTarget(); - const auto& default_host = ::cinn::common::DefaultHostTarget(); + const auto& default_host = ::cinn::common::DefaultHostTarget(); - CINN_OP_REGISTER_EXTERNAL_API(matmul, default_nvgpu).set_api_name("cinn_call_cublas"); - CINN_OP_REGISTER_EXTERNAL_API(mul, default_nvgpu).set_api_name("cinn_call_cublas"); - CINN_OP_REGISTER_EXTERNAL_API(cublas_gemm, default_nvgpu).set_api_name("cinn_call_cublas"); - CINN_OP_REGISTER_EXTERNAL_API(cublas_matmul, default_nvgpu).set_api_name("cinn_call_cublas"); - CINN_OP_REGISTER_EXTERNAL_API(gaussian_random, default_nvgpu).set_api_name("cinn_call_gaussian_random"); - CINN_OP_REGISTER_EXTERNAL_API(uniform_random, default_nvgpu).set_api_name("cinn_call_uniform_random"); - CINN_OP_REGISTER_EXTERNAL_API(randint, default_nvgpu).set_api_name("cinn_call_randint"); - CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_nvgpu).set_api_name("cinn_call_cholesky_nvgpu"); - CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_host).set_api_name("cinn_call_cholesky_host"); - CINN_OP_REGISTER_EXTERNAL_API(triangular_solve, default_nvgpu).set_api_name("cinn_call_triangular_solve_nvgpu"); - CINN_OP_REGISTER_EXTERNAL_API(assert_true, default_nvgpu).set_api_name("cinn_assert_true_nvgpu"); - CINN_OP_REGISTER_EXTERNAL_API(assert_true, default_host).set_api_name("cinn_assert_true_host"); + CINN_OP_REGISTER_EXTERNAL_API(matmul, default_nvgpu) + .set_api_name("cinn_call_cublas"); + CINN_OP_REGISTER_EXTERNAL_API(mul, default_nvgpu) + .set_api_name("cinn_call_cublas"); + CINN_OP_REGISTER_EXTERNAL_API(cublas_gemm, default_nvgpu) + .set_api_name("cinn_call_cublas"); + CINN_OP_REGISTER_EXTERNAL_API(cublas_matmul, default_nvgpu) + .set_api_name("cinn_call_cublas"); + CINN_OP_REGISTER_EXTERNAL_API(gaussian_random, default_nvgpu) + .set_api_name("cinn_call_gaussian_random"); + CINN_OP_REGISTER_EXTERNAL_API(uniform_random, default_nvgpu) + .set_api_name("cinn_call_uniform_random"); + CINN_OP_REGISTER_EXTERNAL_API(randint, default_nvgpu) + .set_api_name("cinn_call_randint"); + CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_nvgpu) + .set_api_name("cinn_call_cholesky_nvgpu"); + CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_host) + .set_api_name("cinn_call_cholesky_host"); + CINN_OP_REGISTER_EXTERNAL_API(triangular_solve, default_nvgpu) + .set_api_name("cinn_call_triangular_solve_nvgpu"); + CINN_OP_REGISTER_EXTERNAL_API(assert_true, default_nvgpu) + .set_api_name("cinn_assert_true_nvgpu"); + CINN_OP_REGISTER_EXTERNAL_API(assert_true, default_host) + .set_api_name("cinn_assert_true_host"); #ifdef CINN_WITH_CUDNN - CINN_OP_REGISTER_EXTERNAL_API(conv2d, default_nvgpu).set_trans_func([](const ::cinn::hlir::framework::Node* node) { - CHECK(node->attrs.attr_store.count("conv_type")); - std::string conv_type = absl::get(node->attrs.attr_store.at("conv_type")); - CHECK(conv_type == "forward" || conv_type == "backward_data" || conv_type == "backward_filter") - << "unknown conv_type=" << conv_type; - return "cinn_call_cudnn_conv2d_" + conv_type; - }); + CINN_OP_REGISTER_EXTERNAL_API(conv2d, default_nvgpu) + .set_trans_func([](const ::cinn::hlir::framework::Node* node) { + CHECK(node->attrs.attr_store.count("conv_type")); + std::string conv_type = + absl::get(node->attrs.attr_store.at("conv_type")); + CHECK(conv_type == "forward" || conv_type == "backward_data" || + conv_type == "backward_filter") + << "unknown conv_type=" << conv_type; + return "cinn_call_cudnn_conv2d_" + conv_type; + }); CINN_OP_REGISTER_EXTERNAL_API(depthwise_conv2d, default_nvgpu) .set_trans_func([](const ::cinn::hlir::framework::Node* node) { - std::string conv_type = node->attrs.attr_store.count("conv_type") - ? absl::get(node->attrs.attr_store.at("conv_type")) - : "forward"; - CHECK(conv_type == "forward" || conv_type == "backward_data" || conv_type == "backward_filter") + std::string conv_type = + node->attrs.attr_store.count("conv_type") + ? absl::get(node->attrs.attr_store.at("conv_type")) + : "forward"; + CHECK(conv_type == "forward" || conv_type == "backward_data" || + conv_type == "backward_filter") << "unknown conv_type=" << conv_type; return "cinn_call_cudnn_conv2d_" + conv_type; }); - CINN_OP_REGISTER_EXTERNAL_API(pool2d, default_nvgpu).set_api_name("cinn_call_cudnn_pool2d_forward"); - CINN_OP_REGISTER_EXTERNAL_API(pool2d_grad, default_nvgpu).set_api_name("cinn_call_cudnn_pool2d_backward"); + CINN_OP_REGISTER_EXTERNAL_API(pool2d, default_nvgpu) + .set_api_name("cinn_call_cudnn_pool2d_forward"); + CINN_OP_REGISTER_EXTERNAL_API(pool2d_grad, default_nvgpu) + .set_api_name("cinn_call_cudnn_pool2d_backward"); #endif return true; } diff --git a/paddle/cinn/hlir/op/external_api_registry.h b/paddle/cinn/hlir/op/external_api_registry.h index 189a51d276974..307cac68b2f20 100644 --- a/paddle/cinn/hlir/op/external_api_registry.h +++ b/paddle/cinn/hlir/op/external_api_registry.h @@ -19,19 +19,22 @@ #include "paddle/cinn/hlir/framework/node.h" #include "paddle/cinn/utils/registry.h" -#define CINN_OP_REGISTER_EXTERNAL_API(Name, Target) \ - static ::cinn::hlir::op::ExternalApiInfo& CINN_STR_CONCAT(__make_##ExternalApiInfo##_##Name##__, __COUNTER__) = \ +#define CINN_OP_REGISTER_EXTERNAL_API(Name, Target) \ + static ::cinn::hlir::op::ExternalApiInfo& CINN_STR_CONCAT( \ + __make_##ExternalApiInfo##_##Name##__, __COUNTER__) = \ ::cinn::hlir::op::ExternalApiRegistry::Global()->Register(#Name, Target) namespace cinn { namespace hlir { namespace op { -using OpNodeTransToExternalApiFunction = std::function; +using OpNodeTransToExternalApiFunction = + std::function; // This class contains detail external api information of a specified Operator. -// To provide the external api name, we can directly set it through `set_api_name` -// or set a transform function wth `set_trans_func` that return a api name finally +// To provide the external api name, we can directly set it through +// `set_api_name` or set a transform function wth `set_trans_func` that return a +// api name finally struct ExternalApiInfo { std::string name; std::string api_name; @@ -42,7 +45,8 @@ struct ExternalApiInfo { return *this; } - inline ExternalApiInfo& set_trans_func(OpNodeTransToExternalApiFunction func) { + inline ExternalApiInfo& set_trans_func( + OpNodeTransToExternalApiFunction func) { this->trans_func = func; return *this; } @@ -56,14 +60,16 @@ class ExternalApiRegistry : public Registry { return &x; } - ExternalApiInfo& Register(const std::string& op_name, const common::Target& target); + ExternalApiInfo& Register(const std::string& op_name, + const common::Target& target); bool Has(const std::string& op_name, const common::Target& target) { return nullptr != Registry::Find(GenKey(op_name, target)); } // return the api name on the specified target - std::string GetExternalApi(const framework::Node* op_node, const common::Target& target); + std::string GetExternalApi(const framework::Node* op_node, + const common::Target& target); private: ExternalApiRegistry() = default; diff --git a/paddle/cinn/hlir/op/external_api_registry_test.cc b/paddle/cinn/hlir/op/external_api_registry_test.cc index 86fa723faece1..186fb8fa53262 100644 --- a/paddle/cinn/hlir/op/external_api_registry_test.cc +++ b/paddle/cinn/hlir/op/external_api_registry_test.cc @@ -27,21 +27,27 @@ using cinn::hlir::framework::Node; using cinn::hlir::op::ExternalApiRegistry; TEST(ExternalApiRegistry, Has) { - ASSERT_TRUE(ExternalApiRegistry::Global()->Has("matmul", common::DefaultNVGPUTarget())); - ASSERT_TRUE(ExternalApiRegistry::Global()->Has("cholesky", common::DefaultHostTarget())); - ASSERT_FALSE(ExternalApiRegistry::Global()->Has("op_doesn't_exist", common::DefaultNVGPUTarget())); + ASSERT_TRUE(ExternalApiRegistry::Global()->Has("matmul", + common::DefaultNVGPUTarget())); + ASSERT_TRUE(ExternalApiRegistry::Global()->Has("cholesky", + common::DefaultHostTarget())); + ASSERT_FALSE(ExternalApiRegistry::Global()->Has( + "op_doesn't_exist", common::DefaultNVGPUTarget())); } TEST(ExternalApiRegistry, GetExternalApi) { - auto node = std::make_unique(Operator::Get("custom_call"), "custom_call"); + auto node = + std::make_unique(Operator::Get("custom_call"), "custom_call"); node->attrs.attr_store["original_op"] = std::string("matmul"); ASSERT_EQ("cinn_call_cublas", - ExternalApiRegistry::Global()->GetExternalApi(node.get(), common::DefaultNVGPUTarget())); + ExternalApiRegistry::Global()->GetExternalApi( + node.get(), common::DefaultNVGPUTarget())); #ifdef CINN_WITH_CUDNN - node->attrs.attr_store["conv_type"] = std::string("backward_data"); + node->attrs.attr_store["conv_type"] = std::string("backward_data"); node->attrs.attr_store["original_op"] = std::string("conv2d"); ASSERT_EQ("cinn_call_cudnn_conv2d_backward_data", - ExternalApiRegistry::Global()->GetExternalApi(node.get(), common::DefaultNVGPUTarget())); + ExternalApiRegistry::Global()->GetExternalApi( + node.get(), common::DefaultNVGPUTarget())); #endif } diff --git a/paddle/cinn/hlir/op/nn.cc b/paddle/cinn/hlir/op/nn.cc index 1658897f30e9e..0f3c93d3124c7 100644 --- a/paddle/cinn/hlir/op/nn.cc +++ b/paddle/cinn/hlir/op/nn.cc @@ -40,88 +40,107 @@ using framework::OpStrategy; using framework::shape_t; using framework::StrategyFunction; -std::shared_ptr StrategyForRelu(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute relu_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of relu compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - CHECK(!pack_args.empty()) << "at least one input tensor for relu compute\n"; - Expr A = pack_args[0]; - CHECK(A.as_tensor()); - std::string tensor_name = UniqName("Relu_output"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); - } - auto out = pe::Relu(A.as_tensor_ref(), 0.0, tensor_name); - auto stages = CreateStages({out}); - *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; - }); +std::shared_ptr StrategyForRelu( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute relu_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of relu compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK(!pack_args.empty()) + << "at least one input tensor for relu compute\n"; + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + std::string tensor_name = UniqName("Relu_output"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } + auto out = pe::Relu(A.as_tensor_ref(), 0.0, tensor_name); + auto stages = CreateStages({out}); + *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; + }); auto strategy = std::make_shared(); CHECK(out_type.size()) << "Out_type of relu op is empty! Please check."; - strategy->AddImpl(relu_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.relu.x86", 1); + strategy->AddImpl(relu_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.relu.x86", + 1); return strategy; } -std::vector InferShapeForRelu(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK(!inputs_shape.empty()) << "The input's shape is empty! Please check again."; +std::vector InferShapeForRelu( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(!inputs_shape.empty()) + << "The input's shape is empty! Please check again."; std::vector res{inputs_shape[0]}; return res; } -std::vector InferDtypeForRelu(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForRelu(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::shared_ptr StrategyForRelu6(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute relu6_compute([](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of relu6 compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - CHECK(!pack_args.empty()) << "at least one input tensor for relu6 compute\n"; - Expr A = pack_args[0]; - CHECK(A.as_tensor()); - std::string tensor_name = UniqName("Relu6_output"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); - } - auto out = pe::Relu6(A.as_tensor_ref(), 0.0, tensor_name); - auto stages = CreateStages({out}); - *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; - }); +std::shared_ptr StrategyForRelu6( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute relu6_compute( + [](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of relu6 compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK(!pack_args.empty()) + << "at least one input tensor for relu6 compute\n"; + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + std::string tensor_name = UniqName("Relu6_output"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } + auto out = pe::Relu6(A.as_tensor_ref(), 0.0, tensor_name); + auto stages = CreateStages({out}); + *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; + }); auto strategy = std::make_shared(); CHECK(out_type.size()) << "Out_type of relu6 op is empty! Please check."; - strategy->AddImpl(relu6_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.relu6.x86", 1); + strategy->AddImpl(relu6_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.relu6.x86", + 1); return strategy; } -std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForConv2d( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { std::vector padding({0, 0}); std::vector stride({1, 1}); std::vector dilation({1, 1}); std::string data_format = "NCHW"; - int groups = 1; - std::string key = ""; - std::string conv_type = ""; - bool use_mkldnn = false; + int groups = 1; + std::string key = ""; + std::string conv_type = ""; + bool use_mkldnn = false; if (attrs.attr_store.find("padding") != attrs.attr_store.end()) { padding = absl::get>(attrs.attr_store.at("padding")); } @@ -151,48 +170,68 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, } #ifndef CINN_WITH_CUDNN - CHECK_EQ(conv_type, "forward") << "cudnn is not found, backward_data/backward_filter is not supported!"; + CHECK_EQ(conv_type, "forward") + << "cudnn is not found, backward_data/backward_filter is not supported!"; #endif - framework::CINNCompute conv2d_compute([=](lang::Args args, lang::RetValue *ret) { - std::vector res; - CHECK(!args.empty()) << "The input argument of conv2d compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 2U) << "at least 2 input tensors for conv2d compute\n"; - Expr A = pack_args[0]; - Expr B = pack_args[1]; - CHECK(A.as_tensor()); - CHECK(B.as_tensor()); - CHECK_EQ(padding.size(), 2) << "The size of padding in conv2d op is not 2! Please check."; - CHECK_EQ(stride.size(), 2) << "The size of stride in conv2d op is not 2! Please check."; - CHECK_EQ(dilation.size(), 2) << "The size of stride in conv2d op is not 2! Please check."; - std::vector out; - VLOG(3) << "input shape: " << utils::Join(A.as_tensor_ref()->shape, ", "); - VLOG(3) << "weight shape: " << utils::Join(B.as_tensor_ref()->shape, ", "); - std::string tensor_name = UniqName("Conv2d_out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_GE(pack_args.size(), 3); - CHECK(pack_args[2].is_string()); - tensor_name = pack_args[2].operator std::string(); - } - if (data_format == "NCHW") { - // A is input: [N, C, H, W], B is filter: [C_out, C_in/group, filter_h, filter_w] - if (target.arch == Target::Arch::X86) { - if (groups == 1 && !use_mkldnn) { - out = pe::Conv2d_NCHW_5D(A.as_tensor_ref(), - B.as_tensor_ref(), - padding[0], - padding[1], - stride[0], - stride[1], - dilation[0], - dilation[1], - key, - tensor_name, - target); - } else { + framework::CINNCompute conv2d_compute( + [=](lang::Args args, lang::RetValue *ret) { + std::vector res; + CHECK(!args.empty()) + << "The input argument of conv2d compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 2U) + << "at least 2 input tensors for conv2d compute\n"; + Expr A = pack_args[0]; + Expr B = pack_args[1]; + CHECK(A.as_tensor()); + CHECK(B.as_tensor()); + CHECK_EQ(padding.size(), 2) + << "The size of padding in conv2d op is not 2! Please check."; + CHECK_EQ(stride.size(), 2) + << "The size of stride in conv2d op is not 2! Please check."; + CHECK_EQ(dilation.size(), 2) + << "The size of stride in conv2d op is not 2! Please check."; + std::vector out; + VLOG(3) << "input shape: " + << utils::Join(A.as_tensor_ref()->shape, ", "); + VLOG(3) << "weight shape: " + << utils::Join(B.as_tensor_ref()->shape, ", "); + std::string tensor_name = UniqName("Conv2d_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_GE(pack_args.size(), 3); + CHECK(pack_args[2].is_string()); + tensor_name = pack_args[2].operator std::string(); + } + if (data_format == "NCHW") { + // A is input: [N, C, H, W], B is filter: [C_out, C_in/group, + // filter_h, filter_w] + if (target.arch == Target::Arch::X86) { + if (groups == 1 && !use_mkldnn) { + out = pe::Conv2d_NCHW_5D(A.as_tensor_ref(), + B.as_tensor_ref(), + padding[0], + padding[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + key, + tensor_name, + target); + } else { #ifdef CINN_WITH_MKLDNN - out = pe::Conv2d_NCHW_MKLDNN(A.as_tensor_ref(), + out = pe::Conv2d_NCHW_MKLDNN(A.as_tensor_ref(), + B.as_tensor_ref(), + padding[0], + padding[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + tensor_name); +#else + out = pe::Conv2d_NCHW_5D(A.as_tensor_ref(), B.as_tensor_ref(), padding[0], padding[1], @@ -200,23 +239,39 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, stride[1], dilation[0], dilation[1], + key, tensor_name); -#else - out = pe::Conv2d_NCHW_5D(A.as_tensor_ref(), - B.as_tensor_ref(), - padding[0], - padding[1], - stride[0], - stride[1], - dilation[0], - dilation[1], - key, - tensor_name); #endif - } - } else { - if (conv_type == "forward") { - out = pe::Conv2d_NCHW(A.as_tensor_ref(), + } + } else { + if (conv_type == "forward") { + out = pe::Conv2d_NCHW(A.as_tensor_ref(), + B.as_tensor_ref(), + padding[0], + padding[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + tensor_name); + out.push_back(B.as_tensor_ref()); + } else { +#ifdef CINN_WITH_CUDNN + // as backward_data and backward_filter is not support now, we + // built a fake op to instead. as the runtime use cudnn to compute + // the conv2d, so this fake op is not been called. When cinn + // support backward_filter/backward_data code gen, this code is to + // be removed. + out = pe::Identity(A.as_tensor_ref()); + out.push_back(A.as_tensor_ref()); + out.push_back(B.as_tensor_ref()); +#endif + } + } + } else if (data_format == "NHWC") { + // A is input: [N, H, W, C], B is filter: [C_out, C_in/group, + // filter_h, filter_w] + out = pe::Conv2d_NHWC(A.as_tensor_ref(), B.as_tensor_ref(), padding[0], padding[1], @@ -225,48 +280,29 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, dilation[0], dilation[1], tensor_name); - out.push_back(B.as_tensor_ref()); } else { -#ifdef CINN_WITH_CUDNN - // as backward_data and backward_filter is not support now, we built a fake op to instead. - // as the runtime use cudnn to compute the conv2d, so this fake op is not been called. - // When cinn support backward_filter/backward_data code gen, this code is to be removed. - out = pe::Identity(A.as_tensor_ref()); - out.push_back(A.as_tensor_ref()); - out.push_back(B.as_tensor_ref()); -#endif + LOG(FATAL) << "Only support NCHW and NHWC data layout\n"; } - } - } else if (data_format == "NHWC") { - // A is input: [N, H, W, C], B is filter: [C_out, C_in/group, filter_h, filter_w] - out = pe::Conv2d_NHWC(A.as_tensor_ref(), - B.as_tensor_ref(), - padding[0], - padding[1], - stride[0], - stride[1], - dilation[0], - dilation[1], - tensor_name); - } else { - LOG(FATAL) << "Only support NCHW and NHWC data layout\n"; - } - auto stages = CreateStages({A.as_tensor_ref(), B.as_tensor_ref()}); + auto stages = CreateStages({A.as_tensor_ref(), B.as_tensor_ref()}); - for (auto &t : out) { - stages->InsertLazily(t); - res.push_back(CINNValue(t)); - } - CHECK(out.size() == 3U || out.size() == 2U || out.size() == 5U || out.size() == 12U) - << "The output tensor sizes of conv2d op in conv2d op should be 2 or 3 or 5\n"; + for (auto &t : out) { + stages->InsertLazily(t); + res.push_back(CINNValue(t)); + } + CHECK(out.size() == 3U || out.size() == 2U || out.size() == 5U || + out.size() == 12U) + << "The output tensor sizes of conv2d op in conv2d op should be 2 " + "or 3 or 5\n"; - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); - framework::CINNSchedule conv2d_schedule([=](lang::Args args, lang::RetValue *ret) { + framework::CINNSchedule conv2d_schedule([=](lang::Args args, + lang::RetValue *ret) { if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of conv2d schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of conv2d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; std::vector vec_ast; for (int i = 0; i < arg_pack.size(); i++) { @@ -282,12 +318,14 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, if (target.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDNN // If conv_type is backward_filter or backward_data, we built a fake op. - // As runtime use cudnn to compute conv2d, this fake op is not to be called. - // When cinn support backward_filter/backward_data code gen, this code is to be removed. + // As runtime use cudnn to compute conv2d, this fake op is not to be + // called. When cinn support backward_filter/backward_data code gen, + // this code is to be removed. if (conv_type != "forward") { CHECK_EQ(vec_ast.size(), 1); pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target); - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; return; } @@ -295,8 +333,10 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, int expr_size = vec_ast.size(); if (expr_size == 2) { pe::IRCudaScheduleConv(ir_sch, target); - VLOG(3) << "After IRCudaScheduleConv, arg_pack[0] is : " << ir_sch.GetModule().GetExprs().at(0); - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + VLOG(3) << "After IRCudaScheduleConv, arg_pack[0] is : " + << ir_sch.GetModule().GetExprs().at(0); + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; return; } else { @@ -307,59 +347,64 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, } LOG(FATAL) << "This target [" << target << "] is not supported yet."; } else { - CHECK(!args.empty()) << "The input argument of conv2d schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of conv2d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; - CHECK(arg_pack.size() == 4UL || arg_pack.size() == 3UL || arg_pack.size() == 6UL || arg_pack.size() == 13UL); + CHECK(arg_pack.size() == 4UL || arg_pack.size() == 3UL || + arg_pack.size() == 6UL || arg_pack.size() == 13UL); poly::StageMap stages = arg_pack.back(); if (target.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDNN // If conv_type is backward_filter or backward_data, we built a fake op. - // As runtime use cudnn to compute conv2d, this fake op is not to be called. - // When cinn support backward_filter/backward_data code gen, this code is to be removed. + // As runtime use cudnn to compute conv2d, this fake op is not to be + // called. When cinn support backward_filter/backward_data code gen, + // this code is to be removed. if (conv_type != "forward") { Expr out = arg_pack[0]; - pe::CudaScheduleInjective(stages[out.as_tensor_ref()], output_shapes.front(), target); + pe::CudaScheduleInjective( + stages[out.as_tensor_ref()], output_shapes.front(), target); *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; return; } #endif if (arg_pack.size() == 4UL) { - Expr Out = arg_pack[0]; - Expr input_pad = arg_pack[1]; - Expr weights = arg_pack[2]; - ir::Tensor out_t = Out.as_tensor_ref(); - ir::Tensor input_t = input_pad.as_tensor_ref(); + Expr Out = arg_pack[0]; + Expr input_pad = arg_pack[1]; + Expr weights = arg_pack[2]; + ir::Tensor out_t = Out.as_tensor_ref(); + ir::Tensor input_t = input_pad.as_tensor_ref(); ir::Tensor weights_t = weights.as_tensor_ref(); CHECK(Out.as_tensor()); pe::CudaScheduleConv(stages, input_t, weights_t, out_t, target); arg_pack[0] = Expr(out_t); arg_pack[1] = Expr(input_t); arg_pack[2] = Expr(weights_t); - *ret = CINNValuePack{{arg_pack[0], CINNValue(stages)}}; + *ret = CINNValuePack{{arg_pack[0], CINNValue(stages)}}; return; } else if (arg_pack.size() == 13UL) { - Expr wino_weights_dilation = arg_pack[0]; - Expr wino_input_pad = arg_pack[1]; - Expr wino_A = arg_pack[2]; - Expr wino_B = arg_pack[3]; - Expr wino_G = arg_pack[4]; - Expr kernel_pack = arg_pack[5]; - Expr input_tile = arg_pack[6]; - Expr data_pack = arg_pack[7]; - Expr bgemm = arg_pack[8]; - Expr inverse = arg_pack[9]; - Expr wino_conv = arg_pack[10]; - ir::Tensor wino_weights_dilation_t = wino_weights_dilation.as_tensor_ref(); - ir::Tensor wino_input_pad_t = wino_input_pad.as_tensor_ref(); - ir::Tensor wino_A_t = wino_A.as_tensor_ref(); - ir::Tensor wino_B_t = wino_B.as_tensor_ref(); - ir::Tensor wino_G_t = wino_G.as_tensor_ref(); - ir::Tensor kernel_pack_t = kernel_pack.as_tensor_ref(); - ir::Tensor input_tile_t = input_tile.as_tensor_ref(); - ir::Tensor data_pack_t = data_pack.as_tensor_ref(); - ir::Tensor bgemm_t = bgemm.as_tensor_ref(); - ir::Tensor inverse_t = inverse.as_tensor_ref(); - ir::Tensor wino_conv_t = wino_conv.as_tensor_ref(); + Expr wino_weights_dilation = arg_pack[0]; + Expr wino_input_pad = arg_pack[1]; + Expr wino_A = arg_pack[2]; + Expr wino_B = arg_pack[3]; + Expr wino_G = arg_pack[4]; + Expr kernel_pack = arg_pack[5]; + Expr input_tile = arg_pack[6]; + Expr data_pack = arg_pack[7]; + Expr bgemm = arg_pack[8]; + Expr inverse = arg_pack[9]; + Expr wino_conv = arg_pack[10]; + ir::Tensor wino_weights_dilation_t = + wino_weights_dilation.as_tensor_ref(); + ir::Tensor wino_input_pad_t = wino_input_pad.as_tensor_ref(); + ir::Tensor wino_A_t = wino_A.as_tensor_ref(); + ir::Tensor wino_B_t = wino_B.as_tensor_ref(); + ir::Tensor wino_G_t = wino_G.as_tensor_ref(); + ir::Tensor kernel_pack_t = kernel_pack.as_tensor_ref(); + ir::Tensor input_tile_t = input_tile.as_tensor_ref(); + ir::Tensor data_pack_t = data_pack.as_tensor_ref(); + ir::Tensor bgemm_t = bgemm.as_tensor_ref(); + ir::Tensor inverse_t = inverse.as_tensor_ref(); + ir::Tensor wino_conv_t = wino_conv.as_tensor_ref(); std::vector all_tensors = {wino_weights_dilation_t, wino_input_pad_t, wino_A_t, @@ -372,50 +417,58 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, inverse_t, wino_conv_t}; hlir::pe::CudaScheduleWinogradConv(stages, all_tensors, target); - arg_pack[0] = Expr(all_tensors[0]); - arg_pack[1] = Expr(all_tensors[1]); - arg_pack[2] = Expr(all_tensors[2]); - arg_pack[3] = Expr(all_tensors[3]); - arg_pack[4] = Expr(all_tensors[4]); - arg_pack[5] = Expr(all_tensors[5]); - arg_pack[6] = Expr(all_tensors[6]); - arg_pack[7] = Expr(all_tensors[7]); - arg_pack[8] = Expr(all_tensors[8]); - arg_pack[9] = Expr(all_tensors[9]); + arg_pack[0] = Expr(all_tensors[0]); + arg_pack[1] = Expr(all_tensors[1]); + arg_pack[2] = Expr(all_tensors[2]); + arg_pack[3] = Expr(all_tensors[3]); + arg_pack[4] = Expr(all_tensors[4]); + arg_pack[5] = Expr(all_tensors[5]); + arg_pack[6] = Expr(all_tensors[6]); + arg_pack[7] = Expr(all_tensors[7]); + arg_pack[8] = Expr(all_tensors[8]); + arg_pack[9] = Expr(all_tensors[9]); arg_pack[10] = Expr(all_tensors[10]); - *ret = CINNValuePack{{arg_pack[10], arg_pack[5], arg_pack[7], arg_pack[8], CINNValue(stages)}}; + *ret = CINNValuePack{{arg_pack[10], + arg_pack[5], + arg_pack[7], + arg_pack[8], + CINNValue(stages)}}; return; } } else if (target.arch == Target::Arch::X86) { if (arg_pack.size() == 6UL) { - Expr res = arg_pack[0]; - Expr packed_out = arg_pack[1]; + Expr res = arg_pack[0]; + Expr packed_out = arg_pack[1]; Expr weights_dilation = arg_pack[2]; - Expr input_pad = arg_pack[3]; - Expr data = arg_pack[4]; + Expr input_pad = arg_pack[3]; + Expr data = arg_pack[4]; CHECK(res.as_tensor()); CHECK(packed_out.as_tensor()); CHECK(input_pad.as_tensor()); CHECK(weights_dilation.as_tensor()); CHECK(data.as_tensor()); - std::vector kernel_shape = weights_dilation.as_tensor_ref()->shape; + std::vector kernel_shape = + weights_dilation.as_tensor_ref()->shape; // kernel_h == 1 && kernel_w == 1 - CHECK_EQ(kernel_shape.size(), 6U) << "kernel_dialtion shape size should be 6"; - bool is_1x1 = (is_zero(kernel_shape[2] - 1)) && (is_zero(kernel_shape[3] - 1)); + CHECK_EQ(kernel_shape.size(), 6U) + << "kernel_dialtion shape size should be 6"; + bool is_1x1 = + (is_zero(kernel_shape[2] - 1)) && (is_zero(kernel_shape[3] - 1)); ir::Tensor packed_out_tensor = packed_out.as_tensor_ref(); - bool do_padding = (padding[0] == 0 && padding[1] == 0) ? false : true; + bool do_padding = (padding[0] == 0 && padding[1] == 0) ? false : true; if (groups == 1) { if (is_1x1) { - pe::Conv2d_NCHWc_1X1_Schedule_CPU(stages, - res.as_tensor_ref(), - packed_out_tensor, - input_pad.as_tensor_ref(), - weights_dilation.as_tensor_ref(), - data.as_tensor_ref(), - target, - key, - do_padding); + pe::Conv2d_NCHWc_1X1_Schedule_CPU( + stages, + res.as_tensor_ref(), + packed_out_tensor, + input_pad.as_tensor_ref(), + weights_dilation.as_tensor_ref(), + data.as_tensor_ref(), + target, + key, + do_padding); } else { pe::Conv2d_NCHWc_Schedule_CPU(stages, res.as_tensor_ref(), @@ -428,10 +481,16 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, do_padding); } if (do_padding) { - *ret = CINNValuePack{ - {CINNValue(res), CINNValue(packed_out_tensor), arg_pack[2], arg_pack[3], CINNValue(stages)}}; + *ret = CINNValuePack{{CINNValue(res), + CINNValue(packed_out_tensor), + arg_pack[2], + arg_pack[3], + CINNValue(stages)}}; } else { - *ret = CINNValuePack{{CINNValue(res), CINNValue(packed_out_tensor), arg_pack[2], CINNValue(stages)}}; + *ret = CINNValuePack{{CINNValue(res), + CINNValue(packed_out_tensor), + arg_pack[2], + CINNValue(stages)}}; } return; } else { @@ -440,7 +499,8 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, stages[input_pad.as_tensor_ref()]->ComputeInline(); stages[weights_dilation.as_tensor_ref()]->ComputeInline(); stages[data.as_tensor_ref()]->ComputeInline(); - *ret = CINNValuePack{{arg_pack[0], CINNValue(packed_out_tensor), CINNValue(stages)}}; + *ret = CINNValuePack{ + {arg_pack[0], CINNValue(packed_out_tensor), CINNValue(stages)}}; } return; } else if (arg_pack.size() == 4UL) { @@ -464,18 +524,22 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, return strategy; } -std::vector InferShapeForConv2d(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 2) << "The conv2d should has and only has 2 inputs"; - CHECK_EQ(inputs_shape[0].size(), 4) << "The conv2d's first input only support 4-dimension tensor"; - CHECK_EQ(inputs_shape[1].size(), 4) << "The conv2d's first input only support 4-dimension tensor"; +std::vector InferShapeForConv2d( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 2) + << "The conv2d should has and only has 2 inputs"; + CHECK_EQ(inputs_shape[0].size(), 4) + << "The conv2d's first input only support 4-dimension tensor"; + CHECK_EQ(inputs_shape[1].size(), 4) + << "The conv2d's first input only support 4-dimension tensor"; std::vector padding({0, 0}); std::vector stride({1, 1}); std::vector dilation({1, 1}); - int groups = 1; + int groups = 1; std::string data_format = "NCHW"; - std::string conv_type = "forward"; + std::string conv_type = "forward"; if (attrs.find("padding") != attrs.end()) { padding = absl::get>(attrs.at("padding")); @@ -499,11 +563,16 @@ std::vector InferShapeForConv2d(const std::vector &inputs_shap conv_type = absl::get(attrs.at("conv_type")); } - CHECK_EQ(padding.size(), 2) << "The size of padding in conv2d op is not 2! Please check."; - CHECK_EQ(stride.size(), 2) << "The size of stride in conv2d op is not 2! Please check."; - CHECK_GE(inputs_shape[0].size(), 3) << "The first input tensor's shape size of conv2d op is < 3! Please check."; - CHECK(conv_type == "forward" || conv_type == "backward_data" || conv_type == "backward_filter") - << "The conv type should be one of {forward, backward_data, backward_filter}."; + CHECK_EQ(padding.size(), 2) + << "The size of padding in conv2d op is not 2! Please check."; + CHECK_EQ(stride.size(), 2) + << "The size of stride in conv2d op is not 2! Please check."; + CHECK_GE(inputs_shape[0].size(), 3) << "The first input tensor's shape size " + "of conv2d op is < 3! Please check."; + CHECK(conv_type == "forward" || conv_type == "backward_data" || + conv_type == "backward_filter") + << "The conv type should be one of {forward, backward_data, " + "backward_filter}."; CHECK(data_format == "NCHW" || data_format == "NHWC") << "The conv2d only support NCHW/NHWC, but here " << data_format; @@ -518,16 +587,26 @@ std::vector InferShapeForConv2d(const std::vector &inputs_shap std::vector output_shape(4, 0); int out_shape_h = 0, out_shape_w = 0; if (conv_type == "forward") { - // A is input: [N, C, H, W], B is filter: [C_out, C_in/group, filter_h, filter_w] - out_shape_h = (inputs_shape[0][h] - ((inputs_shape[1][h] - 1) * dilation[0] + 1) + 2 * padding[0]) / stride[0] + 1; - out_shape_w = (inputs_shape[0][w] - ((inputs_shape[1][w] - 1) * dilation[1] + 1) + 2 * padding[1]) / stride[1] + 1; + // A is input: [N, C, H, W], B is filter: [C_out, C_in/group, filter_h, + // filter_w] + out_shape_h = + (inputs_shape[0][h] - ((inputs_shape[1][h] - 1) * dilation[0] + 1) + + 2 * padding[0]) / + stride[0] + + 1; + out_shape_w = + (inputs_shape[0][w] - ((inputs_shape[1][w] - 1) * dilation[1] + 1) + + 2 * padding[1]) / + stride[1] + + 1; output_shape[n] = inputs_shape[0][n]; output_shape[c] = inputs_shape[1][n]; output_shape[h] = out_shape_h; output_shape[w] = out_shape_w; } else if (conv_type == "backward_data") { - CHECK(attrs.find("output_shape") != attrs.end()) << "The shape of backward_data is not found! Please check."; + CHECK(attrs.find("output_shape") != attrs.end()) + << "The shape of backward_data is not found! Please check."; const auto &x_shape = absl::get>(attrs.at("output_shape")); CHECK_EQ(x_shape.size(), 4) << "The rank of x shape is not 4! Please check"; @@ -539,9 +618,12 @@ std::vector InferShapeForConv2d(const std::vector &inputs_shap output_shape[h] = x_shape[h]; output_shape[w] = x_shape[w]; } else if (conv_type == "backward_filter") { - CHECK(attrs.find("output_shape") != attrs.end()) << "The shape of backward_filter is not found! Please check."; - const auto &weight_shape = absl::get>(attrs.at("output_shape")); - CHECK_EQ(weight_shape.size(), 4) << "The rank of weight shape is not 4! Please check"; + CHECK(attrs.find("output_shape") != attrs.end()) + << "The shape of backward_filter is not found! Please check."; + const auto &weight_shape = + absl::get>(attrs.at("output_shape")); + CHECK_EQ(weight_shape.size(), 4) + << "The rank of weight shape is not 4! Please check"; // input[0] = x(batch, C_in, h, w) // input[1] = dy(batch, C_out, h, w) @@ -555,35 +637,53 @@ std::vector InferShapeForConv2d(const std::vector &inputs_shap std::vector res = {output_shape}; if (data_format == "NCHW") { absl::flat_hash_map conv2d_factors; - int batch = inputs_shape[0][0]; - int oc = inputs_shape[1][0]; - int ic = inputs_shape[0][1]; - int fc = inputs_shape[1][1]; - int h_in = inputs_shape[0][2]; - int w_in = inputs_shape[0][3]; - int h_f = inputs_shape[1][2]; - int w_f = inputs_shape[1][3]; - int pad_h = padding[0]; - int pad_w = padding[1]; - std::string key = pe::GenerateX86ConvKey(inputs_shape[0], inputs_shape[1], stride, padding, dilation); + int batch = inputs_shape[0][0]; + int oc = inputs_shape[1][0]; + int ic = inputs_shape[0][1]; + int fc = inputs_shape[1][1]; + int h_in = inputs_shape[0][2]; + int w_in = inputs_shape[0][3]; + int h_f = inputs_shape[1][2]; + int w_f = inputs_shape[1][3]; + int pad_h = padding[0]; + int pad_w = padding[1]; + std::string key = pe::GenerateX86ConvKey( + inputs_shape[0], inputs_shape[1], stride, padding, dilation); VLOG(3) << "key: " << key; - pe::GetConv2dFactors(&conv2d_factors, oc, ic, fc, -1, -1, Float(32), common::DefaultHostTarget(), key); + pe::GetConv2dFactors(&conv2d_factors, + oc, + ic, + fc, + -1, + -1, + Float(32), + common::DefaultHostTarget(), + key); int ic_bn = conv2d_factors["ic_bn"]; int oc_bn = conv2d_factors["oc_bn"]; int fc_bn = conv2d_factors["fc_bn"]; VLOG(3) << "ic_bn: " << ic_bn; VLOG(3) << "oc_bn: " << oc_bn; VLOG(3) << "fc_bn: " << fc_bn; - int oc_chunk = oc / oc_bn; - int ic_chunk = ic / ic_bn; - int fc_chunk = fc / fc_bn; - std::vector packed_out_shape = {batch, oc_chunk, out_shape_h, out_shape_w, oc_bn}; - std::vector input_pad_shape = {batch, ic_chunk, h_in + 2 * pad_h, w_in + 2 * pad_w, ic_bn}; - std::vector weights_dilation_shape = { - oc_chunk, fc_chunk, dilation[0] * (h_f - 1) + 1, dilation[1] * (w_f - 1) + 1, fc_bn, oc_bn}; + int oc_chunk = oc / oc_bn; + int ic_chunk = ic / ic_bn; + int fc_chunk = fc / fc_bn; + std::vector packed_out_shape = { + batch, oc_chunk, out_shape_h, out_shape_w, oc_bn}; + std::vector input_pad_shape = { + batch, ic_chunk, h_in + 2 * pad_h, w_in + 2 * pad_w, ic_bn}; + std::vector weights_dilation_shape = {oc_chunk, + fc_chunk, + dilation[0] * (h_f - 1) + 1, + dilation[1] * (w_f - 1) + 1, + fc_bn, + oc_bn}; std::vector data_shape = {batch, ic_chunk, h_in, w_in, ic_bn}; - res = {output_shape, packed_out_shape, weights_dilation_shape, input_pad_shape}; + res = {output_shape, + packed_out_shape, + weights_dilation_shape, + input_pad_shape}; } else if (data_format == "NHWC") { // now conv2d codegen version only support NCHW data format res = {output_shape}; @@ -591,31 +691,39 @@ std::vector InferShapeForConv2d(const std::vector &inputs_shap return res; } -std::vector InferDtypeForConv2d(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; - std::vector res{inputs_type[0], inputs_type[0], inputs_type[0], inputs_type[0]}; +std::vector InferDtypeForConv2d(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; + std::vector res{ + inputs_type[0], inputs_type[0], inputs_type[0], inputs_type[0]}; return res; } -std::vector> InferLayoutForConv2d(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 2U) << "The input's layouts size is not 2! Please check again."; +std::vector> InferLayoutForConv2d( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 2U) + << "The input's layouts size is not 2! Please check again."; ir::Layout weight_layout(input_layouts[1]); - return {{input_layouts[0], input_layouts[0], input_layouts[0], input_layouts[0]}, input_layouts}; + return { + {input_layouts[0], input_layouts[0], input_layouts[0], input_layouts[0]}, + input_layouts}; } -std::shared_ptr StrategyForConv2dNCHWc(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForConv2dNCHWc( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { std::vector padding({0, 0}); std::vector stride({1, 1}); std::vector dilation({1, 1}); std::string data_format = "NCHWc"; - int groups = 1; + int groups = 1; if (attrs.attr_store.find("padding") != attrs.attr_store.end()) { padding = absl::get>(attrs.attr_store.at("padding")); } @@ -631,101 +739,138 @@ std::shared_ptr StrategyForConv2dNCHWc(const framework::NodeAttr &at if (attrs.attr_store.find("groups") != attrs.attr_store.end()) { groups = absl::get(attrs.attr_store.at("groups")); } - CHECK(data_format == "NCHWc") << "conv2d_NCHWc op's data_format should be NCHWc"; - framework::CINNCompute conv2d_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of conv2d_NCHWc compute is empty! Please check.\n"; - CINNValuePack a = args[0]; - CHECK_GE(a.size(), 2U) << "at least 2 input tensors for conv2d_NCHWc compute\n"; - Expr A = a[0]; - Expr B = a[1]; - CHECK(A.as_tensor()); - CHECK(B.as_tensor()); - auto tensor_a = A.as_tensor_ref(); - auto tensor_b = B.as_tensor_ref(); - CHECK_EQ(tensor_a->shape.size(), 5) << "input's shape should be 5"; - CHECK_EQ(tensor_b->shape.size(), 6) << "weight's shape should be 6"; - CHECK_EQ(padding.size(), 2) << "The size of padding in conv2d_NCHWc op is not 2! Please check."; - CHECK_EQ(stride.size(), 2) << "The size of stride in conv2d_NCHWc op is not 2! Please check."; - CHECK_EQ(dilation.size(), 2) << "The size of stride in conv2d_NCHWc op is not 2! Please check."; - std::vector out; - CHECK(target.arch == Target::Arch::X86) << "conv2d_NCHWc op is only used in x86"; - // A is input: [N, C_in_outer, H, W, C_in_inner], B is filter: [C_out, C_in_group_outer, filter_h, filter_w, - // C_in_group_inner] - std::string key; - VLOG(3) << "input[" << utils::Join(tensor_a->shape, ", ") << "], weight shape[" - << utils::Join(tensor_b->shape, ", ") << "]"; - out = pe::Conv2d_NCHWc(tensor_a, - tensor_b, - padding[0], - padding[1], - stride[0], - stride[1], - dilation[0], - dilation[1], - UniqName("T_conv2d_NCHWc_out"), - target); - - auto stages = CreateStages({tensor_a, tensor_b}); - - std::vector res; - CHECK(out.size() == 2U) << "The output tensor sizes of conv2d_NCHWc op should be 2\n"; - for (auto &t : out) { - stages->InsertLazily(t); - res.push_back(CINNValue(t)); - } - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); - - framework::CINNSchedule conv2d_schedule([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of conv2d_NCHWc schedule is empty! Please check.\n"; - CINNValuePack arg_pack = args[0]; - CHECK_EQ(arg_pack.size(), 3UL); - poly::StageMap stages = arg_pack.back(); - Expr packed_out = arg_pack[0]; - Expr input_pad = arg_pack[1]; - CHECK(packed_out.as_tensor()); - CHECK(input_pad.as_tensor()); - std::vector kernel_shape = inputs[1]->shape; - // kernel_h == 1 && kernel_w == 1 - CHECK_EQ(kernel_shape.size(), 6U) << "kernel_dialtion shape size should be 6"; - bool is_1x1 = (is_zero(kernel_shape[2] - 1)) && (is_zero(kernel_shape[3] - 1)); - ir::Tensor res; - ir::Tensor data; - ir::Tensor weights; - ir::Tensor packed_out_tensor = packed_out.as_tensor_ref(); - std::string key; - bool do_padding = (padding[0] == 0 && padding[1] == 0) ? false : true; - if (attrs.attr_store.find("key") != attrs.attr_store.end()) { - key = absl::get(attrs.attr_store.at("key")); - } - if (is_1x1) { - pe::Conv2d_NCHWc_1X1_Schedule_CPU( - stages, res, packed_out_tensor, input_pad.as_tensor_ref(), weights, data, target, key, do_padding); - } else { - pe::Conv2d_NCHWc_Schedule_CPU( - stages, res, packed_out_tensor, input_pad.as_tensor_ref(), weights, data, target, key, do_padding); - } - if (do_padding) { - *ret = CINNValuePack{{CINNValue(packed_out_tensor), arg_pack[0], arg_pack[1], CINNValue(stages)}}; - } else { - *ret = CINNValuePack{{CINNValue(packed_out_tensor), arg_pack[0], CINNValue(stages)}}; - } - }); + CHECK(data_format == "NCHWc") + << "conv2d_NCHWc op's data_format should be NCHWc"; + framework::CINNCompute conv2d_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of conv2d_NCHWc compute is " + "empty! Please check.\n"; + CINNValuePack a = args[0]; + CHECK_GE(a.size(), 2U) + << "at least 2 input tensors for conv2d_NCHWc compute\n"; + Expr A = a[0]; + Expr B = a[1]; + CHECK(A.as_tensor()); + CHECK(B.as_tensor()); + auto tensor_a = A.as_tensor_ref(); + auto tensor_b = B.as_tensor_ref(); + CHECK_EQ(tensor_a->shape.size(), 5) << "input's shape should be 5"; + CHECK_EQ(tensor_b->shape.size(), 6) << "weight's shape should be 6"; + CHECK_EQ(padding.size(), 2) + << "The size of padding in conv2d_NCHWc op is not 2! Please check."; + CHECK_EQ(stride.size(), 2) + << "The size of stride in conv2d_NCHWc op is not 2! Please check."; + CHECK_EQ(dilation.size(), 2) + << "The size of stride in conv2d_NCHWc op is not 2! Please check."; + std::vector out; + CHECK(target.arch == Target::Arch::X86) + << "conv2d_NCHWc op is only used in x86"; + // A is input: [N, C_in_outer, H, W, C_in_inner], B is filter: [C_out, + // C_in_group_outer, filter_h, filter_w, C_in_group_inner] + std::string key; + VLOG(3) << "input[" << utils::Join(tensor_a->shape, ", ") + << "], weight shape[" << utils::Join(tensor_b->shape, ", ") + << "]"; + out = pe::Conv2d_NCHWc(tensor_a, + tensor_b, + padding[0], + padding[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + UniqName("T_conv2d_NCHWc_out"), + target); + + auto stages = CreateStages({tensor_a, tensor_b}); + + std::vector res; + CHECK(out.size() == 2U) + << "The output tensor sizes of conv2d_NCHWc op should be 2\n"; + for (auto &t : out) { + stages->InsertLazily(t); + res.push_back(CINNValue(t)); + } + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule conv2d_schedule( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of conv2d_NCHWc schedule " + "is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + CHECK_EQ(arg_pack.size(), 3UL); + poly::StageMap stages = arg_pack.back(); + Expr packed_out = arg_pack[0]; + Expr input_pad = arg_pack[1]; + CHECK(packed_out.as_tensor()); + CHECK(input_pad.as_tensor()); + std::vector kernel_shape = inputs[1]->shape; + // kernel_h == 1 && kernel_w == 1 + CHECK_EQ(kernel_shape.size(), 6U) + << "kernel_dialtion shape size should be 6"; + bool is_1x1 = + (is_zero(kernel_shape[2] - 1)) && (is_zero(kernel_shape[3] - 1)); + ir::Tensor res; + ir::Tensor data; + ir::Tensor weights; + ir::Tensor packed_out_tensor = packed_out.as_tensor_ref(); + std::string key; + bool do_padding = (padding[0] == 0 && padding[1] == 0) ? false : true; + if (attrs.attr_store.find("key") != attrs.attr_store.end()) { + key = absl::get(attrs.attr_store.at("key")); + } + if (is_1x1) { + pe::Conv2d_NCHWc_1X1_Schedule_CPU(stages, + res, + packed_out_tensor, + input_pad.as_tensor_ref(), + weights, + data, + target, + key, + do_padding); + } else { + pe::Conv2d_NCHWc_Schedule_CPU(stages, + res, + packed_out_tensor, + input_pad.as_tensor_ref(), + weights, + data, + target, + key, + do_padding); + } + if (do_padding) { + *ret = CINNValuePack{{CINNValue(packed_out_tensor), + arg_pack[0], + arg_pack[1], + CINNValue(stages)}}; + } else { + *ret = CINNValuePack{ + {CINNValue(packed_out_tensor), arg_pack[0], CINNValue(stages)}}; + } + }); auto strategy = std::make_shared(); - CHECK(out_type.size()) << "Out_type of conv2d_NCHWc op is empty! Please check."; + CHECK(out_type.size()) + << "Out_type of conv2d_NCHWc op is empty! Please check."; if (out_type[0] == Float(32)) { - strategy->AddImpl(conv2d_compute, conv2d_schedule, "strategy.conv2d_NCHWc.x86", 1); + strategy->AddImpl( + conv2d_compute, conv2d_schedule, "strategy.conv2d_NCHWc.x86", 1); } else { - LOG(FATAL) << "conv2d_NCHWc op with dtype != float32 is not implemented yet!"; + LOG(FATAL) + << "conv2d_NCHWc op with dtype != float32 is not implemented yet!"; } return strategy; } -std::vector InferShapeForConv2dNCHWc(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; +std::vector InferShapeForConv2dNCHWc( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) + << "The input's shape size is 0! Please check again."; std::vector padding({0, 0}); std::vector stride({1, 1}); std::vector dilation({1, 1}); @@ -742,72 +887,91 @@ std::vector InferShapeForConv2dNCHWc(const std::vector &inputs if (attrs.find("data_format") != attrs.end()) { data_format = absl::get(attrs.at("data_format")); } - CHECK_EQ(padding.size(), 2) << "The size of padding in conv2d_NCHWc op is not 2! Please check."; - CHECK_EQ(stride.size(), 2) << "The size of stride in conv2d_NCHWc op is not 2! Please check."; + CHECK_EQ(padding.size(), 2) + << "The size of padding in conv2d_NCHWc op is not 2! Please check."; + CHECK_EQ(stride.size(), 2) + << "The size of stride in conv2d_NCHWc op is not 2! Please check."; CHECK_EQ(inputs_shape[0].size(), 5) - << "The first input tensor's shape size of conv2d_NCHWc op should be 5! Please check."; + << "The first input tensor's shape size of conv2d_NCHWc op should be 5! " + "Please check."; CHECK_EQ(inputs_shape[1].size(), 6) - << "The second input tensor's shape size of conv2d_NCHWc op should be 6! Please check."; + << "The second input tensor's shape size of conv2d_NCHWc op should be 6! " + "Please check."; std::vector res; CHECK(data_format == "NCHWc") << "NCHWc op's data_format should be NCHWc"; int out_shape_h = - (inputs_shape[0][2] - ((inputs_shape[1][2] - 1) * dilation[0] + 1) + 2 * padding[0]) / stride[0] + 1; + (inputs_shape[0][2] - ((inputs_shape[1][2] - 1) * dilation[0] + 1) + + 2 * padding[0]) / + stride[0] + + 1; int out_shape_w = - (inputs_shape[0][3] - ((inputs_shape[1][3] - 1) * dilation[1] + 1) + 2 * padding[1]) / stride[1] + 1; + (inputs_shape[0][3] - ((inputs_shape[1][3] - 1) * dilation[1] + 1) + + 2 * padding[1]) / + stride[1] + + 1; // A: NCHWc, B: OIHWio - int batch = inputs_shape[0][0]; - int h_in = inputs_shape[0][2]; - int w_in = inputs_shape[0][3]; - int oc = inputs_shape[1][0]; - int h_f = inputs_shape[1][2]; - int w_f = inputs_shape[1][3]; - int pad_h = padding[0]; - int pad_w = padding[1]; - int ic_bn = inputs_shape[0][4]; - int ic_chunk = inputs_shape[0][1]; - int oc_bn = inputs_shape[1][5]; - int oc_chunk = inputs_shape[1][0]; - std::vector packed_out_shape = {batch, oc_chunk, out_shape_h, out_shape_w, oc_bn}; - auto pad_h_bound = (out_shape_h - 1) * stride[0] + (h_f - 1) * dilation[0] + 1; - auto pad_w_bound = (out_shape_w - 1) * stride[1] + (w_f - 1) * dilation[1] + 1; - auto input_pad_h = std::min(pad_h_bound, h_in + 2 * pad_h); - auto input_pad_w = std::min(pad_w_bound, w_in + 2 * pad_w); - std::vector input_pad_shape = {batch, ic_chunk, input_pad_h, input_pad_w, ic_bn}; + int batch = inputs_shape[0][0]; + int h_in = inputs_shape[0][2]; + int w_in = inputs_shape[0][3]; + int oc = inputs_shape[1][0]; + int h_f = inputs_shape[1][2]; + int w_f = inputs_shape[1][3]; + int pad_h = padding[0]; + int pad_w = padding[1]; + int ic_bn = inputs_shape[0][4]; + int ic_chunk = inputs_shape[0][1]; + int oc_bn = inputs_shape[1][5]; + int oc_chunk = inputs_shape[1][0]; + std::vector packed_out_shape = { + batch, oc_chunk, out_shape_h, out_shape_w, oc_bn}; + auto pad_h_bound = + (out_shape_h - 1) * stride[0] + (h_f - 1) * dilation[0] + 1; + auto pad_w_bound = + (out_shape_w - 1) * stride[1] + (w_f - 1) * dilation[1] + 1; + auto input_pad_h = std::min(pad_h_bound, h_in + 2 * pad_h); + auto input_pad_w = std::min(pad_w_bound, w_in + 2 * pad_w); + std::vector input_pad_shape = { + batch, ic_chunk, input_pad_h, input_pad_w, ic_bn}; VLOG(3) << "packed_out_shape: " << utils::Join(packed_out_shape, ", "); return {packed_out_shape, packed_out_shape, input_pad_shape}; } -std::vector> InferLayoutForConv2dNCHWc(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 2U) << "The input's layouts size is not 2! Please check again."; +std::vector> InferLayoutForConv2dNCHWc( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 2U) + << "The input's layouts size is not 2! Please check again."; ir::Layout weight_layout(input_layouts[1]); CHECK_EQ(weight_layout.ndims(), 6U); - auto var = weight_layout.axes().back(); + auto var = weight_layout.axes().back(); int factor = var->upper_bound.as_int32(); CHECK_GE(factor, 1) << "factor should be larger than 1"; std::string outlayout = "NCHW" + std::to_string(factor) + "c"; return {{outlayout, outlayout, input_layouts[0]}, input_layouts}; } -std::vector InferDtypeForConv2dNCHWc(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForConv2dNCHWc( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0], inputs_type[0], inputs_type[0]}; return res; } -std::shared_ptr StrategyForDepthwiseConv2d(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - std::vector padding = {0, 0}; - std::vector stride = {1, 1}; +std::shared_ptr StrategyForDepthwiseConv2d( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + std::vector padding = {0, 0}; + std::vector stride = {1, 1}; std::vector dilation = {1, 1}; - std::string data_format = "NCHW"; + std::string data_format = "NCHW"; std::string key; if (attrs.attr_store.find("padding") != attrs.attr_store.end()) { padding = absl::get>(attrs.attr_store.at("padding")); @@ -825,17 +989,23 @@ std::shared_ptr StrategyForDepthwiseConv2d(const framework::NodeAttr key = absl::get(attrs.attr_store.at("key")); } - framework::CINNCompute depthwise_conv2d_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of depthwise_conv compute is empty! Please check.\n"; + framework::CINNCompute depthwise_conv2d_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of depthwise_conv compute is " + "empty! Please check.\n"; CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 2U) << "at least 2 input tensors for depthwise_conv compute\n"; + CHECK_GE(pack_args.size(), 2U) + << "at least 2 input tensors for depthwise_conv compute\n"; Expr A = pack_args[0]; Expr B = pack_args[1]; CHECK(A.as_tensor()); CHECK(B.as_tensor()); - CHECK_EQ(padding.size(), 2) << "The size of padding in depthwise_conv op is not 2! Please check.\n"; - CHECK_EQ(stride.size(), 2) << "The size of stride in depthwise_conv op is not 2! Please check.\n"; - CHECK(data_format == "NCHW" || data_format == "NHWC") << "only support NCHW/NHWC data_format.\n"; + CHECK_EQ(padding.size(), 2) + << "The size of padding in depthwise_conv op is not 2! Please check.\n"; + CHECK_EQ(stride.size(), 2) + << "The size of stride in depthwise_conv op is not 2! Please check.\n"; + CHECK(data_format == "NCHW" || data_format == "NHWC") + << "only support NCHW/NHWC data_format.\n"; std::vector out; std::string tensor_name = UniqName("Depthwise_Conv2d_out"); if (FLAGS_cinn_ir_schedule) { @@ -857,12 +1027,22 @@ std::shared_ptr StrategyForDepthwiseConv2d(const framework::NodeAttr tensor_name, target); } else { - out = pe::Depthwise_Conv2d_NCHW( - A.as_tensor_ref(), B.as_tensor_ref(), padding[0], padding[1], stride[0], stride[1], tensor_name); + out = pe::Depthwise_Conv2d_NCHW(A.as_tensor_ref(), + B.as_tensor_ref(), + padding[0], + padding[1], + stride[0], + stride[1], + tensor_name); } } else if (data_format == "NHWC") { - out = pe::Depthwise_Conv2d_NHWC( - A.as_tensor_ref(), B.as_tensor_ref(), padding[0], padding[1], stride[0], stride[1], tensor_name); + out = pe::Depthwise_Conv2d_NHWC(A.as_tensor_ref(), + B.as_tensor_ref(), + padding[0], + padding[1], + stride[0], + stride[1], + tensor_name); } else { LOG(FATAL) << "Only support NCHW and NHWC data layout\n"; } @@ -874,14 +1054,17 @@ std::shared_ptr StrategyForDepthwiseConv2d(const framework::NodeAttr res.push_back(CINNValue(t)); } CHECK(out.size() == 2U || out.size() == 1U || out.size() == 5U) - << "The output tensor sizes of depthwise_conv op in depthwise_conv op should be 1 or 2 or 5\n"; + << "The output tensor sizes of depthwise_conv op in depthwise_conv op " + "should be 1 or 2 or 5\n"; res.push_back(CINNValue(stages)); *ret = CINNValuePack{res}; }); - framework::CINNSchedule depthwise_conv2d_schedule([=](lang::Args args, lang::RetValue *ret) { + framework::CINNSchedule depthwise_conv2d_schedule([=](lang::Args args, + lang::RetValue *ret) { if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of InjectiveSchedule is empty! Please check.\n"; + CHECK(!args.empty()) << "The input argument of InjectiveSchedule is " + "empty! Please check.\n"; common::CINNValuePack arg_pack = args[0]; std::vector vec_ast; std::vector vec_tensor; @@ -903,14 +1086,17 @@ std::shared_ptr StrategyForDepthwiseConv2d(const framework::NodeAttr } else { CINN_NOT_IMPLEMENTED } - std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = common::CINNValuePack{res}; } else { - CHECK(!args.empty()) << "The input argument of depthwise_conv schedule is empty! Please check.\n"; + CHECK(!args.empty()) << "The input argument of depthwise_conv schedule " + "is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; - CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL || arg_pack.size() == 6UL); + CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL || + arg_pack.size() == 6UL); poly::StageMap stages = arg_pack[arg_pack.size() - 1]; - Expr Out = arg_pack[0]; + Expr Out = arg_pack[0]; CHECK(Out.as_tensor()); if (arg_pack.size() == 3UL) { Expr input_pad = arg_pack[1]; @@ -924,31 +1110,38 @@ std::shared_ptr StrategyForDepthwiseConv2d(const framework::NodeAttr arg_pack[0] = Expr(output); } else if (target.arch == Target::Arch::X86) { if (arg_pack.size() == 6UL) { - Expr res = arg_pack[0]; - Expr packed_out = arg_pack[1]; + Expr res = arg_pack[0]; + Expr packed_out = arg_pack[1]; Expr weights_dilation = arg_pack[2]; - Expr input_pad = arg_pack[3]; - Expr data = arg_pack[4]; + Expr input_pad = arg_pack[3]; + Expr data = arg_pack[4]; CHECK(res.as_tensor()); CHECK(packed_out.as_tensor()); CHECK(input_pad.as_tensor()); CHECK(weights_dilation.as_tensor()); CHECK(data.as_tensor()); ir::Tensor packed_out_tensor = packed_out.as_tensor_ref(); - bool do_padding = (padding[0] == 0 && padding[1] == 0) ? false : true; - pe::Depthwise_Conv2d_NCHWc_Schedule_CPU_Nofuse(stages, - res.as_tensor_ref(), - packed_out_tensor, - input_pad.as_tensor_ref(), - weights_dilation.as_tensor_ref(), - data.as_tensor_ref(), - target, - do_padding); + bool do_padding = (padding[0] == 0 && padding[1] == 0) ? false : true; + pe::Depthwise_Conv2d_NCHWc_Schedule_CPU_Nofuse( + stages, + res.as_tensor_ref(), + packed_out_tensor, + input_pad.as_tensor_ref(), + weights_dilation.as_tensor_ref(), + data.as_tensor_ref(), + target, + do_padding); if (do_padding) { - *ret = CINNValuePack{ - {CINNValue(res), CINNValue(packed_out_tensor), arg_pack[2], arg_pack[3], CINNValue(stages)}}; + *ret = CINNValuePack{{CINNValue(res), + CINNValue(packed_out_tensor), + arg_pack[2], + arg_pack[3], + CINNValue(stages)}}; } else { - *ret = CINNValuePack{{CINNValue(res), CINNValue(packed_out_tensor), arg_pack[2], CINNValue(stages)}}; + *ret = CINNValuePack{{CINNValue(res), + CINNValue(packed_out_tensor), + arg_pack[2], + CINNValue(stages)}}; } return; } @@ -959,23 +1152,32 @@ std::shared_ptr StrategyForDepthwiseConv2d(const framework::NodeAttr }); auto strategy = std::make_shared(); - CHECK(out_type.size()) << "Out_type of depthwise_conv op is empty! Please check."; + CHECK(out_type.size()) + << "Out_type of depthwise_conv op is empty! Please check."; if (out_type[0] == Float(32)) { - strategy->AddImpl(depthwise_conv2d_compute, depthwise_conv2d_schedule, "strategy.depthwise_conv.x86", 1); + strategy->AddImpl(depthwise_conv2d_compute, + depthwise_conv2d_schedule, + "strategy.depthwise_conv.x86", + 1); } else { - VLOG(3) << "depthwise_conv op with dtype != float32 is not implemented yet!"; + VLOG(3) + << "depthwise_conv op with dtype != float32 is not implemented yet!"; } return strategy; } -std::vector InferShapeForDepthwiseConv2d(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 2U) << "at least 2 input tensors for depthwise_conv2d op\n"; - CHECK_EQ(inputs_shape[0].size(), 4U) << "The input tensor's shape should be 4! Please check again."; - CHECK_EQ(inputs_shape[1].size(), 4U) << "The input tensor's shape should be 4! Please check again."; +std::vector InferShapeForDepthwiseConv2d( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 2U) + << "at least 2 input tensors for depthwise_conv2d op\n"; + CHECK_EQ(inputs_shape[0].size(), 4U) + << "The input tensor's shape should be 4! Please check again."; + CHECK_EQ(inputs_shape[1].size(), 4U) + << "The input tensor's shape should be 4! Please check again."; std::vector padding = {0, 0}; - std::vector stride = {1, 1}; - std::string data_format = "NCHW"; + std::vector stride = {1, 1}; + std::string data_format = "NCHW"; if (attrs.find("padding") != attrs.end()) { padding = absl::get>(attrs.at("padding")); } @@ -986,59 +1188,83 @@ std::vector InferShapeForDepthwiseConv2d(const std::vector &in data_format = absl::get(attrs.at("data_format")); } std::vector res; - CHECK_EQ(padding.size(), 2U) << "The size of padding in depthwise_conv2d op is not 2! Please check."; - CHECK_EQ(stride.size(), 2U) << "The size of stride in depthwise_conv2d op is not 2! Please check."; + CHECK_EQ(padding.size(), 2U) + << "The size of padding in depthwise_conv2d op is not 2! Please check."; + CHECK_EQ(stride.size(), 2U) + << "The size of stride in depthwise_conv2d op is not 2! Please check."; if (data_format == "NCHW") { - // A is input: [N, C, H, W], and B is filter: [C_in, channel_multiplier, f_h, f_w] - int out_shape_h = (inputs_shape[0][2] - inputs_shape[1][2] + 2 * padding[0]) / stride[0] + 1; - int out_shape_w = (inputs_shape[0][3] - inputs_shape[1][3] + 2 * padding[1]) / stride[1] + 1; - res = {{inputs_shape[0][0], inputs_shape[1][1] * inputs_shape[0][1], out_shape_h, out_shape_w}}; + // A is input: [N, C, H, W], and B is filter: [C_in, channel_multiplier, + // f_h, f_w] + int out_shape_h = + (inputs_shape[0][2] - inputs_shape[1][2] + 2 * padding[0]) / stride[0] + + 1; + int out_shape_w = + (inputs_shape[0][3] - inputs_shape[1][3] + 2 * padding[1]) / stride[1] + + 1; + res = {{inputs_shape[0][0], + inputs_shape[1][1] * inputs_shape[0][1], + out_shape_h, + out_shape_w}}; } else if (data_format == "NHWC") { - // A is input: [N, H, W, C], and B is filter: [C_in, channel_multiplier, f_h, f_w] - int out_shape_h = (inputs_shape[0][1] - inputs_shape[1][1] + 2 * padding[0]) / stride[0] + 1; - int out_shape_w = (inputs_shape[0][2] - inputs_shape[1][2] + 2 * padding[1]) / stride[1] + 1; - res = {{inputs_shape[0][0], out_shape_h, out_shape_w, inputs_shape[1][1] * inputs_shape[0][3]}}; + // A is input: [N, H, W, C], and B is filter: [C_in, channel_multiplier, + // f_h, f_w] + int out_shape_h = + (inputs_shape[0][1] - inputs_shape[1][1] + 2 * padding[0]) / stride[0] + + 1; + int out_shape_w = + (inputs_shape[0][2] - inputs_shape[1][2] + 2 * padding[1]) / stride[1] + + 1; + res = {{inputs_shape[0][0], + out_shape_h, + out_shape_w, + inputs_shape[1][1] * inputs_shape[0][3]}}; } else { LOG(FATAL) << "Only support NCHW and NHWC data layout\n"; } return res; } -std::vector InferDtypeForDepthwiseConv2d(const std::vector &inputs_type, - const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForDepthwiseConv2d( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::shared_ptr StrategyForBatchNorm(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForBatchNorm( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { float epsilon = 0.00001f; std::vector input_layouts; if (attrs.attr_store.find("epsilon") != attrs.attr_store.end()) { epsilon = absl::get(attrs.attr_store.at("epsilon")); } if (attrs.attr_store.find("input_layouts") != attrs.attr_store.end()) { - input_layouts = absl::get>(attrs.attr_store.at("input_layouts")); + input_layouts = absl::get>( + attrs.attr_store.at("input_layouts")); } - framework::CINNCompute batchnorm_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of batchnorm compute is empty! Please check.\n"; + framework::CINNCompute batchnorm_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of batchnorm compute is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; - CHECK_GE(arg_pack.size(), 5U) << "at least 5 input tensors for batchnorm compute\n"; - Expr A = arg_pack[0]; - Expr Scale = arg_pack[1]; - Expr Bias = arg_pack[2]; - Expr Mean = arg_pack[3]; - Expr Variance = arg_pack[4]; + CHECK_GE(arg_pack.size(), 5U) + << "at least 5 input tensors for batchnorm compute\n"; + Expr A = arg_pack[0]; + Expr Scale = arg_pack[1]; + Expr Bias = arg_pack[2]; + Expr Mean = arg_pack[3]; + Expr Variance = arg_pack[4]; std::string out_name = UniqName("BatchNorm_output"); if (FLAGS_cinn_ir_schedule) { CHECK_EQ(arg_pack.size(), 6U); CHECK(arg_pack[5].is_string()); std::string str = arg_pack[5]; - out_name = str; + out_name = str; } CHECK(A.as_tensor()); CHECK(Scale.as_tensor()); @@ -1048,7 +1274,8 @@ std::shared_ptr StrategyForBatchNorm(const framework::NodeAttr &attr ir::Tensor out; auto tensor_input = A.as_tensor_ref(); if (tensor_input->shape.size() != 4 && target.arch == Target::Arch::X86) { - CHECK_EQ(input_layouts.size(), 5U) << "batch_norm_NCHWc's input layout should be 5"; + CHECK_EQ(input_layouts.size(), 5U) + << "batch_norm_NCHWc's input layout should be 5"; std::string input_layout = input_layouts[0]; CHECK_GE(input_layout.size(), 5U); CHECK_EQ(input_layout.substr(0, 4), "NCHW"); @@ -1070,119 +1297,145 @@ std::shared_ptr StrategyForBatchNorm(const framework::NodeAttr &attr out_name); } auto stages = CreateStages({out}); - *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; }); auto strategy = std::make_shared(); CHECK(out_type.size()) << "Out_type of batchnorm op is empty! Please check."; if (out_type[0] == Float(32)) { - strategy->AddImpl(batchnorm_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.batchnorm.x86", 1); + strategy->AddImpl(batchnorm_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.batchnorm.x86", + 1); } else { LOG(FATAL) << "BatchNorm op with dtype != float32 is not implemented yet!"; } return strategy; } -std::vector InferShapeForBatchNorm(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; +std::vector InferShapeForBatchNorm( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) + << "The input's shape size is 0! Please check again."; std::vector res{inputs_shape[0]}; return res; } -std::vector InferDtypeForBatchNorm(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 5U) << "The BatchNorm Infer input's type size should be 5! Please check again."; - CHECK_EQ(inputs_type[1], inputs_type[2]) << "The BatchNorm Infer scale type should the same as bias type"; - CHECK_EQ(inputs_type[1], inputs_type[3]) << "The BatchNorm Infer scale type should the same as moving_mean type"; - CHECK_EQ(inputs_type[1], inputs_type[4]) << "The BatchNorm Infer scale type should the same as moving_variance type"; +std::vector InferDtypeForBatchNorm(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 5U) << "The BatchNorm Infer input's type size " + "should be 5! Please check again."; + CHECK_EQ(inputs_type[1], inputs_type[2]) + << "The BatchNorm Infer scale type should the same as bias type"; + CHECK_EQ(inputs_type[1], inputs_type[3]) + << "The BatchNorm Infer scale type should the same as moving_mean type"; + CHECK_EQ(inputs_type[1], inputs_type[4]) + << "The BatchNorm Infer scale type should the same as moving_variance " + "type"; std::vector res{inputs_type[0]}; return res; } -std::vector> InferLayoutForBatchNorm(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 5U) << "The input's layouts size is not 5! Please check again."; +std::vector> InferLayoutForBatchNorm( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 5U) + << "The input's layouts size is not 5! Please check again."; std::string input_layout = input_layouts[0]; - CHECK_GE(input_layout.size(), 4) << "batchnorm's first input layout size should be >= 4"; + CHECK_GE(input_layout.size(), 4) + << "batchnorm's first input layout size should be >= 4"; return {{input_layout}, input_layouts}; } -std::shared_ptr StrategyForPool1d(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute pool1d_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of pool1d compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - CHECK(!pack_args.empty()) << "The input tensor of pool1d compute is empty! Please check.\n"; - Expr A = pack_args[0]; - CHECK(A.as_tensor()); - auto attr_store = attrs.attr_store; - std::vector kernel_size; // [kernel_w] - std::vector stride_size; // [stride_w] - std::vector padding_size; // [padding_left, padding_right] - std::string pool_type = "max"; - bool ceil_mode = false; - bool exclusive = true; - std::string data_format = "NCW"; - for (auto &iter : attrs.attr_store) { - if (iter.first == "kernel_size") { - kernel_size = absl::get>(iter.second); - } else if (iter.first == "stride_size") { - stride_size = absl::get>(iter.second); - } else if (iter.first == "padding_size") { - padding_size = absl::get>(iter.second); - } else if (iter.first == "pool_type") { - pool_type = absl::get(iter.second); - } else if (iter.first == "ceil_mode") { - ceil_mode = absl::get(iter.second); - } else if (iter.first == "exclusive") { - exclusive = absl::get(iter.second); - } else if (iter.first == "data_format") { - data_format = absl::get(iter.second); - } else { - LOG(ERROR) << "Unsupported attr: " << iter.first << std::endl; - } - } - CHECK(!kernel_size.empty()) << "kernel_size for pool1d is empty. Please check.\n"; - CHECK(!stride_size.empty()) << "stride_size for pool1d is empty. Please check.\n"; - CHECK(!padding_size.empty()) << "padding_size for pool1d is empty. Please check.\n"; - CHECK(pool_type == "max" || pool_type == "avg") << "pool_type for pool1d should be max or avg.\n"; - - std::string tensor_name = UniqName("Pool1d_out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); - } +std::shared_ptr StrategyForPool1d( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute pool1d_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of pool1d compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK(!pack_args.empty()) + << "The input tensor of pool1d compute is empty! Please check.\n"; + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + auto attr_store = attrs.attr_store; + std::vector kernel_size; // [kernel_w] + std::vector stride_size; // [stride_w] + std::vector padding_size; // [padding_left, padding_right] + std::string pool_type = "max"; + bool ceil_mode = false; + bool exclusive = true; + std::string data_format = "NCW"; + for (auto &iter : attrs.attr_store) { + if (iter.first == "kernel_size") { + kernel_size = absl::get>(iter.second); + } else if (iter.first == "stride_size") { + stride_size = absl::get>(iter.second); + } else if (iter.first == "padding_size") { + padding_size = absl::get>(iter.second); + } else if (iter.first == "pool_type") { + pool_type = absl::get(iter.second); + } else if (iter.first == "ceil_mode") { + ceil_mode = absl::get(iter.second); + } else if (iter.first == "exclusive") { + exclusive = absl::get(iter.second); + } else if (iter.first == "data_format") { + data_format = absl::get(iter.second); + } else { + LOG(ERROR) << "Unsupported attr: " << iter.first << std::endl; + } + } + CHECK(!kernel_size.empty()) + << "kernel_size for pool1d is empty. Please check.\n"; + CHECK(!stride_size.empty()) + << "stride_size for pool1d is empty. Please check.\n"; + CHECK(!padding_size.empty()) + << "padding_size for pool1d is empty. Please check.\n"; + CHECK(pool_type == "max" || pool_type == "avg") + << "pool_type for pool1d should be max or avg.\n"; + + std::string tensor_name = UniqName("Pool1d_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } - auto out = pe::Pool1d(A.as_tensor_ref(), - kernel_size, - stride_size, - padding_size, - pool_type, - ceil_mode, - exclusive, - data_format, - tensor_name); - - auto stages = CreateStages(out); - CHECK(out.size() == 1U || out.size() == 2U) << "The size of pe::Pool1d's output should be 1 or 2."; - CHECK(!out_type.empty()) << "Output type of Pool1d is empty! Please check.\n"; - std::vector res; - for (auto &t : out) { - res.push_back(CINNValue(Expr(t.get()))); - } - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); + auto out = pe::Pool1d(A.as_tensor_ref(), + kernel_size, + stride_size, + padding_size, + pool_type, + ceil_mode, + exclusive, + data_format, + tensor_name); + + auto stages = CreateStages(out); + CHECK(out.size() == 1U || out.size() == 2U) + << "The size of pe::Pool1d's output should be 1 or 2."; + CHECK(!out_type.empty()) + << "Output type of Pool1d is empty! Please check.\n"; + std::vector res; + for (auto &t : out) { + res.push_back(CINNValue(Expr(t.get()))); + } + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); - framework::CINNSchedule pool1d_schedule([=](lang::Args args, lang::RetValue *ret) { + framework::CINNSchedule pool1d_schedule([=](lang::Args args, + lang::RetValue *ret) { if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of pool1d schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of pool1d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; std::vector vec_ast; std::vector vec_tensor; @@ -1216,13 +1469,15 @@ std::shared_ptr StrategyForPool1d(const framework::NodeAttr &attrs, ir_sch.Bind(loops[0], "blockIdx.x"); ir_sch.Bind(loops[1], "threadIdx.x"); } - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else { - CHECK(!args.empty()) << "The input argument of pool1d schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of pool1d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL); - Expr Out = arg_pack[0]; + Expr Out = arg_pack[0]; poly::StageMap stages = arg_pack[arg_pack.size() - 1]; if (arg_pack.size() == 3UL) { Expr input_pad = arg_pack[1]; @@ -1246,15 +1501,17 @@ std::shared_ptr StrategyForPool1d(const framework::NodeAttr &attrs, return strategy; } -std::vector> InferShapeForPool1d(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; +std::vector> InferShapeForPool1d( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) + << "The input's shape size is 0! Please check again."; std::vector kernel_size; // [kernel_w] std::vector stride_size; // [stride_w] std::vector padding_size; // [padding_left, padding_right] - std::string pool_type = "max"; - bool ceil_mode = false; - bool exclusive = true; + std::string pool_type = "max"; + bool ceil_mode = false; + bool exclusive = true; std::string data_format = "NCW"; for (auto &iter : attrs) { if (iter.first == "kernel_size") { @@ -1272,9 +1529,12 @@ std::vector> InferShapeForPool1d(const std::vector output_shape1 = inputs_shape[0]; CHECK_EQ(output_shape1.size(), 3U); @@ -1289,32 +1549,37 @@ std::vector> InferShapeForPool1d(const std::vector> res{output_shape1}; return res; } -std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForPool2d( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { auto attr_store = attrs.attr_store; std::vector kernel_size; // [kernel_h, kernel_w] std::vector stride_size; // [stride_h, stride_w] - std::vector padding_size; // [padding_top, padding_left, padding_bottom, padding_right] - std::string pool_type = "max"; - bool ceil_mode = false; - bool exclusive = true; - bool global_pooling = false; - bool adaptive = false; + std::vector padding_size; // [padding_top, padding_left, padding_bottom, + // padding_right] + std::string pool_type = "max"; + bool ceil_mode = false; + bool exclusive = true; + bool global_pooling = false; + bool adaptive = false; std::string data_format = "NCHW"; for (auto &iter : attrs.attr_store) { if (iter.first == "kernel_size") { @@ -1339,10 +1604,12 @@ std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, } // It can be removed after fixing the global_pool2d problem if (attr_store.count("origin_kernel_size")) { - kernel_size = absl::get>(attr_store.at("origin_kernel_size")); + kernel_size = + absl::get>(attr_store.at("origin_kernel_size")); } if (attr_store.count("origin_padding_size")) { - padding_size = absl::get>(attr_store.at("origin_padding_size")); + padding_size = + absl::get>(attr_store.at("origin_padding_size")); } if (attr_store.count("origin_global_pooling")) { global_pooling = absl::get(attr_store.at("origin_global_pooling")); @@ -1351,63 +1618,78 @@ std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, adaptive = absl::get(attr_store.at("origin_adaptive")); } - CHECK(!kernel_size.empty()) << "kernel_size for pool2d is empty. Please check.\n"; - CHECK(!stride_size.empty()) << "stride_size for pool2d is empty. Please check.\n"; - CHECK(!padding_size.empty()) << "padding_size for pool2d is empty. Please check.\n"; - CHECK(pool_type == "max" || pool_type == "avg") << "pool_type for pool2d should be max or avg.\n"; + CHECK(!kernel_size.empty()) + << "kernel_size for pool2d is empty. Please check.\n"; + CHECK(!stride_size.empty()) + << "stride_size for pool2d is empty. Please check.\n"; + CHECK(!padding_size.empty()) + << "padding_size for pool2d is empty. Please check.\n"; + CHECK(pool_type == "max" || pool_type == "avg") + << "pool_type for pool2d should be max or avg.\n"; - CHECK(!inputs.empty()) << "The input tensor of pool2d compute is empty! Please check.\n"; + CHECK(!inputs.empty()) + << "The input tensor of pool2d compute is empty! Please check.\n"; const ir::Tensor &A_tensor = inputs[0]; CHECK(A_tensor->shape.size() == 4U || A_tensor->shape.size() == 5U) << "pool2d requires tensor's shape_size to be 4 or 5\n"; if (global_pooling) { int height_index = -1; - int width_index = -1; + int width_index = -1; if (data_format == "NCHW") { height_index = 2; - width_index = 3; + width_index = 3; } else if (data_format == "NHWC") { height_index = 1; - width_index = 2; + width_index = 2; } else if (data_format == "AnyLayout") { height_index = 2; - width_index = 3; - data_format = "NCHW"; + width_index = 3; + data_format = "NCHW"; } else { - LOG(FATAL) << "Only support 'NCHW' or 'NHWC' or 'AnyLayout' data_format.\n"; + LOG(FATAL) + << "Only support 'NCHW' or 'NHWC' or 'AnyLayout' data_format.\n"; } - kernel_size = {A_tensor->shape[height_index].as_int32(), A_tensor->shape[width_index].as_int32()}; + kernel_size = {A_tensor->shape[height_index].as_int32(), + A_tensor->shape[width_index].as_int32()}; padding_size = {0, 0, 0, 0}; } if (kernel_size.size() == padding_size.size()) { - padding_size.insert(padding_size.end(), padding_size.begin(), padding_size.end()); - } - - framework::CINNCompute global_pool2d_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of pool2d compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - Expr A = pack_args[0]; - CHECK(A.as_tensor()); - ir::Tensor A_tensor = A.as_tensor_ref(); - - std::string tensor_name = UniqName("GlobalPool2d_out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); - } - - auto out = pe::GlobalPool2d(A_tensor, pool_type, tensor_name); - CHECK(out.size() == 2U) << "The size of pe::GlobalPool2d's output should be 2."; - auto stages = CreateStages({A_tensor, out[0], out[1]}); - *ret = CINNValuePack{{CINNValue(out[0]), CINNValue(out[1]), CINNValue(stages)}}; - }); + padding_size.insert( + padding_size.end(), padding_size.begin(), padding_size.end()); + } + + framework::CINNCompute global_pool2d_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of pool2d compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + ir::Tensor A_tensor = A.as_tensor_ref(); + + std::string tensor_name = UniqName("GlobalPool2d_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } - framework::CINNSchedule global_pool2d_schedule([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of pool2d schedule is empty! Please check.\n"; + auto out = pe::GlobalPool2d(A_tensor, pool_type, tensor_name); + CHECK(out.size() == 2U) + << "The size of pe::GlobalPool2d's output should be 2."; + auto stages = CreateStages({A_tensor, out[0], out[1]}); + *ret = CINNValuePack{ + {CINNValue(out[0]), CINNValue(out[1]), CINNValue(stages)}}; + }); + + framework::CINNSchedule global_pool2d_schedule([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of pool2d schedule is empty! Please check.\n"; if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of pool1d schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of pool1d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; std::vector vec_ast; std::vector vec_tensor; @@ -1429,60 +1711,68 @@ std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, } else { CINN_NOT_IMPLEMENTED } - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else { CINNValuePack arg_pack = args[0]; CHECK(arg_pack.size() == 3UL); - Expr out = arg_pack[0]; + Expr out = arg_pack[0]; Expr reduce = arg_pack[1]; CHECK(out.as_tensor() && reduce.as_tensor()); poly::StageMap stages = arg_pack[arg_pack.size() - 1]; - pe::GlobalPoolScheduleGPU(stages, {out.as_tensor_ref(), reduce.as_tensor_ref()}, target); + pe::GlobalPoolScheduleGPU( + stages, {out.as_tensor_ref(), reduce.as_tensor_ref()}, target); *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; } }); - framework::CINNCompute pool2d_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of pool2d compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - Expr A = pack_args[0]; - CHECK(A.as_tensor()); - ir::Tensor A_tensor = A.as_tensor_ref(); - - std::string tensor_name = UniqName("Pool2d_out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); - } + framework::CINNCompute pool2d_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of pool2d compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + ir::Tensor A_tensor = A.as_tensor_ref(); + + std::string tensor_name = UniqName("Pool2d_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } - auto out = pe::Pool2d(A_tensor, - kernel_size, - stride_size, - padding_size, - pool_type, - ceil_mode, - exclusive, - data_format, - adaptive, - tensor_name); - - auto stages = CreateStages({A_tensor}); - CHECK(out.size() == 1U || out.size() == 2U) << "The size of pe::Pool2d's output should be 1 or 2."; - std::vector res; - for (auto &t : out) { - stages->InsertLazily(t); - res.push_back(CINNValue(t)); - } - CHECK(!out_type.empty()) << "Output type of Pool2d is empty! Please check.\n"; - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); + auto out = pe::Pool2d(A_tensor, + kernel_size, + stride_size, + padding_size, + pool_type, + ceil_mode, + exclusive, + data_format, + adaptive, + tensor_name); + + auto stages = CreateStages({A_tensor}); + CHECK(out.size() == 1U || out.size() == 2U) + << "The size of pe::Pool2d's output should be 1 or 2."; + std::vector res; + for (auto &t : out) { + stages->InsertLazily(t); + res.push_back(CINNValue(t)); + } + CHECK(!out_type.empty()) + << "Output type of Pool2d is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); - framework::CINNSchedule pool2d_schedule([=](lang::Args args, lang::RetValue *ret) { + framework::CINNSchedule pool2d_schedule([=](lang::Args args, + lang::RetValue *ret) { if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of pool2d schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of pool2d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; std::vector vec_ast; std::vector vec_tensor; @@ -1514,10 +1804,12 @@ std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, if (target.arch == Target::Arch::NVGPU) { pe::IRPoolScheduleGPU(ir_sch, target, arg_pack_size); } - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else { - CHECK(!args.empty()) << "The input argument of pool2d schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of pool2d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL); Expr Out = arg_pack[0]; @@ -1540,7 +1832,8 @@ std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, auto strategy = std::make_shared(); bool use_warp_reduce = false; - if (global_pooling && data_format == "NCHW" && target.arch == Target::Arch::NVGPU) { + if (global_pooling && data_format == "NCHW" && + target.arch == Target::Arch::NVGPU) { // TODO 32 may not be the exact number, try also 16 or 8 or other number // we choose 32 to make sure all the threads in a warp has work to do, if ((A_tensor->shape[2].as_int32() * A_tensor->shape[3].as_int32()) >= 32) { @@ -1549,25 +1842,30 @@ std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, } strategy->AddImpl(pool2d_compute, pool2d_schedule, "strategy.pool2d.x86", 1); if (use_warp_reduce) { - strategy->AddImpl(global_pool2d_compute, global_pool2d_schedule, "strategy.pool2d.gpu.global", 2); + strategy->AddImpl(global_pool2d_compute, + global_pool2d_schedule, + "strategy.pool2d.gpu.global", + 2); } return strategy; } -std::vector> InferShapeForPool2d(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector> InferShapeForPool2d( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { CHECK(inputs_shape[0].size() == 4 || inputs_shape[0].size() == 5) - << "The input's shape size of pool2d should be 4 or 5! Please check again."; + << "The input's shape size of pool2d should be 4 or 5! Please check " + "again."; std::vector kernel_size; std::vector stride_size; std::vector padding_size; - std::string pool_type = "max"; - bool ceil_mode = false; - bool exclusive = true; + std::string pool_type = "max"; + bool ceil_mode = false; + bool exclusive = true; std::string data_format = "NCHW"; - bool global_pooling = false; - bool adaptive = false; + bool global_pooling = false; + bool adaptive = false; for (auto &iter : attrs) { if (iter.first == "kernel_size") { kernel_size = absl::get>(iter.second); @@ -1591,126 +1889,149 @@ std::vector> InferShapeForPool2d(const std::vector output_shape1 = inputs_shape[0]; if (ceil_mode) { output_shape1[height_axis] = - (inputs_shape[0][height_axis] - kernel_size[0] + padding_size[0] + padding_size[2] + stride_size[0] - 1) / + (inputs_shape[0][height_axis] - kernel_size[0] + padding_size[0] + + padding_size[2] + stride_size[0] - 1) / stride_size[0] + 1; output_shape1[width_axis] = - (inputs_shape[0][width_axis] - kernel_size[1] + padding_size[1] + padding_size[3] + stride_size[1] - 1) / + (inputs_shape[0][width_axis] - kernel_size[1] + padding_size[1] + + padding_size[3] + stride_size[1] - 1) / stride_size[1] + 1; } else { output_shape1[height_axis] = - (inputs_shape[0][height_axis] - kernel_size[0] + padding_size[0] + padding_size[2]) / stride_size[0] + 1; - output_shape1[width_axis] = - (inputs_shape[0][width_axis] - kernel_size[1] + padding_size[1] + padding_size[3]) / stride_size[1] + 1; + (inputs_shape[0][height_axis] - kernel_size[0] + padding_size[0] + + padding_size[2]) / + stride_size[0] + + 1; + output_shape1[width_axis] = (inputs_shape[0][width_axis] - kernel_size[1] + + padding_size[1] + padding_size[3]) / + stride_size[1] + + 1; } if (adaptive) { kernel_size = absl::get>(attrs.at("kernel_size")); if (kernel_size.size() == 1UL) kernel_size.push_back(kernel_size[0]); - CHECK(kernel_size.size() >= 2UL) << "In pool2d, kernel_size's size should be >= 2, please check!"; + CHECK(kernel_size.size() >= 2UL) + << "In pool2d, kernel_size's size should be >= 2, please check!"; output_shape1[height_axis] = kernel_size[0]; - output_shape1[width_axis] = kernel_size[1]; + output_shape1[width_axis] = kernel_size[1]; } - VLOG(4) << std::boolalpha << "y[" << cinn::utils::Join(output_shape1, ", ") << "] = pool2d(x[" - << cinn::utils::Join(inputs_shape[0], ", ") << "], kernel_size=[" << cinn::utils::Join(kernel_size, ", ") - << "], stride_size=[" << cinn::utils::Join(stride_size, ", ") << "], padding_size=[" - << cinn::utils::Join(padding_size, ", ") << "], pool_type=" << pool_type << ", ceil_mode=" << ceil_mode - << ", exclusive=" << exclusive << ", data_format=" << data_format << ", global_pooling=" << global_pooling - << ", adaptive=" << adaptive; + VLOG(4) << std::boolalpha << "y[" << cinn::utils::Join(output_shape1, ", ") + << "] = pool2d(x[" << cinn::utils::Join(inputs_shape[0], ", ") + << "], kernel_size=[" << cinn::utils::Join(kernel_size, ", ") + << "], stride_size=[" << cinn::utils::Join(stride_size, ", ") + << "], padding_size=[" << cinn::utils::Join(padding_size, ", ") + << "], pool_type=" << pool_type << ", ceil_mode=" << ceil_mode + << ", exclusive=" << exclusive << ", data_format=" << data_format + << ", global_pooling=" << global_pooling << ", adaptive=" << adaptive; std::vector> res{output_shape1}; return res; } -std::shared_ptr StrategyForPool3d(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute pool3d_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of pool3d compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - CHECK(!pack_args.empty()) << "The input tensor of pool3d compute is empty! Please check.\n"; - Expr A = pack_args[0]; - CHECK(A.as_tensor()); - auto attr_store = attrs.attr_store; - std::vector kernel_size; // [kernel_d, kernel_h, kernel_w] - std::vector stride_size; // [stride_d, stride_h, stride_w] - std::vector - padding_size; // [padding_front, padding_top, padding_left, padding_back, padding_bottom, padding_right] - std::string pool_type = "max"; - bool ceil_mode = false; - bool exclusive = true; - std::string data_format = "NCDHW"; - for (auto &iter : attrs.attr_store) { - if (iter.first == "kernel_size") { - kernel_size = absl::get>(iter.second); - } else if (iter.first == "stride_size") { - stride_size = absl::get>(iter.second); - } else if (iter.first == "padding_size") { - padding_size = absl::get>(iter.second); - } else if (iter.first == "pool_type") { - pool_type = absl::get(iter.second); - } else if (iter.first == "ceil_mode") { - ceil_mode = absl::get(iter.second); - } else if (iter.first == "exclusive") { - exclusive = absl::get(iter.second); - } else if (iter.first == "data_format") { - data_format = absl::get(iter.second); - } else { - LOG(ERROR) << "Unsupported attr: " << iter.first << std::endl; - } - } - CHECK(!kernel_size.empty()) << "kernel_size for pool3d is empty. Please check.\n"; - CHECK(!stride_size.empty()) << "stride_size for pool3d is empty. Please check.\n"; - CHECK(!padding_size.empty()) << "padding_size for pool3d is empty. Please check.\n"; - CHECK(pool_type == "max" || pool_type == "avg") << "pool_type for pool3d should be max or avg.\n"; - - std::string tensor_name = UniqName("Pool3d_out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); - } - - auto out = pe::Pool3d(A.as_tensor_ref(), - kernel_size, - stride_size, - padding_size, - pool_type, - ceil_mode, - exclusive, - data_format, - tensor_name); - - auto stages = CreateStages(out); - CHECK(out.size() == 1U || out.size() == 2U) << "The size of pe::Pool3d's output should be 1 or 2."; - CHECK(!out_type.empty()) << "Output type of Pool3d is empty! Please check.\n"; +std::shared_ptr StrategyForPool3d( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute pool3d_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of pool3d compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK(!pack_args.empty()) + << "The input tensor of pool3d compute is empty! Please check.\n"; + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + auto attr_store = attrs.attr_store; + std::vector kernel_size; // [kernel_d, kernel_h, kernel_w] + std::vector stride_size; // [stride_d, stride_h, stride_w] + std::vector + padding_size; // [padding_front, padding_top, padding_left, + // padding_back, padding_bottom, padding_right] + std::string pool_type = "max"; + bool ceil_mode = false; + bool exclusive = true; + std::string data_format = "NCDHW"; + for (auto &iter : attrs.attr_store) { + if (iter.first == "kernel_size") { + kernel_size = absl::get>(iter.second); + } else if (iter.first == "stride_size") { + stride_size = absl::get>(iter.second); + } else if (iter.first == "padding_size") { + padding_size = absl::get>(iter.second); + } else if (iter.first == "pool_type") { + pool_type = absl::get(iter.second); + } else if (iter.first == "ceil_mode") { + ceil_mode = absl::get(iter.second); + } else if (iter.first == "exclusive") { + exclusive = absl::get(iter.second); + } else if (iter.first == "data_format") { + data_format = absl::get(iter.second); + } else { + LOG(ERROR) << "Unsupported attr: " << iter.first << std::endl; + } + } + CHECK(!kernel_size.empty()) + << "kernel_size for pool3d is empty. Please check.\n"; + CHECK(!stride_size.empty()) + << "stride_size for pool3d is empty. Please check.\n"; + CHECK(!padding_size.empty()) + << "padding_size for pool3d is empty. Please check.\n"; + CHECK(pool_type == "max" || pool_type == "avg") + << "pool_type for pool3d should be max or avg.\n"; + + std::string tensor_name = UniqName("Pool3d_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } - std::vector res; - for (auto &t : out) { - res.push_back(CINNValue(Expr(t.get()))); - } - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); + auto out = pe::Pool3d(A.as_tensor_ref(), + kernel_size, + stride_size, + padding_size, + pool_type, + ceil_mode, + exclusive, + data_format, + tensor_name); + + auto stages = CreateStages(out); + CHECK(out.size() == 1U || out.size() == 2U) + << "The size of pe::Pool3d's output should be 1 or 2."; + CHECK(!out_type.empty()) + << "Output type of Pool3d is empty! Please check.\n"; + + std::vector res; + for (auto &t : out) { + res.push_back(CINNValue(Expr(t.get()))); + } + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); - framework::CINNSchedule pool3d_schedule([=](lang::Args args, lang::RetValue *ret) { + framework::CINNSchedule pool3d_schedule([=](lang::Args args, + lang::RetValue *ret) { if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of pool3d schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of pool3d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; std::vector vec_ast; std::vector vec_tensor; @@ -1744,13 +2065,15 @@ std::shared_ptr StrategyForPool3d(const framework::NodeAttr &attrs, ir_sch.Bind(loops[0], "blockIdx.x"); ir_sch.Bind(loops[1], "threadIdx.x"); } - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else { - CHECK(!args.empty()) << "The input argument of pool3d schedule is empty! Please check.\n"; + CHECK(!args.empty()) + << "The input argument of pool3d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL); - Expr Out = arg_pack[0]; + Expr Out = arg_pack[0]; poly::StageMap stages = arg_pack[arg_pack.size() - 1]; if (arg_pack.size() == 3UL) { Expr input_pad = arg_pack[1]; @@ -1774,16 +2097,19 @@ std::shared_ptr StrategyForPool3d(const framework::NodeAttr &attrs, return strategy; } -std::vector> InferShapeForPool3d(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; +std::vector> InferShapeForPool3d( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) + << "The input's shape size is 0! Please check again."; std::vector kernel_size; // [kernel_d, kernel_h, kernel_w] std::vector stride_size; // [stride_d, stride_h, stride_w] std::vector - padding_size; // [padding_front, padding_top, padding_left, padding_bottom, padding_right, padding_back] - std::string pool_type = "max"; - bool ceil_mode = false; - bool exclusive = true; + padding_size; // [padding_front, padding_top, padding_left, + // padding_bottom, padding_right, padding_back] + std::string pool_type = "max"; + bool ceil_mode = false; + bool exclusive = true; std::string data_format = "NCDHW"; for (auto &iter : attrs) { if (iter.first == "kernel_size") { @@ -1803,71 +2129,88 @@ std::vector> InferShapeForPool3d(const std::vector output_shape1 = inputs_shape[0]; - CHECK_EQ(inputs_shape[0].size(), 5U) << "input_shape size for pool3d should be 5.\n"; - int depth_axis = -1; + CHECK_EQ(inputs_shape[0].size(), 5U) + << "input_shape size for pool3d should be 5.\n"; + int depth_axis = -1; int height_axis = -1; - int width_axis = -1; + int width_axis = -1; if (data_format == "NCDHW") { - depth_axis = 2; + depth_axis = 2; height_axis = 3; - width_axis = 4; + width_axis = 4; } else if (data_format == "NDHWC") { - depth_axis = 1; + depth_axis = 1; height_axis = 2; - width_axis = 3; + width_axis = 3; } else { LOG(ERROR) << "unsupported data_format: " << data_format << std::endl; } if (ceil_mode) { output_shape1[depth_axis] = - (inputs_shape[0][depth_axis] - kernel_size[0] + padding_size[0] + padding_size[3] + stride_size[0] - 1) / + (inputs_shape[0][depth_axis] - kernel_size[0] + padding_size[0] + + padding_size[3] + stride_size[0] - 1) / stride_size[0] + 1; output_shape1[height_axis] = - (inputs_shape[0][height_axis] - kernel_size[1] + padding_size[1] + padding_size[4] + stride_size[1] - 1) / + (inputs_shape[0][height_axis] - kernel_size[1] + padding_size[1] + + padding_size[4] + stride_size[1] - 1) / stride_size[1] + 1; output_shape1[width_axis] = - (inputs_shape[0][width_axis] - kernel_size[2] + padding_size[2] + padding_size[5] + stride_size[2] - 1) / + (inputs_shape[0][width_axis] - kernel_size[2] + padding_size[2] + + padding_size[5] + stride_size[2] - 1) / stride_size[2] + 1; } else { - output_shape1[depth_axis] = - (inputs_shape[0][depth_axis] - kernel_size[0] + padding_size[0] + padding_size[3]) / stride_size[0] + 1; + output_shape1[depth_axis] = (inputs_shape[0][depth_axis] - kernel_size[0] + + padding_size[0] + padding_size[3]) / + stride_size[0] + + 1; output_shape1[height_axis] = - (inputs_shape[0][height_axis] - kernel_size[1] + padding_size[1] + padding_size[4]) / stride_size[1] + 1; - output_shape1[width_axis] = - (inputs_shape[0][width_axis] - kernel_size[2] + padding_size[2] + padding_size[5]) / stride_size[2] + 1; + (inputs_shape[0][height_axis] - kernel_size[1] + padding_size[1] + + padding_size[4]) / + stride_size[1] + + 1; + output_shape1[width_axis] = (inputs_shape[0][width_axis] - kernel_size[2] + + padding_size[2] + padding_size[5]) / + stride_size[2] + + 1; } std::vector> res{output_shape1}; return res; } -std::vector InferDtypeForPool(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForPool(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::vector> InferLayoutForPool(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again."; +std::vector> InferLayoutForPool( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 1U) + << "The input's layout size is not 1! Please check again."; return {input_layouts, input_layouts}; } -std::shared_ptr StrategyForSoftmax(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - int axis = -1; +std::shared_ptr StrategyForSoftmax( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + int axis = -1; bool use_mkldnn = false; if (attrs.attr_store.count("axis")) { axis = absl::get(attrs.attr_store.at("axis")); @@ -1875,50 +2218,57 @@ std::shared_ptr StrategyForSoftmax(const framework::NodeAttr &attrs, if (attrs.attr_store.count("use_mkldnn")) { use_mkldnn = absl::get(attrs.attr_store.at("use_mkldnn")); } - framework::CINNCompute softmax_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of softmax compute is empty! Please check."; - CINNValuePack pack_args = args[0]; - CHECK(!pack_args.empty()) << "The input tensors of softmax compute is empty! Please check."; - Expr A_expr = pack_args[0]; - CHECK(A_expr.as_tensor()); - ir::Tensor A = A_expr.as_tensor_ref(); - auto stages = CreateStages({A}); - int new_axis = axis; - if (axis == -1) { - new_axis = A->shape.size() - 1; - } - std::vector out; + framework::CINNCompute softmax_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of softmax compute is empty! Please check."; + CINNValuePack pack_args = args[0]; + CHECK(!pack_args.empty()) + << "The input tensors of softmax compute is empty! Please check."; + Expr A_expr = pack_args[0]; + CHECK(A_expr.as_tensor()); + ir::Tensor A = A_expr.as_tensor_ref(); + auto stages = CreateStages({A}); + int new_axis = axis; + if (axis == -1) { + new_axis = A->shape.size() - 1; + } + std::vector out; - std::string tensor_name = UniqName("Softmax_out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_GE(pack_args.size(), 2); - CHECK(pack_args[pack_args.size() - 1].is_string()); - tensor_name = pack_args[pack_args.size() - 1].operator std::string(); - } + std::string tensor_name = UniqName("Softmax_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_GE(pack_args.size(), 2); + CHECK(pack_args[pack_args.size() - 1].is_string()); + tensor_name = pack_args[pack_args.size() - 1].operator std::string(); + } #ifdef CINN_WITH_MKLDNN - if (use_mkldnn) { - out = pe::SoftmaxMKLDNN(A, new_axis, tensor_name); - } else { - out = pe::Softmax(A, new_axis, tensor_name); - } + if (use_mkldnn) { + out = pe::SoftmaxMKLDNN(A, new_axis, tensor_name); + } else { + out = pe::Softmax(A, new_axis, tensor_name); + } #else - out = pe::Softmax(A, new_axis, tensor_name); + out = pe::Softmax(A, new_axis, tensor_name); #endif - std::vector res; - for (auto &t : out) { - stages->InsertLazily(t); - res.push_back(CINNValue(t)); - } - CHECK_EQ(out.size(), 2U) << "The size of pe::Softmax's output should be 2."; - CHECK(!out_type.empty()) << "Output type of Softmax is empty! Please check.\n"; - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); + std::vector res; + for (auto &t : out) { + stages->InsertLazily(t); + res.push_back(CINNValue(t)); + } + CHECK_EQ(out.size(), 2U) + << "The size of pe::Softmax's output should be 2."; + CHECK(!out_type.empty()) + << "Output type of Softmax is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); - framework::CINNSchedule softmax_schedule([=](lang::Args args, lang::RetValue *ret) { + framework::CINNSchedule softmax_schedule([=](lang::Args args, + lang::RetValue *ret) { if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input arguments of softmax schedule is empty! Please check."; + CHECK(!args.empty()) + << "The input arguments of softmax schedule is empty! Please check."; CINNValuePack arg_pack = args[0]; std::vector vec_ast; for (int i = 0; i < arg_pack.size(); i++) { @@ -1942,31 +2292,35 @@ std::shared_ptr StrategyForSoftmax(const framework::NodeAttr &attrs, ir_sch.SimpleComputeAt(all_blocks[0], loops[0]); } - loops = ir_sch.GetLoops(all_blocks[2]); + loops = ir_sch.GetLoops(all_blocks[2]); int loop_index = 1; if (output_shapes[0][0] == 1) loop_index--; CHECK_GE(loops.size(), loop_index + 1); auto splited_loops = ir_sch.Split(loops[loop_index], {-1, 5}); all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[2]); + loops = ir_sch.GetLoops(all_blocks[2]); ir_sch.Bind(loops[0], "blockIdx.x"); ir_sch.Bind(loops[1], "threadIdx.x"); } - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else if (target.arch == Target::Arch::X86) { pe::IRSoftmaxScheduleCPU(ir_sch, axis); - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } } else { - CHECK(!args.empty()) << "The input arguments of softmax schedule is empty! Please check."; + CHECK(!args.empty()) + << "The input arguments of softmax schedule is empty! Please check."; CINNValuePack arg_pack = args[0]; - CHECK_EQ(arg_pack.size(), 3UL) << "The input tensor's size of softmax schedule is " << arg_pack.size() - << "and it should be equal to 3! Please check."; - Expr out1 = arg_pack[0]; - Expr out2 = arg_pack[1]; + CHECK_EQ(arg_pack.size(), 3UL) + << "The input tensor's size of softmax schedule is " + << arg_pack.size() << "and it should be equal to 3! Please check."; + Expr out1 = arg_pack[0]; + Expr out2 = arg_pack[1]; poly::StageMap stages = arg_pack[2]; CHECK(out1.as_tensor()); CHECK(out2.as_tensor()); @@ -1988,29 +2342,36 @@ std::shared_ptr StrategyForSoftmax(const framework::NodeAttr &attrs, }); auto strategy = std::make_shared(); - strategy->AddImpl(softmax_compute, softmax_schedule, "strategy.softmax.x86", 1); + strategy->AddImpl( + softmax_compute, softmax_schedule, "strategy.softmax.x86", 1); return strategy; } -std::vector> InferShapeForSoftmax(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; +std::vector> InferShapeForSoftmax( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) + << "The input's shape size is 0! Please check again."; std::vector> res{inputs_shape[0]}; return res; } -std::vector InferDtypeForSoftmax(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForSoftmax(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::vector> InferLayoutForSoftmax(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again."; +std::vector> InferLayoutForSoftmax( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 1U) + << "The input's layout size is not 1! Please check again."; if (input_shapes[0].size() > 4) { // input tensor needs to be transformed back to NCHW for mkldnn return {{"NCHW", "NCHW"}, {"NCHW"}}; @@ -2018,24 +2379,30 @@ std::vector> InferLayoutForSoftmax(const std::vector StrategyForDropoutInfer(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - float dropout_prob = 0; +std::shared_ptr StrategyForDropoutInfer( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + float dropout_prob = 0; std::string dropout_implementation = "downgrade_in_infer"; if (attrs.attr_store.find("dropout_prob") != attrs.attr_store.end()) { dropout_prob = absl::get(attrs.attr_store.at("dropout_prob")); } - if (attrs.attr_store.find("dropout_implementation") != attrs.attr_store.end()) { - dropout_implementation = absl::get(attrs.attr_store.at("dropout_implementation")); + if (attrs.attr_store.find("dropout_implementation") != + attrs.attr_store.end()) { + dropout_implementation = + absl::get(attrs.attr_store.at("dropout_implementation")); } - framework::CINNCompute dropout_infer_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of dropout_infer compute is empty! Please check."; + framework::CINNCompute dropout_infer_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of dropout_infer compute is " + "empty! Please check."; CINNValuePack pack_args = args[0]; - CHECK(!pack_args.empty()) << "The input tensors of dropout_infer compute is empty! Please check."; + CHECK(!pack_args.empty()) + << "The input tensors of dropout_infer compute is empty! Please check."; Expr A_expr = pack_args[0]; CHECK(A_expr.as_tensor()); ir::Tensor A = A_expr.as_tensor_ref(); @@ -2047,22 +2414,27 @@ std::shared_ptr StrategyForDropoutInfer(const framework::NodeAttr &a tensor_name = pack_args[1].operator std::string(); } - auto out = pe::DropoutInfer(A, dropout_prob, dropout_implementation, tensor_name); + auto out = + pe::DropoutInfer(A, dropout_prob, dropout_implementation, tensor_name); auto stages = CreateStages({A, out}); - *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; }); auto strategy = std::make_shared(); - strategy->AddImpl( - dropout_infer_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.dropout_infer.x86", 1); + strategy->AddImpl(dropout_infer_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.dropout_infer.x86", + 1); return strategy; } -std::vector> InferShapeForDropoutInfer(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK(!inputs_shape.empty()) << "The input's shape size is 0! Please check again."; - float dropout_prob = 0; +std::vector> InferShapeForDropoutInfer( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(!inputs_shape.empty()) + << "The input's shape size is 0! Please check again."; + float dropout_prob = 0; std::string dropout_implementation = "downgrade_in_infer"; for (auto &iter : attrs) { if (iter.first == "dropout_prob") { @@ -2078,79 +2450,103 @@ std::vector> InferShapeForDropoutInfer(const std::vector InferDtypeForDropoutInfer(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForDropoutInfer( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::shared_ptr StrategyForSelect(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute select_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of select compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 3U) << "at least three input tensor for select compute\n"; - Expr condition = pack_args[0]; - Expr true_value = pack_args[1]; - Expr false_value = pack_args[2]; - CHECK(condition.as_tensor()); - CHECK(true_value.as_tensor()); - CHECK(false_value.as_tensor()); - - std::string tensor_name = UniqName("Select_output"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 4U); - CHECK(pack_args[3].is_string()); - tensor_name = pack_args[3].operator std::string(); - } +std::shared_ptr StrategyForSelect( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute select_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of select compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 3U) + << "at least three input tensor for select compute\n"; + Expr condition = pack_args[0]; + Expr true_value = pack_args[1]; + Expr false_value = pack_args[2]; + CHECK(condition.as_tensor()); + CHECK(true_value.as_tensor()); + CHECK(false_value.as_tensor()); + + std::string tensor_name = UniqName("Select_output"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 4U); + CHECK(pack_args[3].is_string()); + tensor_name = pack_args[3].operator std::string(); + } - auto out = - pe::Select(condition.as_tensor_ref(), true_value.as_tensor_ref(), false_value.as_tensor_ref(), tensor_name); - auto stages = - CreateStages({condition.as_tensor_ref(), true_value.as_tensor_ref(), false_value.as_tensor_ref(), out}); - *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; - }); + auto out = pe::Select(condition.as_tensor_ref(), + true_value.as_tensor_ref(), + false_value.as_tensor_ref(), + tensor_name); + auto stages = CreateStages({condition.as_tensor_ref(), + true_value.as_tensor_ref(), + false_value.as_tensor_ref(), + out}); + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + }); auto strategy = std::make_shared(); CHECK(out_type.size()) << "Out_type of select op is empty! Please check."; - strategy->AddImpl(select_compute, GetInjectiveScheduleFunc(output_shapes, target, false), "strategy.select.x86", 1); + strategy->AddImpl(select_compute, + GetInjectiveScheduleFunc(output_shapes, target, false), + "strategy.select.x86", + 1); return strategy; } -std::vector InferShapeForSelect(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_GE(inputs_shape.size(), 3) << "The input's shape size is 0! Please check again."; - CHECK(inputs_shape[0].size() == inputs_shape[1].size() && inputs_shape[1].size() == inputs_shape[2].size()) +std::vector InferShapeForSelect( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_GE(inputs_shape.size(), 3) + << "The input's shape size is 0! Please check again."; + CHECK(inputs_shape[0].size() == inputs_shape[1].size() && + inputs_shape[1].size() == inputs_shape[2].size()) << "input tensors n_dim is not equal!"; - CHECK(inputs_shape[0] == inputs_shape[1] && inputs_shape[1] == inputs_shape[2]) + CHECK(inputs_shape[0] == inputs_shape[1] && + inputs_shape[1] == inputs_shape[2]) << "input tensor shapes is not equal!"; std::vector res{inputs_shape[0]}; return res; } -std::vector InferDtypeForSelect(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK_GE(inputs_type.size(), 3) << "The input's type size is less than three! Please check again."; +std::vector InferDtypeForSelect(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK_GE(inputs_type.size(), 3) + << "The input's type size is less than three! Please check again."; CHECK(inputs_type[0].is_bool()) << "The condition tensor type should be bool"; - CHECK_EQ(inputs_type[1], inputs_type[2]) << "The true or false tensor type should be equal"; + CHECK_EQ(inputs_type[1], inputs_type[2]) + << "The true or false tensor type should be equal"; std::vector res{inputs_type[1]}; return res; } -std::vector> InferLayoutForUnary(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again."; +std::vector> InferLayoutForUnary( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 1U) + << "The input's layout size is not 1! Please check again."; return {input_layouts, input_layouts}; } // batch norm train -std::vector InferShapeForBatchNormTrain(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 5U) << "The input's layout size is not 5! Please check again."; +std::vector InferShapeForBatchNormTrain( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 5U) + << "The input's layout size is not 5! Please check again."; std::string data_layout = ""; if (attrs.find("data_layout") != attrs.end()) { data_layout = absl::get(attrs.at("data_layout")); @@ -2159,50 +2555,77 @@ std::vector InferShapeForBatchNormTrain(const std::vector InferDtypeForBatchNormTrain(const std::vector &inputs_type, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 5U) << "The BatchNormTrain input's type size should be 5! Please check again."; - CHECK_EQ(inputs_type[1], inputs_type[2]) << "The BatchNormTrain scale type should the same as bias type"; - CHECK_EQ(inputs_type[1], inputs_type[3]) << "The BatchNormTrain scale type should the same as moving_mean type"; - CHECK_EQ(inputs_type[1], inputs_type[4]) << "The BatchNormTrain scale type should the same as moving_variance type"; - return {inputs_type[0], inputs_type[1], inputs_type[1], inputs_type[1], inputs_type[1]}; +std::vector InferDtypeForBatchNormTrain( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 5U) << "The BatchNormTrain input's type size " + "should be 5! Please check again."; + CHECK_EQ(inputs_type[1], inputs_type[2]) + << "The BatchNormTrain scale type should the same as bias type"; + CHECK_EQ(inputs_type[1], inputs_type[3]) + << "The BatchNormTrain scale type should the same as moving_mean type"; + CHECK_EQ(inputs_type[1], inputs_type[4]) + << "The BatchNormTrain scale type should the same as moving_variance " + "type"; + return {inputs_type[0], + inputs_type[1], + inputs_type[1], + inputs_type[1], + inputs_type[1]}; } -std::shared_ptr StrategyForGradOp(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - LOG(FATAL) - << "Gradient operator will be decomposed into several primitive operators. Please Use Decomposer Program Pass."; +std::shared_ptr StrategyForGradOp( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + LOG(FATAL) << "Gradient operator will be decomposed into several primitive " + "operators. Please Use Decomposer Program Pass."; } // batch norm grad -std::vector InferShapeForBatchNormGrad(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 5U) << "The input's layout size is not 5! Please check again."; +std::vector InferShapeForBatchNormGrad( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 5U) + << "The input's layout size is not 5! Please check again."; std::string data_layout = ""; if (attrs.find("data_layout") != attrs.end()) { data_layout = absl::get(attrs.at("data_layout")); @@ -2212,19 +2635,28 @@ std::vector InferShapeForBatchNormGrad(const std::vector InferShapeForBatchNormGrad(const std::vector InferDtypeForBatchNormGrad(const std::vector &inputs_type, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 5U) << "The BatchNormGrad input's type size should be 5! Please check again."; - - CHECK_EQ(inputs_type[0], inputs_type[1]) << "The BatchNormGrad y_grad type should the same as x type"; - CHECK_EQ(inputs_type[2], inputs_type[3]) << "The BatchNormGrad scale type should the same as save_mean type"; - CHECK_EQ(inputs_type[2], inputs_type[4]) << "The BatchNormGrad scale type should the same as save_variance type"; +std::vector InferDtypeForBatchNormGrad( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 5U) + << "The BatchNormGrad input's type size should be 5! Please check again."; + + CHECK_EQ(inputs_type[0], inputs_type[1]) + << "The BatchNormGrad y_grad type should the same as x type"; + CHECK_EQ(inputs_type[2], inputs_type[3]) + << "The BatchNormGrad scale type should the same as save_mean type"; + CHECK_EQ(inputs_type[2], inputs_type[4]) + << "The BatchNormGrad scale type should the same as save_variance type"; return {inputs_type[0], inputs_type[2], inputs_type[2]}; } // pool2d grad -std::vector InferShapeForPool2dGrad(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 3U) << "The operator pool2d_grad should has 3 inputs! Please check again."; +std::vector InferShapeForPool2dGrad( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 3U) + << "The operator pool2d_grad should has 3 inputs! Please check again."; return {inputs_shape[0]}; } -std::vector InferDtypeForPool2dGrad(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 3U) << "The operator pool2d_grad should has 3 inputs! Please check again."; +std::vector InferDtypeForPool2dGrad(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 3U) + << "The operator pool2d_grad should has 3 inputs! Please check again."; return {inputs_type[0]}; } @@ -2260,163 +2700,229 @@ std::vector InferDtypeForPool2dGrad(const std::vector &inputs_type, CINN_REGISTER_HELPER(nn_ops) { CINN_REGISTER_OP(relu) - .describe("Output 0 for each input element < 0. Output itself for each input element >= 0.") + .describe( + "Output 0 for each input element < 0. Output itself for each input " + "element >= 0.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForRelu) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForRelu) .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForRelu)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForRelu)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForUnary)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForUnary)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); CINN_REGISTER_OP(relu6) - .describe("Output 0 for each input element < 0. Output itself for each input element >= 0 and <=6.") + .describe( + "Output 0 for each input element < 0. Output itself for each input " + "element >= 0 and <=6.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForRelu6) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForRelu6) .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForRelu)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForRelu)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForUnary)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForUnary)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); CINN_REGISTER_OP(conv2d) .describe("Do a 2-D convolution with an NCHW/NHWC layout.") .set_num_inputs(2) // here we consider filter as another input .set_num_outputs(4) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForConv2d) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForConv2d)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForConv2d)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForConv2d) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForConv2d)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForConv2d)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForConv2d)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForConv2d)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); CINN_REGISTER_OP(conv2d_NCHWc) - .describe("Do a 2-D convolution with an NCHWc layout. Input is 5D tensor and weight is 6D tensor.") + .describe( + "Do a 2-D convolution with an NCHWc layout. Input is 5D tensor and " + "weight is 6D tensor.") .set_num_inputs(2) // here we consider filter as another input .set_num_outputs(3) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForConv2dNCHWc) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForConv2dNCHWc)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForConv2dNCHWc)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForConv2dNCHWc) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForConv2dNCHWc)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForConv2dNCHWc)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForConv2dNCHWc)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForConv2dNCHWc)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kOutFusible) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kOutFusible) .set_support_level(4); CINN_REGISTER_OP(depthwise_conv2d) .describe("Do a 2-D depthwise convolution with an NCHW/NHWC layout.") .set_num_inputs(2) // here we consider filter as another input .set_num_outputs(4) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForDepthwiseConv2d) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForConv2d)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForConv2d)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForDepthwiseConv2d) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForConv2d)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForConv2d)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForConv2d)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForConv2d)) #endif #ifdef CINN_WITH_CUDNN - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) #else - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kOutFusible) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kOutFusible) #endif .set_support_level(4); CINN_REGISTER_OP(batch_norm) - .describe("Can be used as a normalizer function for convolution or fully_connected operations.") - .set_num_inputs(5) // here we consider batchnorm's 4 attrs(mean, variance, scale, bias) as other 4 inputs + .describe( + "Can be used as a normalizer function for convolution or " + "fully_connected operations.") + .set_num_inputs(5) // here we consider batchnorm's 4 attrs(mean, + // variance, scale, bias) as other 4 inputs .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForBatchNorm) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBatchNorm)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForBatchNorm)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForBatchNorm) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForBatchNorm)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForBatchNorm)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForBatchNorm)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForBatchNorm)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); CINN_REGISTER_OP(pool1d) .describe("Do pooling on the width dimension of the input tensor.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForPool1d) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForPool1d)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForPool1d) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForPool1d)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForPool)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForPool)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForPool)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); CINN_REGISTER_OP(pool2d) - .describe("Do pooling on the height and width dimension of the input tensor.") + .describe( + "Do pooling on the height and width dimension of the input tensor.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForPool2d) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForPool2d)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForPool2d) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForPool2d)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForPool)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForPool)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForPool)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); CINN_REGISTER_OP(pool3d) - .describe("Do pooling on the depth, height and width dimension of the input tensor.") + .describe( + "Do pooling on the depth, height and width dimension of the input " + "tensor.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForPool3d) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForPool3d)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForPool3d) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForPool3d)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForPool)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForPool)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForPool)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); CINN_REGISTER_OP(softmax) .describe("This operator implements the softmax layer") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForSoftmax) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForSoftmax)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForSoftmax)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForSoftmax) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForSoftmax)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForSoftmax)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForSoftmax)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForSoftmax)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); CINN_REGISTER_OP(dropout_infer) .describe("Downgrade the outcome at inference or keep the same.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForDropoutInfer) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForDropoutInfer)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForDropoutInfer)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForDropoutInfer) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForDropoutInfer)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForDropoutInfer)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForUnary)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForUnary)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) .set_support_level(4); CINN_REGISTER_OP(select) .describe("This operator implements the meta op 'Select'.") .set_num_inputs(3) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForSelect) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForSelect)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForSelect)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForSelect) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForSelect)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForSelect)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForUnary)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForUnary)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); return true; @@ -2427,33 +2933,42 @@ CINN_REGISTER_HELPER(nn_grad_ops) { .describe("The gradient of relu.") .set_num_inputs(2) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForGradOp) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForGradOp) .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForRelu)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForRelu)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise); + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise); CINN_REGISTER_OP(batch_norm_train) - .describe("This operator implements the batch normalization training forward.") + .describe( + "This operator implements the batch normalization training forward.") .set_num_inputs(5) .set_num_outputs(5) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBatchNormTrain)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForBatchNormTrain)) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForBatchNormTrain)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForBatchNormTrain)) .set_support_level(4); CINN_REGISTER_OP(batch_norm_grad) .describe("This operator implements the batch normalization backward.") .set_num_inputs(5) .set_num_outputs(3) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForBatchNormGrad)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForBatchNormGrad)) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForBatchNormGrad)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForBatchNormGrad)) .set_support_level(4); CINN_REGISTER_OP(pool2d_grad) .describe("This operator implements the batch normalization backward.") .set_num_inputs(3) .set_num_outputs(1) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForPool2dGrad)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForPool2dGrad)) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForPool2dGrad)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForPool2dGrad)) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/op_broadcast_test.cc b/paddle/cinn/hlir/op/op_broadcast_test.cc index f8af967ff62ec..086cb43528aa6 100755 --- a/paddle/cinn/hlir/op/op_broadcast_test.cc +++ b/paddle/cinn/hlir/op/op_broadcast_test.cc @@ -35,10 +35,11 @@ namespace cinn { namespace hlir { namespace framework { -using CCompute = std::function(const std::vector)>; +using CCompute = + std::function(const std::vector)>; TEST(Operator, Operator_ElementWise_Add_Test0) { - auto add = Operator::Get("elementwise_add"); + auto add = Operator::Get("elementwise_add"); Operator temp = *add; auto strategy = Operator::GetAttrs("CINNStrategy"); @@ -50,7 +51,8 @@ TEST(Operator, Operator_ElementWise_Add_Test0) { std::vector inputs{A.tensor(), B.tensor()}; std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); - auto impl = OpStrategy::SelectImpl(strategy[add](attrs, inputs, type, {{M.as_int32(), N.as_int32()}}, target)); + auto impl = OpStrategy::SelectImpl(strategy[add]( + attrs, inputs, type, {{M.as_int32(), N.as_int32()}}, target)); ASSERT_EQ(impl->name, "strategy.elementwise_add.x86"); ASSERT_EQ(add->description, "elementwise_add function"); @@ -60,19 +62,24 @@ TEST(Operator, Operator_ElementWise_Add_Test0) { if (FLAGS_cinn_ir_schedule) { std::string out_name = "C"; common::CINNValuePack cinn_input = - common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B), common::CINNValue(out_name)}}; + common::CINNValuePack{{common::CINNValue(A), + common::CINNValue(B), + common::CINNValue(out_name)}}; std::vector input_output_names{"A", "B", out_name}; - auto funcs = framework::GetFuncFromImpl(impl, cinn_input, inputs, input_output_names, func_name, target); + auto funcs = framework::GetFuncFromImpl( + impl, cinn_input, inputs, input_output_names, func_name, target); for (auto func : funcs) { - LOG(INFO) << "Test Operator_ElementWise_Add_Test0's Strategy, func is :\n" << func; + LOG(INFO) << "Test Operator_ElementWise_Add_Test0's Strategy, func is :\n" + << func; builder.AddFunction(func); } } else { - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); + common::CINNValuePack cinn_input = + common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); ASSERT_EQ(rets.size(), 2UL); rets = impl->fschedule(rets); ASSERT_EQ(rets.size(), 2UL); @@ -86,7 +93,7 @@ TEST(Operator, Operator_ElementWise_Add_Test0) { builder.AddFunction(func); } - auto jit = backends::ExecutionEngine::Create({}); + auto jit = backends::ExecutionEngine::Create({}); auto module = builder.Build(); jit->Link(module); auto fn = jit->Lookup("fn_" + func_name); @@ -96,13 +103,28 @@ TEST(Operator, Operator_ElementWise_Add_Test0) { cinn_buffer_t *B_buf; int set_value = 0; if (set_value != 0) { - A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_align(512).set_val(set_value).Build(); - B_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_align(512).set_val(set_value).Build(); + A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_align(512) + .set_val(set_value) + .Build(); + B_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_align(512) + .set_val(set_value) + .Build(); } else { - A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_align(512).set_random().Build(); - B_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_align(512).set_random().Build(); + A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_align(512) + .set_random() + .Build(); + B_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_align(512) + .set_random() + .Build(); } - auto *C_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_align(512).set_zero().Build(); + auto *C_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_align(512) + .set_zero() + .Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; @@ -117,7 +139,7 @@ TEST(Operator, Operator_ElementWise_Add_Test0) { } #ifdef CINN_WITH_CUDA TEST(Operator, Operator_ElementWise_Add_Test1) { - auto add = Operator::Get("elementwise_add"); + auto add = Operator::Get("elementwise_add"); Operator temp = *add; auto strategy = Operator::GetAttrs("CINNStrategy"); @@ -130,7 +152,8 @@ TEST(Operator, Operator_ElementWise_Add_Test1) { std::vector inputs{A.tensor(), B.tensor()}; std::vector type{Float(32)}; common::Target target = common::DefaultNVGPUTarget(); - auto impl = OpStrategy::SelectImpl(strategy[add](attrs, inputs, type, {{100, 32}}, target)); + auto impl = OpStrategy::SelectImpl( + strategy[add](attrs, inputs, type, {{100, 32}}, target)); ASSERT_EQ(impl->name, "strategy.elementwise_add.x86"); ASSERT_EQ(add->description, "elementwise_add function"); @@ -140,18 +163,23 @@ TEST(Operator, Operator_ElementWise_Add_Test1) { if (FLAGS_cinn_ir_schedule) { std::string out_name = "C"; common::CINNValuePack cinn_input = - common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B), common::CINNValue(out_name)}}; + common::CINNValuePack{{common::CINNValue(A), + common::CINNValue(B), + common::CINNValue(out_name)}}; std::vector input_output_names{"A", "B", out_name}; - auto funcs = framework::GetFuncFromImpl(impl, cinn_input, inputs, input_output_names, func_name, target); + auto funcs = framework::GetFuncFromImpl( + impl, cinn_input, inputs, input_output_names, func_name, target); for (auto func : funcs) { builder.AddFunction(func); - LOG(INFO) << "Test Operator_ElementWise_Add_Test1's Strategy, func is :\n" << func; + LOG(INFO) << "Test Operator_ElementWise_Add_Test1's Strategy, func is :\n" + << func; } } else { - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); + common::CINNValuePack cinn_input = + common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); ASSERT_EQ(rets.size(), 2UL); rets = impl->fschedule(rets); ASSERT_EQ(rets.size(), 2UL); @@ -167,7 +195,7 @@ TEST(Operator, Operator_ElementWise_Add_Test1) { backends::CodeGenCUDA_Dev codegen(target); - auto module = builder.Build(); + auto module = builder.Build(); auto source_code = codegen.Compile(module); LOG(INFO) << "Operator_ElementWise_Add_Test1 source code:\n" << source_code; } @@ -175,40 +203,44 @@ TEST(Operator, Operator_ElementWise_Add_Test1) { TEST(Operator, Operator_BroadcastTo) { auto broadcast_to = Operator::Get("broadcast_to"); - Operator temp = *broadcast_to; - auto strategy = Operator::GetAttrs("CINNStrategy"); + Operator temp = *broadcast_to; + auto strategy = Operator::GetAttrs("CINNStrategy"); Expr N(1); Placeholder B("B", {N}); NodeAttr attrs; - std::vector out_shape = {16}; + std::vector out_shape = {16}; attrs.attr_store["out_shape"] = out_shape; - std::vector broadcast_axes = {0}; + std::vector broadcast_axes = {0}; attrs.attr_store["broadcast_axes"] = broadcast_axes; std::vector inputs{B.tensor()}; std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); - auto impl = OpStrategy::SelectImpl(strategy[broadcast_to](attrs, inputs, type, {out_shape}, target)); + auto impl = OpStrategy::SelectImpl( + strategy[broadcast_to](attrs, inputs, type, {out_shape}, target)); std::string func_name = "broadcast_to"; if (FLAGS_cinn_ir_schedule) { - std::string out_name = "C"; - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(B), common::CINNValue(out_name)}}; + std::string out_name = "C"; + common::CINNValuePack cinn_input = common::CINNValuePack{ + {common::CINNValue(B), common::CINNValue(out_name)}}; std::vector input_output_names{"B", out_name}; - auto funcs = framework::GetFuncFromImpl(impl, cinn_input, inputs, input_output_names, func_name, target); + auto funcs = framework::GetFuncFromImpl( + impl, cinn_input, inputs, input_output_names, func_name, target); for (auto func : funcs) { LOG(INFO) << "Test Operator_BroadcastTo's Strategy, func is :\n" << func; } } else { - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(B)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); + common::CINNValuePack cinn_input = + common::CINNValuePack{{common::CINNValue(B)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); ASSERT_EQ(rets.size(), 2UL); rets = impl->fschedule(rets); @@ -224,9 +256,10 @@ TEST(Operator, Operator_BroadcastTo) { } } -common::CINNValuePack GetComputeResult(const std::shared_ptr &impl, - std::vector &cinn_inputs, - const std::string &output_name = "") { +common::CINNValuePack GetComputeResult( + const std::shared_ptr &impl, + std::vector &cinn_inputs, + const std::string &output_name = "") { if (FLAGS_cinn_ir_schedule) { cinn_inputs.emplace_back(output_name); } @@ -234,9 +267,9 @@ common::CINNValuePack GetComputeResult(const std::shared_ptr &impl, } TEST(Operator, Operator_BroadcastTo_0) { - auto const_scalar = Operator::Get("const_scalar"); - auto broadcast_to = Operator::Get("broadcast_to"); - auto reduce_sum = Operator::Get("reduce_sum"); + auto const_scalar = Operator::Get("const_scalar"); + auto broadcast_to = Operator::Get("broadcast_to"); + auto reduce_sum = Operator::Get("reduce_sum"); auto elementwise_add = Operator::Get("elementwise_mul"); auto strategy = Operator::GetAttrs("CINNStrategy"); @@ -247,55 +280,65 @@ TEST(Operator, Operator_BroadcastTo_0) { NodeAttr attrs; attrs.attr_store["value"] = 0.5f; - std::vector out_shape = {16}; + std::vector out_shape = {16}; attrs.attr_store["out_shape"] = out_shape; - std::vector broadcast_axes = {0}; + std::vector broadcast_axes = {0}; attrs.attr_store["broadcast_axes"] = broadcast_axes; - std::vector dim = {0, 2, 3}; + std::vector dim = {0, 2, 3}; attrs.attr_store["dim"] = dim; std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); - auto impl_0 = - OpStrategy::SelectImpl(strategy[const_scalar](attrs, std::vector{}, type, {out_shape}, target)); + auto impl_0 = OpStrategy::SelectImpl(strategy[const_scalar]( + attrs, std::vector{}, type, {out_shape}, target)); std::vector cinn_inputs; common::CINNValuePack rets_0 = GetComputeResult(impl_0, cinn_inputs, "out_0"); - ir::Expr out_0 = rets_0[0]; - auto tensor_0 = out_0.as_tensor_ref(); - poly::StageMap stages_0 = rets_0.back(); - - auto impl_1 = OpStrategy::SelectImpl(strategy[broadcast_to](attrs, {tensor_0}, type, {out_shape}, target)); - std::vector cinn_inputs_1 = {{common::CINNValue(tensor_0)}}; - common::CINNValuePack rets_1 = GetComputeResult(impl_1, cinn_inputs_1, "out_1"); - - ir::Expr out_1 = rets_1[0]; - auto tensor_1 = out_1.as_tensor_ref(); + ir::Expr out_0 = rets_0[0]; + auto tensor_0 = out_0.as_tensor_ref(); + poly::StageMap stages_0 = rets_0.back(); + + auto impl_1 = OpStrategy::SelectImpl( + strategy[broadcast_to](attrs, {tensor_0}, type, {out_shape}, target)); + std::vector cinn_inputs_1 = { + {common::CINNValue(tensor_0)}}; + common::CINNValuePack rets_1 = + GetComputeResult(impl_1, cinn_inputs_1, "out_1"); + + ir::Expr out_1 = rets_1[0]; + auto tensor_1 = out_1.as_tensor_ref(); poly::StageMap stages_1 = rets_1.back(); - auto impl_2 = OpStrategy::SelectImpl(strategy[reduce_sum](attrs, {A.tensor()}, type, {out_shape}, target)); - std::vector cinn_inputs_2 = {{common::CINNValue(A.tensor())}}; - common::CINNValuePack rets_2 = GetComputeResult(impl_2, cinn_inputs_2, "out_2"); + auto impl_2 = OpStrategy::SelectImpl( + strategy[reduce_sum](attrs, {A.tensor()}, type, {out_shape}, target)); + std::vector cinn_inputs_2 = { + {common::CINNValue(A.tensor())}}; + common::CINNValuePack rets_2 = + GetComputeResult(impl_2, cinn_inputs_2, "out_2"); - ir::Expr out_2 = rets_2[0]; - auto tensor_2 = out_2.as_tensor_ref(); + ir::Expr out_2 = rets_2[0]; + auto tensor_2 = out_2.as_tensor_ref(); poly::StageMap stages_2 = rets_2.back(); - std::vector cinn_inputs_4 = {{common::CINNValue(A.tensor())}}; - common::CINNValuePack rets_4 = GetComputeResult(impl_2, cinn_inputs_4, "out_4"); - ir::Expr out_4 = rets_4[0]; - auto tensor_4 = out_4.as_tensor_ref(); - poly::StageMap stages_4 = rets_4.back(); - - auto impl_3 = - OpStrategy::SelectImpl(strategy[elementwise_add](attrs, {tensor_1, tensor_2}, type, {out_shape}, target)); - std::vector cinn_inputs_3 = {{common::CINNValue(tensor_1), common::CINNValue(tensor_2)}}; - common::CINNValuePack rets_3 = GetComputeResult(impl_3, cinn_inputs_3, "out_3"); - - ir::Expr out_3 = rets_3[0]; - auto tensor_3 = out_3.as_tensor_ref(); + std::vector cinn_inputs_4 = { + {common::CINNValue(A.tensor())}}; + common::CINNValuePack rets_4 = + GetComputeResult(impl_2, cinn_inputs_4, "out_4"); + ir::Expr out_4 = rets_4[0]; + auto tensor_4 = out_4.as_tensor_ref(); + poly::StageMap stages_4 = rets_4.back(); + + auto impl_3 = OpStrategy::SelectImpl(strategy[elementwise_add]( + attrs, {tensor_1, tensor_2}, type, {out_shape}, target)); + std::vector cinn_inputs_3 = { + {common::CINNValue(tensor_1), common::CINNValue(tensor_2)}}; + common::CINNValuePack rets_3 = + GetComputeResult(impl_3, cinn_inputs_3, "out_3"); + + ir::Expr out_3 = rets_3[0]; + auto tensor_3 = out_3.as_tensor_ref(); poly::StageMap stages_3 = rets_3.back(); stages_3->InsertLazily(tensor_0, stages_0[tensor_0]); @@ -309,7 +352,7 @@ TEST(Operator, Operator_BroadcastTo_0) { stages_3[tensor_2]->SimpleComputeAt(stages_3[tensor_3], 0); std::vector inputs = {A.tensor(), tensor_3, tensor_4}; - auto func = Lower("broadcast_to", stages_3, inputs); + auto func = Lower("broadcast_to", stages_3, inputs); LOG(INFO) << "Test Strategy Codegen:\n" << func; } diff --git a/paddle/cinn/hlir/op/op_nn_test.cc b/paddle/cinn/hlir/op/op_nn_test.cc index 6f6b26407761d..b2dff5cfdb7ee 100644 --- a/paddle/cinn/hlir/op/op_nn_test.cc +++ b/paddle/cinn/hlir/op/op_nn_test.cc @@ -36,7 +36,8 @@ namespace cinn { namespace hlir { namespace framework { -using CCompute = std::function(const std::vector)>; +using CCompute = + std::function(const std::vector)>; Module LowerToModule(const std::string test_name, const std::string func_name, @@ -53,7 +54,8 @@ Module LowerToModule(const std::string test_name, common::CINNValuePack cinn_input = common::CINNValuePack{cinn_inputs}; input_names.push_back(output_name); - auto funcs = framework::GetFuncFromImpl(impl, cinn_input, inputs, input_names, func_name, target); + auto funcs = framework::GetFuncFromImpl( + impl, cinn_input, inputs, input_names, func_name, target); for (auto func : funcs) { LOG(INFO) << "Test" << test_name << "'s Strategy, func is :\n" << func; @@ -61,8 +63,8 @@ Module LowerToModule(const std::string test_name, } } else { common::CINNValuePack cinn_input = common::CINNValuePack{cinn_inputs}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - rets = impl->fschedule(rets); + common::CINNValuePack rets = impl->fcompute(cinn_input); + rets = impl->fschedule(rets); // the last element is a StageMap for (int i = 0; i < rets->size() - 1; i++) { Expr temp = rets[i]; @@ -78,7 +80,7 @@ Module LowerToModule(const std::string test_name, } TEST(Operator, Operator_Pool2d_Test0) { - auto pool2d = Operator::Get("pool2d"); + auto pool2d = Operator::Get("pool2d"); Operator temp = *pool2d; auto strategy = Operator::GetAttrs("CINNStrategy"); @@ -86,22 +88,29 @@ TEST(Operator, Operator_Pool2d_Test0) { Placeholder A("A", {N, C, H, W}); NodeAttr attrs; - std::vector kernel_size = {2, 2}; - std::vector stride_size = {2, 2}; - std::vector padding_size = {1, 1, 1, 1}; - std::string pool_type = "max"; - attrs.attr_store["kernel_size"] = kernel_size; - attrs.attr_store["stride_size"] = stride_size; + std::vector kernel_size = {2, 2}; + std::vector stride_size = {2, 2}; + std::vector padding_size = {1, 1, 1, 1}; + std::string pool_type = "max"; + attrs.attr_store["kernel_size"] = kernel_size; + attrs.attr_store["stride_size"] = stride_size; attrs.attr_store["padding_size"] = padding_size; - attrs.attr_store["pool_type"] = pool_type; + attrs.attr_store["pool_type"] = pool_type; std::vector inputs{A.tensor()}; std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); - auto impl = OpStrategy::SelectImpl(strategy[pool2d](attrs, inputs, type, {{1, 3, 10, 10}, {1, 3, 5, 5}}, target)); + auto impl = OpStrategy::SelectImpl(strategy[pool2d]( + attrs, inputs, type, {{1, 3, 10, 10}, {1, 3, 5, 5}}, target)); std::string func_name = "pool2d"; - auto module = - LowerToModule("Operator_Pool2d_Test0", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + auto module = LowerToModule("Operator_Pool2d_Test0", + func_name, + impl, + {"A"}, + "B", + inputs, + {common::CINNValue(A)}, + target); auto jit = backends::ExecutionEngine::Create({}); @@ -110,19 +119,24 @@ TEST(Operator, Operator_Pool2d_Test0) { CHECK(fn); auto fn_ = reinterpret_cast(fn); - cinn_buffer_t *A_buf = common::BufferBuilder(Float(32), {1, 3, 8, 8}).set_random().Build(); - cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), {1, 3, 10, 10}).set_random().Build(); - cinn_buffer_t *C_buf = common::BufferBuilder(Float(32), {1, 3, 5, 5}).set_random().Build(); + cinn_buffer_t *A_buf = + common::BufferBuilder(Float(32), {1, 3, 8, 8}).set_random().Build(); + cinn_buffer_t *B_buf = + common::BufferBuilder(Float(32), {1, 3, 10, 10}).set_random().Build(); + cinn_buffer_t *C_buf = + common::BufferBuilder(Float(32), {1, 3, 5, 5}).set_random().Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; fn_(args, 3); ASSERT_EQ(impl->name, "strategy.pool2d.x86"); - ASSERT_EQ(pool2d->description, "Do pooling on the height and width dimension of the input tensor."); + ASSERT_EQ( + pool2d->description, + "Do pooling on the height and width dimension of the input tensor."); } TEST(Operator, Operator_Pool2d_Test1) { - auto pool2d = Operator::Get("pool2d"); + auto pool2d = Operator::Get("pool2d"); Operator temp = *pool2d; auto strategy = Operator::GetAttrs("CINNStrategy"); @@ -130,25 +144,32 @@ TEST(Operator, Operator_Pool2d_Test1) { Placeholder A("A", {N, C, H, W}); NodeAttr attrs; - std::vector kernel_size = {2, 2}; - std::vector stride_size = {2, 2}; - std::vector padding_size = {1, 1, 1, 1}; - std::string pool_type = "avg"; - attrs.attr_store["kernel_size"] = kernel_size; - attrs.attr_store["stride_size"] = stride_size; + std::vector kernel_size = {2, 2}; + std::vector stride_size = {2, 2}; + std::vector padding_size = {1, 1, 1, 1}; + std::string pool_type = "avg"; + attrs.attr_store["kernel_size"] = kernel_size; + attrs.attr_store["stride_size"] = stride_size; attrs.attr_store["padding_size"] = padding_size; - attrs.attr_store["pool_type"] = pool_type; - attrs.attr_store["ceil_mode"] = true; - attrs.attr_store["exclusive"] = false; + attrs.attr_store["pool_type"] = pool_type; + attrs.attr_store["ceil_mode"] = true; + attrs.attr_store["exclusive"] = false; std::vector inputs{A.tensor()}; std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); - auto impl = OpStrategy::SelectImpl(strategy[pool2d](attrs, inputs, type, {{1, 3, 11, 11}, {1, 3, 5, 5}}, target)); + auto impl = OpStrategy::SelectImpl(strategy[pool2d]( + attrs, inputs, type, {{1, 3, 11, 11}, {1, 3, 5, 5}}, target)); std::string func_name = "pool2d"; - auto module = - LowerToModule("Operator_Pool2d_Test1", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + auto module = LowerToModule("Operator_Pool2d_Test1", + func_name, + impl, + {"A"}, + "B", + inputs, + {common::CINNValue(A)}, + target); auto jit = backends::ExecutionEngine::Create({}); @@ -157,19 +178,24 @@ TEST(Operator, Operator_Pool2d_Test1) { CHECK(fn); auto fn_ = reinterpret_cast(fn); - cinn_buffer_t *A_buf = common::BufferBuilder(Float(32), {1, 3, 8, 8}).set_random().Build(); - cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), {1, 3, 11, 11}).set_random().Build(); - cinn_buffer_t *C_buf = common::BufferBuilder(Float(32), {1, 3, 5, 5}).set_random().Build(); + cinn_buffer_t *A_buf = + common::BufferBuilder(Float(32), {1, 3, 8, 8}).set_random().Build(); + cinn_buffer_t *B_buf = + common::BufferBuilder(Float(32), {1, 3, 11, 11}).set_random().Build(); + cinn_buffer_t *C_buf = + common::BufferBuilder(Float(32), {1, 3, 5, 5}).set_random().Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; fn_(args, 3); ASSERT_EQ(impl->name, "strategy.pool2d.x86"); - ASSERT_EQ(pool2d->description, "Do pooling on the height and width dimension of the input tensor."); + ASSERT_EQ( + pool2d->description, + "Do pooling on the height and width dimension of the input tensor."); } TEST(Operator, Operator_Pool2d_Test2) { - auto pool2d = Operator::Get("pool2d"); + auto pool2d = Operator::Get("pool2d"); Operator temp = *pool2d; auto strategy = Operator::GetAttrs("CINNStrategy"); @@ -177,27 +203,34 @@ TEST(Operator, Operator_Pool2d_Test2) { Placeholder A("A", {N, H, W, C}); NodeAttr attrs; - std::vector kernel_size = {2, 2}; - std::vector stride_size = {2, 2}; - std::vector padding_size = {1, 1, 1, 1}; - std::string pool_type = "avg"; - std::string data_format = "NHWC"; - attrs.attr_store["kernel_size"] = kernel_size; - attrs.attr_store["stride_size"] = stride_size; + std::vector kernel_size = {2, 2}; + std::vector stride_size = {2, 2}; + std::vector padding_size = {1, 1, 1, 1}; + std::string pool_type = "avg"; + std::string data_format = "NHWC"; + attrs.attr_store["kernel_size"] = kernel_size; + attrs.attr_store["stride_size"] = stride_size; attrs.attr_store["padding_size"] = padding_size; - attrs.attr_store["pool_type"] = pool_type; - attrs.attr_store["ceil_mode"] = true; - attrs.attr_store["exclusive"] = true; - attrs.attr_store["data_format"] = data_format; + attrs.attr_store["pool_type"] = pool_type; + attrs.attr_store["ceil_mode"] = true; + attrs.attr_store["exclusive"] = true; + attrs.attr_store["data_format"] = data_format; std::vector inputs{A.tensor()}; std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); - auto impl = OpStrategy::SelectImpl(strategy[pool2d](attrs, inputs, type, {{1, 11, 11, 3}, {1, 5, 5, 3}}, target)); + auto impl = OpStrategy::SelectImpl(strategy[pool2d]( + attrs, inputs, type, {{1, 11, 11, 3}, {1, 5, 5, 3}}, target)); std::string func_name = "pool2d"; - auto module = - LowerToModule("Operator_Pool2d_Test2", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + auto module = LowerToModule("Operator_Pool2d_Test2", + func_name, + impl, + {"A"}, + "B", + inputs, + {common::CINNValue(A)}, + target); auto jit = backends::ExecutionEngine::Create({}); @@ -206,19 +239,24 @@ TEST(Operator, Operator_Pool2d_Test2) { CHECK(fn); auto fn_ = reinterpret_cast(fn); - cinn_buffer_t *A_buf = common::BufferBuilder(Float(32), {1, 8, 8, 3}).set_random().Build(); - cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), {1, 11, 11, 3}).set_random().Build(); - cinn_buffer_t *C_buf = common::BufferBuilder(Float(32), {1, 5, 5, 3}).set_random().Build(); + cinn_buffer_t *A_buf = + common::BufferBuilder(Float(32), {1, 8, 8, 3}).set_random().Build(); + cinn_buffer_t *B_buf = + common::BufferBuilder(Float(32), {1, 11, 11, 3}).set_random().Build(); + cinn_buffer_t *C_buf = + common::BufferBuilder(Float(32), {1, 5, 5, 3}).set_random().Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; fn_(args, 3); ASSERT_EQ(impl->name, "strategy.pool2d.x86"); - ASSERT_EQ(pool2d->description, "Do pooling on the height and width dimension of the input tensor."); + ASSERT_EQ( + pool2d->description, + "Do pooling on the height and width dimension of the input tensor."); } TEST(Operator, Operator_Pool3d_Test0) { - auto pool3d = Operator::Get("pool3d"); + auto pool3d = Operator::Get("pool3d"); Operator temp = *pool3d; auto strategy = Operator::GetAttrs("CINNStrategy"); @@ -226,27 +264,33 @@ TEST(Operator, Operator_Pool3d_Test0) { Placeholder A("A", {N, D, H, W, C}); NodeAttr attrs; - std::vector kernel_size = {2, 2, 2}; - std::vector stride_size = {2, 2, 2}; - std::vector padding_size = {1, 1, 1, 1, 1, 1}; - std::string pool_type = "max"; - std::string data_format = "NDHWC"; - attrs.attr_store["kernel_size"] = kernel_size; - attrs.attr_store["stride_size"] = stride_size; + std::vector kernel_size = {2, 2, 2}; + std::vector stride_size = {2, 2, 2}; + std::vector padding_size = {1, 1, 1, 1, 1, 1}; + std::string pool_type = "max"; + std::string data_format = "NDHWC"; + attrs.attr_store["kernel_size"] = kernel_size; + attrs.attr_store["stride_size"] = stride_size; attrs.attr_store["padding_size"] = padding_size; - attrs.attr_store["pool_type"] = pool_type; - attrs.attr_store["ceil_mode"] = false; - attrs.attr_store["exclusive"] = true; - attrs.attr_store["data_format"] = data_format; + attrs.attr_store["pool_type"] = pool_type; + attrs.attr_store["ceil_mode"] = false; + attrs.attr_store["exclusive"] = true; + attrs.attr_store["data_format"] = data_format; std::vector inputs{A.tensor()}; std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); - auto impl = - OpStrategy::SelectImpl(strategy[pool3d](attrs, inputs, type, {{1, 11, 11, 11, 3}, {1, 5, 5, 5, 3}}, target)); + auto impl = OpStrategy::SelectImpl(strategy[pool3d]( + attrs, inputs, type, {{1, 11, 11, 11, 3}, {1, 5, 5, 5, 3}}, target)); std::string func_name = "pool3d"; - auto module = - LowerToModule("Operator_Pool3d_Test0", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + auto module = LowerToModule("Operator_Pool3d_Test0", + func_name, + impl, + {"A"}, + "B", + inputs, + {common::CINNValue(A)}, + target); auto jit = backends::ExecutionEngine::Create({}); @@ -255,19 +299,24 @@ TEST(Operator, Operator_Pool3d_Test0) { CHECK(fn); auto fn_ = reinterpret_cast(fn); - cinn_buffer_t *A_buf = common::BufferBuilder(Float(32), {1, 8, 8, 8, 3}).set_random().Build(); - cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), {1, 11, 11, 11, 3}).set_random().Build(); - cinn_buffer_t *C_buf = common::BufferBuilder(Float(32), {1, 5, 5, 5, 3}).set_random().Build(); + cinn_buffer_t *A_buf = + common::BufferBuilder(Float(32), {1, 8, 8, 8, 3}).set_random().Build(); + cinn_buffer_t *B_buf = + common::BufferBuilder(Float(32), {1, 11, 11, 11, 3}).set_random().Build(); + cinn_buffer_t *C_buf = + common::BufferBuilder(Float(32), {1, 5, 5, 5, 3}).set_random().Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; fn_(args, 3); ASSERT_EQ(impl->name, "strategy.pool3d.x86"); - ASSERT_EQ(pool3d->description, "Do pooling on the depth, height and width dimension of the input tensor."); + ASSERT_EQ(pool3d->description, + "Do pooling on the depth, height and width dimension of the input " + "tensor."); } TEST(Operator, Operator_Pool1d_Test0) { - auto pool1d = Operator::Get("pool1d"); + auto pool1d = Operator::Get("pool1d"); Operator temp = *pool1d; auto strategy = Operator::GetAttrs("CINNStrategy"); @@ -275,26 +324,33 @@ TEST(Operator, Operator_Pool1d_Test0) { Placeholder A("A", {N, W, C}); NodeAttr attrs; - std::vector kernel_size = {2}; - std::vector stride_size = {2}; - std::vector padding_size = {1, 1}; - std::string pool_type = "max"; - std::string data_format = "NWC"; - attrs.attr_store["kernel_size"] = kernel_size; - attrs.attr_store["stride_size"] = stride_size; + std::vector kernel_size = {2}; + std::vector stride_size = {2}; + std::vector padding_size = {1, 1}; + std::string pool_type = "max"; + std::string data_format = "NWC"; + attrs.attr_store["kernel_size"] = kernel_size; + attrs.attr_store["stride_size"] = stride_size; attrs.attr_store["padding_size"] = padding_size; - attrs.attr_store["pool_type"] = pool_type; - attrs.attr_store["ceil_mode"] = false; - attrs.attr_store["exclusive"] = true; - attrs.attr_store["data_format"] = data_format; + attrs.attr_store["pool_type"] = pool_type; + attrs.attr_store["ceil_mode"] = false; + attrs.attr_store["exclusive"] = true; + attrs.attr_store["data_format"] = data_format; std::vector inputs{A.tensor()}; std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); - auto impl = OpStrategy::SelectImpl(strategy[pool1d](attrs, inputs, type, {{1, 11, 3}, {1, 5, 3}}, target)); + auto impl = OpStrategy::SelectImpl( + strategy[pool1d](attrs, inputs, type, {{1, 11, 3}, {1, 5, 3}}, target)); std::string func_name = "pool1d"; - auto module = - LowerToModule("Operator_Pool1d_Test0", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + auto module = LowerToModule("Operator_Pool1d_Test0", + func_name, + impl, + {"A"}, + "B", + inputs, + {common::CINNValue(A)}, + target); auto jit = backends::ExecutionEngine::Create({}); @@ -303,22 +359,27 @@ TEST(Operator, Operator_Pool1d_Test0) { CHECK(fn); auto fn_ = reinterpret_cast(fn); - cinn_buffer_t *A_buf = common::BufferBuilder(Float(32), {1, 8, 3}).set_random().Build(); - cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), {1, 11, 3}).set_random().Build(); - cinn_buffer_t *C_buf = common::BufferBuilder(Float(32), {1, 5, 3}).set_random().Build(); + cinn_buffer_t *A_buf = + common::BufferBuilder(Float(32), {1, 8, 3}).set_random().Build(); + cinn_buffer_t *B_buf = + common::BufferBuilder(Float(32), {1, 11, 3}).set_random().Build(); + cinn_buffer_t *C_buf = + common::BufferBuilder(Float(32), {1, 5, 3}).set_random().Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; fn_(args, 3); ASSERT_EQ(impl->name, "strategy.pool1d.x86"); - ASSERT_EQ(pool1d->description, "Do pooling on the width dimension of the input tensor."); + ASSERT_EQ(pool1d->description, + "Do pooling on the width dimension of the input tensor."); } TEST(Operator, Operator_Select_Test0) { - auto select = Operator::Get("select"); - Operator temp = *select; - auto strategy = Operator::GetAttrs("CINNStrategy"); - auto infer_shape_func = Operator::GetAttrs("infershape")[select]; + auto select = Operator::Get("select"); + Operator temp = *select; + auto strategy = Operator::GetAttrs("CINNStrategy"); + auto infer_shape_func = + Operator::GetAttrs("infershape")[select]; Expr C(16), H(64), W(64); Placeholder condition("condition", {C, H, W}); @@ -326,25 +387,36 @@ TEST(Operator, Operator_Select_Test0) { Placeholder false_value("false_value", {C, H, W}); NodeAttr attrs; - std::vector inputs{condition.tensor(), true_value.tensor(), false_value.tensor()}; + std::vector inputs{ + condition.tensor(), true_value.tensor(), false_value.tensor()}; std::vector type{Float(32)}; const common::Target target = common::DefaultHostTarget(); - const std::vector input_shapes = {{16, 64, 64}, {16, 64, 64}, {16, 64, 64}}; - auto infer_shape = infer_shape_func(input_shapes, attrs.attr_store); + const std::vector input_shapes = { + {16, 64, 64}, {16, 64, 64}, {16, 64, 64}}; + auto infer_shape = infer_shape_func(input_shapes, attrs.attr_store); ASSERT_EQ(infer_shape[0][0], 16); ASSERT_EQ(infer_shape[0][1], 64); ASSERT_EQ(infer_shape[0][2], 64); - auto impl = OpStrategy::SelectImpl(strategy[select](attrs, inputs, type, {{16, 64, 64}}, target)); - - std::string func_name = "select"; - std::vector input_names = {"condition", "true_value", "false_value"}; - std::vector cinn_inputs = { - common::CINNValue(condition), common::CINNValue(true_value), common::CINNValue(false_value)}; - - auto module = LowerToModule( - "Operator_Select_Test0", func_name, impl, std::move(input_names), "output", inputs, cinn_inputs, target); + auto impl = OpStrategy::SelectImpl( + strategy[select](attrs, inputs, type, {{16, 64, 64}}, target)); + + std::string func_name = "select"; + std::vector input_names = { + "condition", "true_value", "false_value"}; + std::vector cinn_inputs = {common::CINNValue(condition), + common::CINNValue(true_value), + common::CINNValue(false_value)}; + + auto module = LowerToModule("Operator_Select_Test0", + func_name, + impl, + std::move(input_names), + "output", + inputs, + cinn_inputs, + target); auto jit = backends::ExecutionEngine::Create({}); @@ -353,19 +425,23 @@ TEST(Operator, Operator_Select_Test0) { CHECK(fn); auto fn_ = reinterpret_cast(fn); - cinn_buffer_t *A_buf = common::BufferBuilder(Bool(), {16, 64, 64}).set_random().Build(); - cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), {16, 64, 64}).set_random().Build(); - cinn_buffer_t *C_buf = common::BufferBuilder(Float(32), {16, 64, 64}).set_random().Build(); - cinn_buffer_t *D_buf = common::BufferBuilder(Float(32), {16, 64, 64}).set_random().Build(); + cinn_buffer_t *A_buf = + common::BufferBuilder(Bool(), {16, 64, 64}).set_random().Build(); + cinn_buffer_t *B_buf = + common::BufferBuilder(Float(32), {16, 64, 64}).set_random().Build(); + cinn_buffer_t *C_buf = + common::BufferBuilder(Float(32), {16, 64, 64}).set_random().Build(); + cinn_buffer_t *D_buf = + common::BufferBuilder(Float(32), {16, 64, 64}).set_random().Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf), d_arg(D_buf); cinn_pod_value_t args[] = {a_arg, b_arg, c_arg, d_arg}; fn_(args, 4); - auto condition_ = reinterpret_cast(A_buf->memory); - auto true_value_ = reinterpret_cast(B_buf->memory); + auto condition_ = reinterpret_cast(A_buf->memory); + auto true_value_ = reinterpret_cast(B_buf->memory); auto false_value_ = reinterpret_cast(C_buf->memory); - auto output_ = reinterpret_cast(D_buf->memory); + auto output_ = reinterpret_cast(D_buf->memory); for (int i = 0; i < A_buf->num_elements(); i++) { if (static_cast(condition_[i])) { @@ -376,11 +452,12 @@ TEST(Operator, Operator_Select_Test0) { } ASSERT_EQ(impl->name, "strategy.select.x86"); - ASSERT_EQ(select->description, "This operator implements the meta op 'Select'."); + ASSERT_EQ(select->description, + "This operator implements the meta op 'Select'."); } TEST(Operator, Operator_Reverse_Test0) { - auto reverse = Operator::Get("reverse"); + auto reverse = Operator::Get("reverse"); Operator temp = *reverse; auto strategy = Operator::GetAttrs("CINNStrategy"); @@ -389,17 +466,24 @@ TEST(Operator, Operator_Reverse_Test0) { Placeholder A("A", {C, H, W}); NodeAttr attrs; - std::vector axis = {1, 2}; + std::vector axis = {1, 2}; attrs.attr_store["axis"] = axis; std::vector inputs{A.tensor()}; std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); - auto impl = OpStrategy::SelectImpl(strategy[reverse](attrs, inputs, type, {{c, h, w}}, target)); + auto impl = OpStrategy::SelectImpl( + strategy[reverse](attrs, inputs, type, {{c, h, w}}, target)); std::string func_name = "reverse"; - auto module = - LowerToModule("Operator_Reverse_Test0", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + auto module = LowerToModule("Operator_Reverse_Test0", + func_name, + impl, + {"A"}, + "B", + inputs, + {common::CINNValue(A)}, + target); auto jit = backends::ExecutionEngine::Create({}); @@ -408,19 +492,21 @@ TEST(Operator, Operator_Reverse_Test0) { CHECK(fn); auto fn_ = reinterpret_cast(fn); - cinn_buffer_t *A_buf = common::BufferBuilder(Float(32), {c, h, w}).set_random().Build(); - cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), {c, h, w}).set_random().Build(); + cinn_buffer_t *A_buf = + common::BufferBuilder(Float(32), {c, h, w}).set_random().Build(); + cinn_buffer_t *B_buf = + common::BufferBuilder(Float(32), {c, h, w}).set_random().Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf); cinn_pod_value_t args[] = {a_arg, b_arg}; fn_(args, 2); - auto input = reinterpret_cast(A_buf->memory); + auto input = reinterpret_cast(A_buf->memory); auto output = reinterpret_cast(B_buf->memory); for (int ida = 0; ida < c; ++ida) { for (int idb = 0; idb < h; ++idb) { for (int idc = 0; idc < w; ++idc) { - int index = ida * h * w + idb * h + idc; + int index = ida * h * w + idb * h + idc; int index_ = ida * h * w + (h - 1 - idb) * h + (w - 1 - idc); ASSERT_EQ(output[index], input[index_]); } @@ -428,21 +514,23 @@ TEST(Operator, Operator_Reverse_Test0) { } ASSERT_EQ(impl->name, "strategy.reverse.x86"); - ASSERT_EQ(reverse->description, "This operator implements the meta op reverse."); + ASSERT_EQ(reverse->description, + "This operator implements the meta op reverse."); } TEST(Operator, Operator_Transpose_Test0) { - auto transpose = Operator::Get("transpose"); - Operator temp = *transpose; - auto strategy = Operator::GetAttrs("CINNStrategy"); - auto infer_shape_func = Operator::GetAttrs("infershape")[transpose]; + auto transpose = Operator::Get("transpose"); + Operator temp = *transpose; + auto strategy = Operator::GetAttrs("CINNStrategy"); + auto infer_shape_func = + Operator::GetAttrs("infershape")[transpose]; int n = 16, c = 3, h = 32, w = 32; Expr N(n), C(c), H(h), W(w); Placeholder A("A", {N, C, H, W}); NodeAttr attrs; - std::vector axis = {0, 2, 3, 1}; + std::vector axis = {0, 2, 3, 1}; attrs.attr_store["axis"] = axis; std::vector inputs{A.tensor()}; std::vector type{Float(32)}; @@ -456,23 +544,33 @@ TEST(Operator, Operator_Transpose_Test0) { #ifndef CINN_WITH_CUDA using InferLayoutFunction = - std::function>(const std::vector &, - const std::vector &, - const framework::NodeAttr &, - const Target &target)>; - auto infer_layout_func = Operator::GetAttrs("inferlayout")[transpose]; - auto infer_layout = infer_layout_func({{n, c, h, w}}, {"NCHW"}, attrs, target); + std::function>( + const std::vector &, + const std::vector &, + const framework::NodeAttr &, + const Target &target)>; + auto infer_layout_func = + Operator::GetAttrs("inferlayout")[transpose]; + auto infer_layout = + infer_layout_func({{n, c, h, w}}, {"NCHW"}, attrs, target); ASSERT_EQ(infer_layout[0][0], "NHWC"); #endif - auto input_shape = {n, c, h, w}; + auto input_shape = {n, c, h, w}; auto output_shape = {n, h, w, c}; - auto impl = OpStrategy::SelectImpl(strategy[transpose](attrs, inputs, type, {output_shape}, target)); + auto impl = OpStrategy::SelectImpl( + strategy[transpose](attrs, inputs, type, {output_shape}, target)); std::string func_name = "transpose"; - auto module = - LowerToModule("Operator_Transpose_Test0", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + auto module = LowerToModule("Operator_Transpose_Test0", + func_name, + impl, + {"A"}, + "B", + inputs, + {common::CINNValue(A)}, + target); auto jit = backends::ExecutionEngine::Create({}); @@ -481,13 +579,15 @@ TEST(Operator, Operator_Transpose_Test0) { CHECK(fn); auto fn_ = reinterpret_cast(fn); - cinn_buffer_t *A_buf = common::BufferBuilder(Float(32), input_shape).set_random().Build(); - cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), output_shape).set_random().Build(); + cinn_buffer_t *A_buf = + common::BufferBuilder(Float(32), input_shape).set_random().Build(); + cinn_buffer_t *B_buf = + common::BufferBuilder(Float(32), output_shape).set_random().Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf); cinn_pod_value_t args[] = {a_arg, b_arg}; fn_(args, 2); - auto input = reinterpret_cast(A_buf->memory); + auto input = reinterpret_cast(A_buf->memory); auto output = reinterpret_cast(B_buf->memory); for (int idx = 0; idx < n; ++idx) { @@ -505,7 +605,8 @@ TEST(Operator, Operator_Transpose_Test0) { } ASSERT_EQ(impl->name, "strategy.transpose.x86"); - ASSERT_EQ(transpose->description, "This operator implements the meta op transpose."); + ASSERT_EQ(transpose->description, + "This operator implements the meta op transpose."); } } // namespace framework diff --git a/paddle/cinn/hlir/op/op_util.cc b/paddle/cinn/hlir/op/op_util.cc index 4cca01a4fbd5d..52b81542ca0f0 100644 --- a/paddle/cinn/hlir/op/op_util.cc +++ b/paddle/cinn/hlir/op/op_util.cc @@ -26,12 +26,14 @@ DECLARE_bool(cinn_ir_schedule); namespace cinn { namespace hlir { -CINNSchedule GetElementwiseScheduleFunc(const std::vector>& output_shapes, - const Target& target, - bool vectorizable) { +CINNSchedule GetElementwiseScheduleFunc( + const std::vector>& output_shapes, + const Target& target, + bool vectorizable) { return CINNSchedule([=](lang::Args args, lang::RetValue* ret) { if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of ElementwiseSchedule is empty! Please check.\n"; + CHECK(!args.empty()) << "The input argument of ElementwiseSchedule is " + "empty! Please check.\n"; common::CINNValuePack arg_pack = args[0]; std::vector vec_ast; for (int i = 0; i < arg_pack.size(); i++) { @@ -45,31 +47,39 @@ CINNSchedule GetElementwiseScheduleFunc(const std::vector>& out ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); pe::IRElementwiseSchedule(ir_sch, output_shapes.front(), target); - std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = common::CINNValuePack{res}; } else { - CHECK(!args.empty()) << "The input argument of ElementwiseSchedule is empty! Please check.\n"; + CHECK(!args.empty()) << "The input argument of ElementwiseSchedule is " + "empty! Please check.\n"; common::CINNValuePack arg_pack = args[0]; - Expr out = arg_pack[0]; - poly::StageMap stages = arg_pack[1]; + Expr out = arg_pack[0]; + poly::StageMap stages = arg_pack[1]; CHECK(out.as_tensor()); CHECK_EQ(arg_pack.size(), 2UL); if (target.arch == Target::Arch::NVGPU) { - pe::CudaScheduleInjective(stages[out.as_tensor_ref()], output_shapes.front(), target); + pe::CudaScheduleInjective( + stages[out.as_tensor_ref()], output_shapes.front(), target); } else if (target.arch == Target::Arch::X86) { - pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()], output_shapes.front(), target, vectorizable); + pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()], + output_shapes.front(), + target, + vectorizable); } *ret = arg_pack; } }); } -CINNSchedule GetInjectiveScheduleFunc(const std::vector>& output_shapes, - const Target& target, - bool vectorizable) { +CINNSchedule GetInjectiveScheduleFunc( + const std::vector>& output_shapes, + const Target& target, + bool vectorizable) { return CINNSchedule([=](lang::Args args, lang::RetValue* ret) { if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of InjectiveSchedule is empty! Please check.\n"; + CHECK(!args.empty()) << "The input argument of InjectiveSchedule is " + "empty! Please check.\n"; common::CINNValuePack arg_pack = args[0]; std::vector vec_ast; for (int i = 0; i < arg_pack.size(); i++) { @@ -86,21 +96,28 @@ CINNSchedule GetInjectiveScheduleFunc(const std::vector>& outpu /*if (target.arch == Target::Arch::NVGPU) { pe::IRInjectiveSchedule(ir_sch, output_shapes.front(), target); } else if (target.arch == Target::Arch::X86) { - pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, vectorizable); + pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, + vectorizable); }*/ - std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = common::CINNValuePack{res}; } else { - CHECK(!args.empty()) << "The input argument of InjectiveSchedule is empty! Please check.\n"; + CHECK(!args.empty()) << "The input argument of InjectiveSchedule is " + "empty! Please check.\n"; common::CINNValuePack arg_pack = args[0]; - Expr out = arg_pack[0]; - poly::StageMap stages = arg_pack[1]; + Expr out = arg_pack[0]; + poly::StageMap stages = arg_pack[1]; CHECK(out.as_tensor()); CHECK_EQ(arg_pack.size(), 2UL); if (target.arch == Target::Arch::NVGPU) { - pe::CudaScheduleInjective(stages[out.as_tensor_ref()], output_shapes.front(), target); + pe::CudaScheduleInjective( + stages[out.as_tensor_ref()], output_shapes.front(), target); } else if (target.arch == Target::Arch::X86) { - pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()], output_shapes.front(), target, vectorizable); + pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()], + output_shapes.front(), + target, + vectorizable); } *ret = arg_pack; } @@ -123,7 +140,8 @@ std::string GetExternFuncName(const common::Target& target, } else if (target.arch == common::Target::Arch::X86) { func_proto_name.append("host_"); } else { - LOG(FATAL) << func_name << " only supports X86 and NVGPU! Please Check.\n"; + LOG(FATAL) << func_name + << " only supports X86 and NVGPU! Please Check.\n"; } } func_proto_name.append(func_name); @@ -160,7 +178,8 @@ std::string GetExternFuncName(const common::Target& target, } else if (type.is_uint(64)) { func_proto_name.append("uint64"); } else { - LOG(FATAL) << "Can not find type: " << type << " for extern function. Please Check.\n"; + LOG(FATAL) << "Can not find type: " << type + << " for extern function. Please Check.\n"; } return func_proto_name; } diff --git a/paddle/cinn/hlir/op/op_util.h b/paddle/cinn/hlir/op/op_util.h index 1a48117ae8025..082c1f258a042 100644 --- a/paddle/cinn/hlir/op/op_util.h +++ b/paddle/cinn/hlir/op/op_util.h @@ -28,16 +28,22 @@ namespace cinn { namespace hlir { template -T GetAttr(const cinn::utils::AttributeMap &attr_map, const std::string &attr_name) { - CHECK(attr_map.count(attr_name)) << "Cannot found attribute \"" << attr_name << "\""; +T GetAttr(const cinn::utils::AttributeMap &attr_map, + const std::string &attr_name) { + CHECK(attr_map.count(attr_name)) + << "Cannot found attribute \"" << attr_name << "\""; const auto &attr = attr_map.at(attr_name); - CHECK(absl::holds_alternative(attr)) << "The type of attribute \"" << attr_name << "\" isn't " << typeid(T).name(); + CHECK(absl::holds_alternative(attr)) + << "The type of attribute \"" << attr_name << "\" isn't " + << typeid(T).name(); return absl::get(attr_map.at(attr_name)); } template -T SafeGetAttr(const cinn::utils::AttributeMap &attrs, const std::string &key, const T &&value) { +T SafeGetAttr(const cinn::utils::AttributeMap &attrs, + const std::string &key, + const T &&value) { if (attrs.find(key) != attrs.end()) { return GetAttr(attrs, key); } @@ -47,7 +53,10 @@ T SafeGetAttr(const cinn::utils::AttributeMap &attrs, const std::string &key, co template std::vector ToCinnExprs(const std::vector &args) { std::vector exprs; - std::transform(args.begin(), args.end(), std::back_inserter(exprs), [](const T &arg) { return Expr(arg); }); + std::transform( + args.begin(), args.end(), std::back_inserter(exprs), [](const T &arg) { + return Expr(arg); + }); return exprs; } @@ -58,7 +67,8 @@ std::vector ToPodVector(const std::vector &args) { } const auto &type = args.front().type(); - CHECK_EQ(type, common::type_of()) << "Cannot get " << common::type_of() << " value from " << type << " vector!"; + CHECK_EQ(type, common::type_of()) << "Cannot get " << common::type_of() + << " value from " << type << " vector!"; std::vector shape_v; if (type.is_bool()) { @@ -121,20 +131,22 @@ std::vector ToPodVector(const std::vector &args) { using CINNSchedule = lang::PackedFunc; -CINNSchedule GetElementwiseScheduleFunc(const std::vector> &output_shapes, - const Target &target, - bool vectorizable = true); +CINNSchedule GetElementwiseScheduleFunc( + const std::vector> &output_shapes, + const Target &target, + bool vectorizable = true); -CINNSchedule GetInjectiveScheduleFunc(const std::vector> &output_shapes, - const Target &target, - bool vectorizable = true); +CINNSchedule GetInjectiveScheduleFunc( + const std::vector> &output_shapes, + const Target &target, + bool vectorizable = true); std::string GetExternFuncName(const common::Target &target, const common::Type &type, const std::string &func_name, - const bool need_cinn = true, + const bool need_cinn = true, const bool need_target = true, - const bool need_type = true); + const bool need_type = true); } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/op/reduction.cc b/paddle/cinn/hlir/op/reduction.cc index 6e40d3da2d00f..51efb04bc82fb 100644 --- a/paddle/cinn/hlir/op/reduction.cc +++ b/paddle/cinn/hlir/op/reduction.cc @@ -41,20 +41,26 @@ using framework::OpStrategy; using framework::shape_t; using framework::StrategyFunction; -using BlockReduceFunc = std::function( - const ir::Tensor &, const std::vector &, const bool, const std::string &)>; -using ReduceFunc = - std::function &, const bool, const std::string &)>; - -std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target, - const std::string &op_name, - BlockReduceFunc gpu_reduce_with_last_axis_func, - BlockReduceFunc gpu_reduce_without_last_axis_func, - ReduceFunc cpu_reduce_func) { +using BlockReduceFunc = + std::function(const ir::Tensor &, + const std::vector &, + const bool, + const std::string &)>; +using ReduceFunc = std::function &, + const bool, + const std::string &)>; + +std::shared_ptr StrategyForReduce( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target, + const std::string &op_name, + BlockReduceFunc gpu_reduce_with_last_axis_func, + BlockReduceFunc gpu_reduce_without_last_axis_func, + ReduceFunc cpu_reduce_func) { std::vector reduce_axes; auto ndim = inputs[0]->shape.size(); if (attrs.attr_store.count("dim")) { @@ -84,7 +90,8 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, keep_dim = absl::get(attrs.attr_store.at("keep_dim")); } - auto WithoutLastDimInReduce = [](const std::vector &inshape, const std::vector &axes) { + auto WithoutLastDimInReduce = [](const std::vector &inshape, + const std::vector &axes) { // if last axis is in reduce. if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || std::find(axes.begin(), axes.end(), -1) != axes.end()) { @@ -103,62 +110,71 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, } }; - framework::CINNCompute reduction_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check."; - CINNValuePack arg_packs = args[0]; - std::string tensor_name = UniqName(op_name + "_out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(arg_packs.size(), 2U) << "There should be 2 input args for " << op_name << " compute"; - CHECK(arg_packs[1].is_string()); - tensor_name = arg_packs[1].operator std::string(); - } else { - CHECK_EQ(arg_packs.size(), 1U) << "There should be 1 input args for " << op_name << " compute"; - } - Expr x_expr = arg_packs[0]; - CHECK(x_expr.as_tensor()); - ir::Tensor x = x_expr.as_tensor_ref(); - - std::unordered_set bool_reduce_op = {"reduce_all", "reduce_any"}; - CHECK(!bool_reduce_op.count(op_name) || x->type().is_bool()) - << "The type of input argument " << x->name << " of " << op_name << " should be bool, but get " << x->type() - << "! Please check."; - - if (target == common::DefaultNVGPUTarget()) { - if (!WithoutLastDimInReduce(inputs[0]->shape, reduce_axes)) { - VLOG(3) << "Do Two Step Block Reduce Compute!"; - auto res = gpu_reduce_with_last_axis_func(x, reduce_axes, keep_dim, tensor_name); - auto stages = CreateStages(res); - - std::vector cinn_values; - for (auto &t : res) { - cinn_values.emplace_back(t); + framework::CINNCompute reduction_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of " << op_name + << " compute is empty! Please check."; + CINNValuePack arg_packs = args[0]; + std::string tensor_name = UniqName(op_name + "_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(arg_packs.size(), 2U) + << "There should be 2 input args for " << op_name << " compute"; + CHECK(arg_packs[1].is_string()); + tensor_name = arg_packs[1].operator std::string(); + } else { + CHECK_EQ(arg_packs.size(), 1U) + << "There should be 1 input args for " << op_name << " compute"; } - cinn_values.emplace_back(stages); - *ret = CINNValuePack{cinn_values}; - } else { - VLOG(3) << "Do Block Shuffle Reduce Compute!"; - auto res = gpu_reduce_without_last_axis_func(x, reduce_axes, keep_dim, tensor_name); - auto stages = CreateStages(res); + Expr x_expr = arg_packs[0]; + CHECK(x_expr.as_tensor()); + ir::Tensor x = x_expr.as_tensor_ref(); + + std::unordered_set bool_reduce_op = {"reduce_all", + "reduce_any"}; + CHECK(!bool_reduce_op.count(op_name) || x->type().is_bool()) + << "The type of input argument " << x->name << " of " << op_name + << " should be bool, but get " << x->type() << "! Please check."; + + if (target == common::DefaultNVGPUTarget()) { + if (!WithoutLastDimInReduce(inputs[0]->shape, reduce_axes)) { + VLOG(3) << "Do Two Step Block Reduce Compute!"; + auto res = gpu_reduce_with_last_axis_func( + x, reduce_axes, keep_dim, tensor_name); + auto stages = CreateStages(res); + + std::vector cinn_values; + for (auto &t : res) { + cinn_values.emplace_back(t); + } + cinn_values.emplace_back(stages); + *ret = CINNValuePack{cinn_values}; + } else { + VLOG(3) << "Do Block Shuffle Reduce Compute!"; + auto res = gpu_reduce_without_last_axis_func( + x, reduce_axes, keep_dim, tensor_name); + auto stages = CreateStages(res); + + std::vector cinn_values; + for (auto &t : res) { + cinn_values.emplace_back(t); + } + cinn_values.emplace_back(stages); + *ret = CINNValuePack{cinn_values}; + } + } else { + VLOG(3) << "Do Reduce Compute!"; + auto out = cpu_reduce_func(x, reduce_axes, keep_dim, tensor_name); + auto stages = CreateStages({out}); - std::vector cinn_values; - for (auto &t : res) { - cinn_values.emplace_back(t); + std::vector cinn_values{CINNValue(out), CINNValue(stages)}; + *ret = CINNValuePack{cinn_values}; } - cinn_values.emplace_back(stages); - *ret = CINNValuePack{cinn_values}; - } - } else { - VLOG(3) << "Do Reduce Compute!"; - auto out = cpu_reduce_func(x, reduce_axes, keep_dim, tensor_name); - auto stages = CreateStages({out}); - - std::vector cinn_values{CINNValue(out), CINNValue(stages)}; - *ret = CINNValuePack{cinn_values}; - } - }); + }); - framework::CINNSchedule reduction_schedule([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of " << op_name << " schedule is empty! Please check."; + framework::CINNSchedule reduction_schedule([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of " << op_name + << " schedule is empty! Please check."; CINNValuePack arg_pack = args[0]; if (FLAGS_cinn_ir_schedule) { @@ -190,32 +206,38 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, if (!WithoutLastDimInReduce(inputs[0]->shape, reduce_axes)) { if (arg_pack.size() == 4) { CHECK_EQ(vec_tensor.size(), 2); - Expr out = vec_tensor[0]; + Expr out = vec_tensor[0]; Expr tmp_out = vec_tensor[1]; VLOG(3) << "Do IRCudaScheduleBlockReduceInternal Schedule!"; - pe::IRCudaScheduleBlockReduceInternal(ir_sch, tmp_out.as_tensor_ref(), out.as_tensor_ref(), target); + pe::IRCudaScheduleBlockReduceInternal( + ir_sch, tmp_out.as_tensor_ref(), out.as_tensor_ref(), target); - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else if (arg_pack.size() == 6) { CHECK_EQ(vec_tensor.size(), 3); - Expr out = vec_tensor[0]; - Expr tmp_out = vec_tensor[1]; + Expr out = vec_tensor[0]; + Expr tmp_out = vec_tensor[1]; Expr reduce_tmp_out = vec_tensor[2]; VLOG(3) << "Do IRCudaScheduleBlockReduce Schedule!"; - pe::IRCudaScheduleBlockReduce( - ir_sch, reduce_tmp_out.as_tensor_ref(), tmp_out.as_tensor_ref(), out.as_tensor_ref(), target); + pe::IRCudaScheduleBlockReduce(ir_sch, + reduce_tmp_out.as_tensor_ref(), + tmp_out.as_tensor_ref(), + out.as_tensor_ref(), + target); - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else if (arg_pack.size() == 7) { CHECK_EQ(vec_tensor.size(), 4); - Expr out = vec_tensor[0]; - Expr tmp_out = vec_tensor[1]; + Expr out = vec_tensor[0]; + Expr tmp_out = vec_tensor[1]; Expr reduce_tmp_out = vec_tensor[2]; - Expr reshape = vec_tensor[3]; + Expr reshape = vec_tensor[3]; VLOG(3) << "Do IRCudaTwoStepReduceSchedule Schedule!"; pe::IRCudaTwoStepReduceSchedule(ir_sch, @@ -225,12 +247,13 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, out.as_tensor_ref(), common::DefaultNVGPUTarget()); - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else if (arg_pack.size() == 5) { CHECK_EQ(vec_tensor.size(), 3); - Expr out = vec_tensor[0]; - Expr tmp_out = vec_tensor[1]; + Expr out = vec_tensor[0]; + Expr tmp_out = vec_tensor[1]; Expr reduce_tmp_out = vec_tensor[2]; VLOG(3) << "Do IRCudaScheduleBlockReduce Schedule!"; @@ -240,7 +263,8 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, out.as_tensor_ref(), common::DefaultNVGPUTarget()); - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else { LOG(FATAL) << "Unkown Reduce Type!"; @@ -252,31 +276,38 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, VLOG(3) << "Do IRCudaScheduleReduce Schedule!"; pe::IRCudaScheduleReduce( - ir_sch, reduce_out.as_tensor_ref(), inputs[0]->shape.size() - reduce_axes.back() - 1, target); + ir_sch, + reduce_out.as_tensor_ref(), + inputs[0]->shape.size() - reduce_axes.back() - 1, + target); - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else if (arg_pack.size() == 6) { CHECK_EQ(vec_tensor.size(), 3); - Expr reduce_out = vec_tensor[0]; + Expr reduce_out = vec_tensor[0]; Expr reduce_internal = vec_tensor[1]; - Expr reduce_reshape = vec_tensor[2]; + Expr reduce_reshape = vec_tensor[2]; VLOG(3) << "Do IRCudaScheduleBlockShuffleReduce Schedule!"; - pe::IRCudaScheduleBlockShuffleReduce(ir_sch, - reduce_reshape.as_tensor_ref(), - reduce_internal.as_tensor_ref(), - reduce_out.as_tensor_ref(), - target); - - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + pe::IRCudaScheduleBlockShuffleReduce( + ir_sch, + reduce_reshape.as_tensor_ref(), + reduce_internal.as_tensor_ref(), + reduce_out.as_tensor_ref(), + target); + + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else { LOG(FATAL) << "Unkown Reduce Type!"; } } } else { - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } } else { @@ -285,16 +316,18 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, if (target.arch == Target::Arch::NVGPU) { if (!WithoutLastDimInReduce(inputs[0]->shape, reduce_axes)) { if (arg_pack.size() == 3) { - Expr out = arg_pack[0]; - Expr tmp_out = arg_pack[1]; + Expr out = arg_pack[0]; + Expr tmp_out = arg_pack[1]; poly::StageMap stages = arg_pack.back(); VLOG(3) << "Do CudaBlockReduceInternalSchedule Schedule!"; - pe::CudaBlockReduceInternalSchedule( - stages, tmp_out.as_tensor_ref(), out.as_tensor_ref(), common::DefaultNVGPUTarget()); + pe::CudaBlockReduceInternalSchedule(stages, + tmp_out.as_tensor_ref(), + out.as_tensor_ref(), + common::DefaultNVGPUTarget()); } else if (arg_pack.size() == 4) { - Expr out = arg_pack[0]; - Expr tmp_out = arg_pack[1]; - Expr reduce_tmp_out = arg_pack[2]; + Expr out = arg_pack[0]; + Expr tmp_out = arg_pack[1]; + Expr reduce_tmp_out = arg_pack[2]; poly::StageMap stages = arg_pack.back(); VLOG(3) << "Do CudaBlockReduceSchedule Schedule!"; pe::CudaBlockReduceSchedule(stages, @@ -303,10 +336,10 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, out.as_tensor_ref(), common::DefaultNVGPUTarget()); } else { - Expr out = arg_pack[0]; - Expr tmp_out = arg_pack[1]; - Expr reduce_tmp_out = arg_pack[2]; - Expr reshape = arg_pack[3]; + Expr out = arg_pack[0]; + Expr tmp_out = arg_pack[1]; + Expr reduce_tmp_out = arg_pack[2]; + Expr reshape = arg_pack[3]; poly::StageMap stages = arg_pack.back(); VLOG(3) << "Do CudaTwoStepReduceSchedule Schedule!"; pe::CudaTwoStepReduceSchedule(stages, @@ -318,16 +351,19 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, } } else { if (arg_pack.size() == 2) { - Expr reduce_out = arg_pack[0]; + Expr reduce_out = arg_pack[0]; poly::StageMap stages = arg_pack.back(); VLOG(3) << "Do CudaReduceSchedule Schedule!"; pe::CudaReduceSchedule( - stages, reduce_out.as_tensor_ref(), inputs[0]->shape.size() - reduce_axes.back() - 1, target); + stages, + reduce_out.as_tensor_ref(), + inputs[0]->shape.size() - reduce_axes.back() - 1, + target); } else { CHECK_EQ(arg_pack.size(), 4) << "args is not equal 4!"; - Expr reduce_reshape = arg_pack[2]; - Expr reduce_internal = arg_pack[1]; - Expr reduce_out = arg_pack[0]; + Expr reduce_reshape = arg_pack[2]; + Expr reduce_internal = arg_pack[1]; + Expr reduce_out = arg_pack[0]; poly::StageMap stages = arg_pack.back(); VLOG(3) << "Do CudaBlockShuffleReduceSchedule Schedule!"; pe::CudaBlockShuffleReduceSchedule(stages, @@ -343,40 +379,70 @@ std::shared_ptr StrategyForReduce(const framework::NodeAttr &attrs, }); auto strategy = std::make_shared(); - strategy->AddImpl(reduction_compute, reduction_schedule, "strategy." + op_name + ".x86", 1); + strategy->AddImpl( + reduction_compute, reduction_schedule, "strategy." + op_name + ".x86", 1); return strategy; } -#define STRATEGY_FOR_REDUCE( \ - op_name_, reduce_op_, gpu_reduce_with_last_axis_func, gpu_reduce_without_last_axis_func, cpu_reduce_func) \ - std::shared_ptr StrategyFor##reduce_op_(const framework::NodeAttr &attrs, \ - const std::vector &inputs, \ - const std::vector &out_type, \ - const std::vector> &output_shapes, \ - const Target &target) { \ - return StrategyForReduce(attrs, \ - inputs, \ - out_type, \ - output_shapes, \ - target, \ - #op_name_, \ - gpu_reduce_with_last_axis_func, \ - gpu_reduce_without_last_axis_func, \ - cpu_reduce_func); \ +#define STRATEGY_FOR_REDUCE(op_name_, \ + reduce_op_, \ + gpu_reduce_with_last_axis_func, \ + gpu_reduce_without_last_axis_func, \ + cpu_reduce_func) \ + std::shared_ptr StrategyFor##reduce_op_( \ + const framework::NodeAttr &attrs, \ + const std::vector &inputs, \ + const std::vector &out_type, \ + const std::vector> &output_shapes, \ + const Target &target) { \ + return StrategyForReduce(attrs, \ + inputs, \ + out_type, \ + output_shapes, \ + target, \ + #op_name_, \ + gpu_reduce_with_last_axis_func, \ + gpu_reduce_without_last_axis_func, \ + cpu_reduce_func); \ } -STRATEGY_FOR_REDUCE(reduce_sum, ReduceSum, pe::TwoStepBlockReduceSum, pe::BlockShuffleReduceSum, pe::ReduceSum); -STRATEGY_FOR_REDUCE(reduce_prod, ReduceProd, pe::TwoStepBlockReduceProd, pe::BlockShuffleReduceProd, pe::ReduceProd); -STRATEGY_FOR_REDUCE(reduce_max, ReduceMax, pe::TwoStepBlockReduceMax, pe::BlockShuffleReduceMax, pe::ReduceMax); -STRATEGY_FOR_REDUCE(reduce_min, ReduceMin, pe::TwoStepBlockReduceMin, pe::BlockShuffleReduceMin, pe::ReduceMin); -STRATEGY_FOR_REDUCE(reduce_all, ReduceAll, pe::TwoStepBlockReduceAll, pe::BlockShuffleReduceAll, pe::ReduceAll); -STRATEGY_FOR_REDUCE(reduce_any, ReduceAny, pe::TwoStepBlockReduceAny, pe::BlockShuffleReduceAny, pe::ReduceAny); +STRATEGY_FOR_REDUCE(reduce_sum, + ReduceSum, + pe::TwoStepBlockReduceSum, + pe::BlockShuffleReduceSum, + pe::ReduceSum); +STRATEGY_FOR_REDUCE(reduce_prod, + ReduceProd, + pe::TwoStepBlockReduceProd, + pe::BlockShuffleReduceProd, + pe::ReduceProd); +STRATEGY_FOR_REDUCE(reduce_max, + ReduceMax, + pe::TwoStepBlockReduceMax, + pe::BlockShuffleReduceMax, + pe::ReduceMax); +STRATEGY_FOR_REDUCE(reduce_min, + ReduceMin, + pe::TwoStepBlockReduceMin, + pe::BlockShuffleReduceMin, + pe::ReduceMin); +STRATEGY_FOR_REDUCE(reduce_all, + ReduceAll, + pe::TwoStepBlockReduceAll, + pe::BlockShuffleReduceAll, + pe::ReduceAll); +STRATEGY_FOR_REDUCE(reduce_any, + ReduceAny, + pe::TwoStepBlockReduceAny, + pe::BlockShuffleReduceAny, + pe::ReduceAny); #undef STRATEGY_FOR_REDUCE -std::vector InferShapeForReduction(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForReduction( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { CHECK(inputs_shape.size() == 1UL || inputs_shape.size() == 3UL); std::vector dim; bool keep_dim = false; @@ -416,57 +482,69 @@ std::vector InferShapeForReduction(const std::vector &inputs_s out_shapes.push_back(1); } - VLOG(4) << "Reduce from input shape [" << cinn::utils::Join(inputs_shape[0], ",") << "] to output shape [" - << cinn::utils::Join(out_shapes, ",") << "] with reduce dim [" << cinn::utils::Join(dim, ",") - << "] and keep_dim is " << keep_dim; + VLOG(4) << "Reduce from input shape [" + << cinn::utils::Join(inputs_shape[0], ",") << "] to output shape [" + << cinn::utils::Join(out_shapes, ",") << "] with reduce dim [" + << cinn::utils::Join(dim, ",") << "] and keep_dim is " << keep_dim; return {out_shapes}; } -std::vector InferDtypeForReduction(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForReduction(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::vector InferDtypeForReductionBool(const std::vector &inputs_type, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 1UL) << "The reduce should only has one input! Please check again."; - CHECK(inputs_type[0].is_bool()) << "The input's type should be bool! Please check."; +std::vector InferDtypeForReductionBool( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 1UL) + << "The reduce should only has one input! Please check again."; + CHECK(inputs_type[0].is_bool()) + << "The input's type should be bool! Please check."; return inputs_type; } -std::vector> InferLayoutForReduction(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 1U) << "The input's layouts size is not 1! Please check again."; +std::vector> InferLayoutForReduction( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 1U) + << "The input's layouts size is not 1! Please check again."; std::vector new_input_layouts = input_layouts; if (input_shapes[0].size() > 4) { // alter input layout back new_input_layouts[0] = "NCHW"; - VLOG(3) << "alter input layout from " << input_layouts[0] << " to " << new_input_layouts[0]; + VLOG(3) << "alter input layout from " << input_layouts[0] << " to " + << new_input_layouts[0]; } return {{""}, new_input_layouts}; } -std::vector InferShapeForBnOptimize(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForBnOptimize( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { auto shapes = InferShapeForReduction(inputs_shape, attrs); CHECK_GE(shapes.size(), 1) << "shapes's size less than 1, please check!"; return {shapes[0], shapes[0]}; } -std::vector InferDtypeForBnOptimize(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForBnOptimize(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; return {inputs_type[0], inputs_type[0]}; } -std::vector> InferLayoutForBnOptimize(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector> InferLayoutForBnOptimize( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { return {{"", ""}, {"", ""}}; } @@ -475,19 +553,26 @@ std::vector> InferLayoutForBnOptimize(const std::vector } // namespace cinn CINN_REGISTER_HELPER(reduce_ops) { -#define CINN_REGISTER_REDUCTION_WITH_DTYPE(op__, op_stragegy__, dtype__) \ - CINN_REGISTER_OP(op__) \ - .describe(#op__ " function") \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr("CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForReduction)) \ - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForReduction##dtype__)) \ - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForReduction)) \ - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kReduction) \ +#define CINN_REGISTER_REDUCTION_WITH_DTYPE(op__, op_stragegy__, dtype__) \ + CINN_REGISTER_OP(op__) \ + .describe(#op__ " function") \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr( \ + "CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ + .set_attr("infershape", \ + MakeOpFunction(cinn::hlir::op::InferShapeForReduction)) \ + .set_attr( \ + "inferdtype", \ + MakeOpFunction(cinn::hlir::op::InferDtypeForReduction##dtype__)) \ + .set_attr("inferlayout", \ + MakeOpFunction(cinn::hlir::op::InferLayoutForReduction)) \ + .set_attr( \ + "OpPattern", cinn::hlir::framework::OpPatternKind::kReduction) \ .set_support_level(4); -#define CINN_REGISTER_REDUCTION(op__, op_stragegy__) CINN_REGISTER_REDUCTION_WITH_DTYPE(op__, op_stragegy__, ) +#define CINN_REGISTER_REDUCTION(op__, op_stragegy__) \ + CINN_REGISTER_REDUCTION_WITH_DTYPE(op__, op_stragegy__, ) CINN_REGISTER_REDUCTION(reduce_sum, ReduceSum); CINN_REGISTER_REDUCTION(reduce_prod, ReduceProd); diff --git a/paddle/cinn/hlir/op/reduction_test.cc b/paddle/cinn/hlir/op/reduction_test.cc index f6f30c75b79d4..e6611d6484af5 100644 --- a/paddle/cinn/hlir/op/reduction_test.cc +++ b/paddle/cinn/hlir/op/reduction_test.cc @@ -52,15 +52,17 @@ using framework::shape_t; using framework::StrategyFunction; using runtime::cuda::CUDAModule; -std::pair GenReduceCode(const std::vector& shape, - const std::vector& dim, - const std::string& func_name, - bool keep_dim = false, - const std::string& op_name = "reduce_sum") { +std::pair GenReduceCode( + const std::vector& shape, + const std::vector& dim, + const std::string& func_name, + bool keep_dim = false, + const std::string& op_name = "reduce_sum") { // code gen Context::Global().ResetNameId(); auto reduce_sum = Operator::Get(op_name); - auto strategy = Operator::GetAttrs("CINNStrategy")[reduce_sum]; + auto strategy = + Operator::GetAttrs("CINNStrategy")[reduce_sum]; // input tensor std::vector shape_as_expr; @@ -71,7 +73,7 @@ std::pair GenReduceCode(const std::vector& shape, // set attrs NodeAttr attrs; - attrs.attr_store["dim"] = dim; + attrs.attr_store["dim"] = dim; attrs.attr_store["keep_dim"] = keep_dim; std::vector inputs{X.tensor()}; std::vector out_type{Float(32)}; @@ -88,28 +90,33 @@ std::pair GenReduceCode(const std::vector& shape, } auto target = common::DefaultNVGPUTarget(); - auto impl = OpStrategy::SelectImpl(strategy(attrs, inputs, out_type, {output_shape}, target)); + auto impl = OpStrategy::SelectImpl( + strategy(attrs, inputs, out_type, {output_shape}, target)); std::vector func; if (!FLAGS_cinn_ir_schedule) { - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(X)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - rets = impl->fschedule(rets); - poly::StageMap stages = rets.back(); + common::CINNValuePack cinn_input = + common::CINNValuePack{{common::CINNValue(X)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); + rets = impl->fschedule(rets); + poly::StageMap stages = rets.back(); // the last element is a StageMap for (int i = 0; i < rets->size() - 1; i++) { Expr temp = rets[i]; - if (!temp.as_tensor_ref()->buffer.defined() && !stages[temp.as_tensor_ref()]->inlined()) { + if (!temp.as_tensor_ref()->buffer.defined() && + !stages[temp.as_tensor_ref()]->inlined()) { inputs.push_back(temp.as_tensor_ref()); } } - func = lang::LowerVec(func_name, rets.back(), inputs, {}, {}, nullptr, target); + func = + lang::LowerVec(func_name, rets.back(), inputs, {}, {}, nullptr, target); } else { std::vector input_output_nodes{"X", op_name}; func = GetFuncFromImpl(impl, - common::CINNValuePack{{common::CINNValue(X), common::CINNValue(op_name)}}, + common::CINNValuePack{{common::CINNValue(X), + common::CINNValue(op_name)}}, inputs, input_output_nodes, func_name, @@ -122,11 +129,13 @@ std::pair GenReduceCode(const std::vector& shape, } // compile the module // Need to create a new compiler for every call of Build, - // because the underneath jit engine does't support addIRModule repeatedly now. - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); // NOLINT - auto& host_module = std::get<0>(host_module_device_module); - auto& device_module = std::get<1>(host_module_device_module); + // because the underneath jit engine does't support addIRModule repeatedly + // now. + auto module = builder.Build(); + auto host_module_device_module = + backends::SplitCudaAndHostModule(module); // NOLINT + auto& host_module = std::get<0>(host_module_device_module); + auto& device_module = std::get<1>(host_module_device_module); backends::CodeGenCUDA_Dev codegen(target); std::string source_code; @@ -143,7 +152,7 @@ std::pair GenReduceCode(const std::vector& shape, // last dimension not in reduce TEST(Operator, Operator_Reduce_Without_Last_Channel_Case_5) { std::vector shape = {128, 112, 112, 128}; - std::vector dim = {0, 1, 2}; + std::vector dim = {0, 1, 2}; GenReduceCode(shape, dim, "Reduce_Without_Last_Channel_Case_5"); } @@ -151,140 +160,140 @@ TEST(Operator, Operator_Reduce_Without_Last_Channel_Case_5) { // last dimension not in reduce TEST(Operator, Operator_Reduce_Without_Last_Channel_Case_4) { std::vector shape = {16, 16, 8, 8, 16, 16}; - std::vector dim = {0, 2, 3}; + std::vector dim = {0, 2, 3}; GenReduceCode(shape, dim, "Reduce_Without_Last_Channel_Case_4"); } // case 3 TEST(Operator, Operator_Reduce_Without_Last_Channel_Case_3) { std::vector shape = {16, 16, 16, 16, 16}; - std::vector dim = {0, 2}; + std::vector dim = {0, 2}; GenReduceCode(shape, dim, "Reduce_Without_Last_Channel_Case_3"); } // case 2 TEST(Operator, Operator_Reduce_Without_Last_Channel_Case_2) { std::vector shape = {16, 16, 16, 16}; - std::vector dim = {0, 1}; + std::vector dim = {0, 1}; GenReduceCode(shape, dim, "Reduce_Without_Last_Channel_Case_2"); } // case 1 TEST(Operator, Operator_Reduce_Without_Last_Channel_Case_1) { std::vector shape = {16, 16, 16, 16}; - std::vector dim = {1}; + std::vector dim = {1}; GenReduceCode(shape, dim, "Reduce_Without_Last_Channel_Case_1"); } // case 0 TEST(Operator, Operator_Reduce_Without_Last_Channel_Case_0) { std::vector shape = {16, 16, 32}; - std::vector dim = {1}; + std::vector dim = {1}; GenReduceCode(shape, dim, "Reduce_Without_Last_Channel_Case_0"); } TEST(Operator, Operator_Reduction_Case_Last_Dim_1) { std::vector shape = {10, 100, 1}; - std::vector dim = {0, 2}; + std::vector dim = {0, 2}; GenReduceCode(shape, dim, "reduce_cast_with_last_dim_1"); } TEST(Operator, Operator_Reduction_Case_0) { std::vector shape = {16, 16, 8, 16}; - std::vector dim = {2, 3}; + std::vector dim = {2, 3}; GenReduceCode(shape, dim, "reduce_cast_0"); } TEST(Operator, Operator_Reduction_Case_0_0) { std::vector shape = {16, 16, 8, 16}; - std::vector dim = {2, 3}; + std::vector dim = {2, 3}; GenReduceCode(shape, dim, "reduce_cast_0_0", true); } TEST(Operator, Operator_Reduction_Case_1) { std::vector shape = {16, 16, 32, 32}; - std::vector dim = {2, 3}; + std::vector dim = {2, 3}; GenReduceCode(shape, dim, "reduce_cast_1"); } TEST(Operator, Operator_Reduction_Case_1_1) { std::vector shape = {16, 16, 32, 32}; - std::vector dim = {2, 3}; + std::vector dim = {2, 3}; GenReduceCode(shape, dim, "reduce_cast_1_1", true); } TEST(Operator, Operator_Reduction_Case_2) { std::vector shape = {16, 16, 32, 32}; - std::vector dim = {1}; + std::vector dim = {1}; GenReduceCode(shape, dim, "reduce_cast_2", true); } TEST(Operator, Operator_Reduction_Case_2_1) { std::vector shape = {16, 16, 32, 32}; - std::vector dim = {-1}; + std::vector dim = {-1}; GenReduceCode(shape, dim, "reduce_cast_2_1", true); } TEST(Operator, Operator_Reduction_Case_3) { std::vector shape = {16, 16, 64, 64}; - std::vector dim = {1}; + std::vector dim = {1}; GenReduceCode(shape, dim, "reduce_cast_3"); } TEST(Operator, Operator_Reduction_Case_4) { std::vector shape = {16, 16, 16, 16}; - std::vector dim = {0, 2, 3}; + std::vector dim = {0, 2, 3}; GenReduceCode(shape, dim, "reduce_cast_4"); } TEST(Operator, Operator_Reduction_Case_4_4) { std::vector shape = {16, 16, 16, 16}; - std::vector dim = {0, 2, 3}; + std::vector dim = {0, 2, 3}; GenReduceCode(shape, dim, "reduce_cast_4_4", true); } TEST(Operator, Operator_Reduction_Case_5) { std::vector shape = {16, 16, 16, 16, 16, 32}; - std::vector dim = {1, 3, 5}; + std::vector dim = {1, 3, 5}; GenReduceCode(shape, dim, "reduce_cast_5"); } TEST(Operator, Operator_Reduction_Case_5_5) { std::vector shape = {16, 16, 16, 16, 16, 32}; - std::vector dim = {1, 3, 5}; + std::vector dim = {1, 3, 5}; GenReduceCode(shape, dim, "reduce_cast_5_5", true); } TEST(Operator, Operator_Reduction_Case_6_0) { std::vector shape = {32, 32, 32}; - std::vector dim = {0, 1, 2}; + std::vector dim = {0, 1, 2}; GenReduceCode(shape, dim, "reduce_cast_6_0", false); } TEST(Operator, Operator_Reduction_Case_6_00) { std::vector shape = {32, 32, 32, 32}; - std::vector dim = {0, 1, 2}; + std::vector dim = {0, 1, 2}; GenReduceCode(shape, dim, "reduce_cast_6_00", false); } TEST(Operator, Operator_Reduction_Case_6_10) { std::vector shape = {32, 32, 32}; - std::vector dim = {-2, -1, 0}; + std::vector dim = {-2, -1, 0}; GenReduceCode(shape, dim, "reduce_cast_6_10", true); } @@ -296,10 +305,14 @@ struct ProdOp { float operator()(const float left, const float right) { return left * right; } }; struct MaxOp { - float operator()(const float left, const float right) { return std::max(left, right); } + float operator()(const float left, const float right) { + return std::max(left, right); + } }; struct MinOp { - float operator()(const float left, const float right) { return std::min(left, right); } + float operator()(const float left, const float right) { + return std::min(left, right); + } }; template @@ -322,8 +335,11 @@ void DoCpuReduce(const float* x, for (int idy = 0; idy < c; ++idy) { for (int idz = 0; idz < h; ++idz) { for (int ida = 0; ida < w; ++ida) { - sum0->at(idy * w + ida) += Op()(sum0->at(idy * w + ida), x[idx * c * h * w + idy * h * w + idz * w + ida]); - sum1->at(idy) = Op()(sum1->at(idy), x[idx * c * h * w + idy * h * w + idz * w + ida]); + sum0->at(idy * w + ida) += + Op()(sum0->at(idy * w + ida), + x[idx * c * h * w + idy * h * w + idz * w + ida]); + sum1->at(idy) = Op()( + sum1->at(idy), x[idx * c * h * w + idy * h * w + idz * w + ida]); } } } @@ -331,13 +347,19 @@ void DoCpuReduce(const float* x, } template -void TestCaseForReduce( - const float init_val, int n, int c, int h, int w, const std::string& test_name, const std::string& op_name) { +void TestCaseForReduce(const float init_val, + int n, + int c, + int h, + int w, + const std::string& test_name, + const std::string& op_name) { std::vector shape = {n, c, h, w}; - std::vector dim = {0, 2, 3}; + std::vector dim = {0, 2, 3}; // get source code - auto source_code = GenReduceCode(shape, dim, test_name, false, op_name).second; + auto source_code = + GenReduceCode(shape, dim, test_name, false, op_name).second; // nv jit compile to ptx backends::nvrtc::Compiler compiler; @@ -350,36 +372,48 @@ void TestCaseForReduce( srand(time(NULL)); CUDA_CALL(cudaSetDevice(0)); - // auto func_0 = reinterpret_cast(fn_reduce_sum); - auto buffer_x = common::BufferBuilder(Float(32), {n, c, h, w}).set_random().Build(); + // auto func_0 = reinterpret_cast(fn_reduce_sum); + auto buffer_x = + common::BufferBuilder(Float(32), {n, c, h, w}).set_random().Build(); auto buffer_z = common::BufferBuilder(Float(32), {c}).set_random().Build(); void *dev_x = nullptr, *dev_z = nullptr; CUDA_CALL(cudaMalloc(&dev_x, buffer_x->memory_size)); CUDA_CALL(cudaMalloc(&dev_z, buffer_z->memory_size)); - CUDA_CALL(cudaMemcpy(dev_x, buffer_x->memory, buffer_x->memory_size, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy( + dev_x, buffer_x->memory, buffer_x->memory_size, cudaMemcpyHostToDevice)); dim3 grid; dim3 block; if (!FLAGS_cinn_ir_schedule) { - grid = {n * c, 1, 1}; + grid = {n * c, 1, 1}; block = {h * w, 1, 1}; } else { - grid = {c, 1, 1}; + grid = {c, 1, 1}; int block_dim_x = n * w * h > 1024 ? 1024 : n * w * h; - block = {block_dim_x, 1, 1}; + block = {block_dim_x, 1, 1}; } - void* args[] = {&dev_x, &dev_z}; + void* args[] = {&dev_x, &dev_z}; std::string new_test_name = test_name; if (FLAGS_cinn_ir_schedule) new_test_name = "fn_" + new_test_name + "_kernel"; cuda_module.LaunchKernel(0, new_test_name, grid, block, args); - CUDA_CALL(cudaMemcpy(buffer_z->memory, dev_z, buffer_z->memory_size, cudaMemcpyDeviceToHost)); + CUDA_CALL(cudaMemcpy( + buffer_z->memory, dev_z, buffer_z->memory_size, cudaMemcpyDeviceToHost)); std::vector sum0(c * w); std::vector sum1(c); - DoCpuReduce(reinterpret_cast(buffer_x->memory), &sum0, &sum1, init_val, n, c, h, w); - - std::vector, float*>> results = {{sum1, reinterpret_cast(buffer_z->memory)}}; + DoCpuReduce(reinterpret_cast(buffer_x->memory), + &sum0, + &sum1, + init_val, + n, + c, + h, + w); + + std::vector, float*>> results = { + {sum1, reinterpret_cast(buffer_z->memory)}}; for (auto& res : results) { for (int idx = 0; idx < res.first.size(); ++idx) { ASSERT_LT(abs(res.first[idx] - res.second[idx]) / res.first[idx], 1e-4); @@ -391,21 +425,25 @@ void TestCaseForReduce( } TEST(Operator, Operator_Reduction_Case_6_1) { - TestCaseForReduce(0.0f, 32, 32, 32, 32, "Operator_Reduction_Case_6_1", "reduce_sum"); + TestCaseForReduce( + 0.0f, 32, 32, 32, 32, "Operator_Reduction_Case_6_1", "reduce_sum"); } TEST(Operator, Operator_Reduction_Case_6_2) { - TestCaseForReduce(1.0f, 1, 1, 1, 32, "Operator_Reduction_Case_6_2", "reduce_prod"); + TestCaseForReduce( + 1.0f, 1, 1, 1, 32, "Operator_Reduction_Case_6_2", "reduce_prod"); } TEST(Operator, Operator_Reduction_Case_6_3) { - TestCaseForReduce(-1e38f, 32, 32, 32, 32, "Operator_Reduction_Case_6_3", "reduce_max"); + TestCaseForReduce( + -1e38f, 32, 32, 32, 32, "Operator_Reduction_Case_6_3", "reduce_max"); } TEST(Operator, Operator_Reduction_Case_6_4) { - TestCaseForReduce(1e38f, 32, 32, 32, 32, "Operator_Reduction_Case_6_4", "reduce_min"); + TestCaseForReduce( + 1e38f, 32, 32, 32, 32, "Operator_Reduction_Case_6_4", "reduce_min"); } TEST(Operator, Operator_Reduction_Case_7) { int n = 32, c = 32, h = 16, w = 16; std::vector shape = {n, c, h, w}; - std::vector dim = {0, 1}; + std::vector dim = {0, 1}; std::string func_name = "reduce_cast_7"; // get source code @@ -418,16 +456,19 @@ TEST(Operator, Operator_Reduction_Case_7) { // load ptx CUDA_CALL(cudaSetDevice(0)); - runtime::cuda::CUDAModule cuda_module(ptx, runtime::cuda::CUDAModule::Kind::PTX); + runtime::cuda::CUDAModule cuda_module(ptx, + runtime::cuda::CUDAModule::Kind::PTX); std::string new_func_name = func_name; if (FLAGS_cinn_ir_schedule) new_func_name = "fn_" + new_func_name; - void* reduce_sum_kernel = cuda_module.GetFunction(0, new_func_name + "_kernel"); + void* reduce_sum_kernel = + cuda_module.GetFunction(0, new_func_name + "_kernel"); CHECK(reduce_sum_kernel); // register cufunction and stream void* stream = nullptr; - backends::GlobalSymbolRegistry::Global().RegisterFn(new_func_name + "_kernel_ptr_", - reinterpret_cast(&reduce_sum_kernel)); + backends::GlobalSymbolRegistry::Global().RegisterFn( + new_func_name + "_kernel_ptr_", + reinterpret_cast(&reduce_sum_kernel)); // gen host code auto jit = backends::SimpleJIT::Create(); @@ -439,14 +480,16 @@ TEST(Operator, Operator_Reduction_Case_7) { auto func_0 = reinterpret_cast(fn_reduce_sum); srand(time(NULL)); - auto buffer_x = common::BufferBuilder(Float(32), {n, c, h, w}).set_random().Build(); + auto buffer_x = + common::BufferBuilder(Float(32), {n, c, h, w}).set_random().Build(); auto buffer_y = common::BufferBuilder(Float(32), {h, w}).set_random().Build(); void *dev_x = nullptr, *dev_y = nullptr; CUDA_CALL(cudaMalloc(&dev_x, buffer_x->memory_size)); CUDA_CALL(cudaMalloc(&dev_y, buffer_y->memory_size)); - CUDA_CALL(cudaMemcpy(dev_x, buffer_x->memory, buffer_x->memory_size, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy( + dev_x, buffer_x->memory, buffer_x->memory_size, cudaMemcpyHostToDevice)); cinn_buffer_t _x; cinn_buffer_t _y; @@ -461,7 +504,8 @@ TEST(Operator, Operator_Reduction_Case_7) { cinn_pod_value_t args0[] = {x_arg, y_arg}; func_0(args0, 2, stream); - CUDA_CALL(cudaMemcpy(buffer_y->memory, dev_y, buffer_y->memory_size, cudaMemcpyDeviceToHost)); + CUDA_CALL(cudaMemcpy( + buffer_y->memory, dev_y, buffer_y->memory_size, cudaMemcpyDeviceToHost)); CUDA_CALL(cudaFree(dev_x)); CUDA_CALL(cudaFree(dev_y)); @@ -469,91 +513,97 @@ TEST(Operator, Operator_Reduction_Case_7) { TEST(Operator, Operator_Reduction_Case_8) { std::vector shape = {128, 1}; - std::vector dim = {0}; + std::vector dim = {0}; GenReduceCode(shape, dim, "Operator_Reduction_Case_8"); } TEST(Operator, Operator_Reduction_Case_88) { std::vector shape = {128, 1}; - std::vector dim = {0}; + std::vector dim = {0}; GenReduceCode(shape, dim, "Operator_Reduction_Case_88", true); } TEST(Operator, Operator_Reduction_Case_9) { std::vector shape = {2560, 1}; - std::vector dim = {0}; + std::vector dim = {0}; GenReduceCode(shape, dim, "Operator_Reduction_Case_9"); } TEST(Operator, Operator_Reduction_Case_99) { std::vector shape = {2560, 1}; - std::vector dim = {0}; + std::vector dim = {0}; GenReduceCode(shape, dim, "Operator_Reduction_Case_99", true); } TEST(Operator, Operator_Reduction_Case_10) { std::vector shape = {16, 2560, 1}; - std::vector dim = {1}; + std::vector dim = {1}; GenReduceCode(shape, dim, "Operator_Reduction_Case_10"); } TEST(Operator, Operator_Reduction_Case_11) { std::vector shape = {16, 128, 128, 1}; - std::vector dim = {1, 2}; + std::vector dim = {1, 2}; GenReduceCode(shape, dim, "Operator_Reduction_Case_11"); } TEST(Operator, Operator_Reduction_Case_Warp_Reduce) { - int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); - int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = + common::DefaultNVGPUTarget().get_max_threads_per_sm(); int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; std::vector shape = {warp_reduce_threshold + 10, 256}; - std::vector dim = {1}; + std::vector dim = {1}; auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce"); CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); } TEST(Operator, Operator_Reduction_Case_Block_Reduce) { - int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); - int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = + common::DefaultNVGPUTarget().get_max_threads_per_sm(); int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; std::vector shape = {warp_reduce_threshold - 10, 33}; - std::vector dim = {1}; + std::vector dim = {1}; auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce"); CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); } TEST(Operator, Operator_Reduction_Case_Warp_Reduce_Case_1) { - int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); - int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = + common::DefaultNVGPUTarget().get_max_threads_per_sm(); int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; std::vector shape = {(warp_reduce_threshold + 32) / 2, 2, 10, 256}; - std::vector dim = {2, 3}; + std::vector dim = {2, 3}; - auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce_Case_1"); + auto res = + GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce_Case_1"); CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); } TEST(Operator, Operator_Reduction_Case_Block_Reduce_Case_1) { - int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); - int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = + common::DefaultNVGPUTarget().get_max_threads_per_sm(); int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; std::vector shape = {(warp_reduce_threshold - 32) / 2, 2, 10, 33}; - std::vector dim = {2, 3}; + std::vector dim = {2, 3}; - auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce_Case_2"); + auto res = + GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce_Case_2"); CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); } } // namespace framework diff --git a/paddle/cinn/hlir/op/transform.cc b/paddle/cinn/hlir/op/transform.cc index 0b370d6dc027a..1f362dc76fce4 100644 --- a/paddle/cinn/hlir/op/transform.cc +++ b/paddle/cinn/hlir/op/transform.cc @@ -41,29 +41,34 @@ using framework::OpStrategy; using framework::shape_t; using framework::StrategyFunction; -std::shared_ptr StrategyForMatMul(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForMatMul( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { const auto &attr_store = attrs.attr_store; - bool trans_a = SafeGetAttr(attr_store, "trans_a", false); - bool trans_b = SafeGetAttr(attr_store, "trans_b", false); - float alpha = SafeGetAttr(attr_store, "alpha", 1.0f); + bool trans_a = SafeGetAttr(attr_store, "trans_a", false); + bool trans_b = SafeGetAttr(attr_store, "trans_b", false); + float alpha = SafeGetAttr(attr_store, "alpha", 1.0f); const auto &shape_A = ToPodVector(inputs[0]->shape); const auto &shape_B = ToPodVector(inputs[1]->shape); - const auto &new_shape = pe::utils::GetMatmulNewShapes({shape_A, shape_B}, trans_a, trans_b); + const auto &new_shape = + pe::utils::GetMatmulNewShapes({shape_A, shape_B}, trans_a, trans_b); - const auto &new_shape_A = new_shape[0]; - const auto &new_shape_B = new_shape[1]; + const auto &new_shape_A = new_shape[0]; + const auto &new_shape_B = new_shape[1]; const auto &output_shape = new_shape[2]; - framework::CINNCompute matmul_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of Matmul compute is empty! Please check.\n"; + framework::CINNCompute matmul_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of Matmul compute is empty! Please check.\n"; CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 2U) << "at least 2 input tensors for Matmul compute\n"; + CHECK_GE(pack_args.size(), 2U) + << "at least 2 input tensors for Matmul compute\n"; Expr A = pack_args[0]; Expr B = pack_args[1]; CHECK(A.as_tensor()); @@ -78,7 +83,7 @@ std::shared_ptr StrategyForMatMul(const framework::NodeAttr &attrs, auto tensor_A = A.as_tensor_ref(); auto tensor_B = B.as_tensor_ref(); - auto stages = CreateStages({tensor_A, tensor_B}); + auto stages = CreateStages({tensor_A, tensor_B}); auto new_shape_A_e = ToCinnExprs(new_shape_A); auto new_shape_B_e = ToCinnExprs(new_shape_B); @@ -89,9 +94,21 @@ std::shared_ptr StrategyForMatMul(const framework::NodeAttr &attrs, std::vector out; if (target.arch == Target::Arch::X86) { #ifdef CINN_WITH_MKL_CBLAS - out = pe::MatmulMKL(new_A, new_B, trans_a, trans_b, alpha, UniqName("MatmulMKL_output"), target); + out = pe::MatmulMKL(new_A, + new_B, + trans_a, + trans_b, + alpha, + UniqName("MatmulMKL_output"), + target); #else - out = pe::MatmulV2(new_A, new_B, trans_a, trans_b, alpha, UniqName("MatmulV2_output"), target); + out = pe::MatmulV2(new_A, + new_B, + trans_a, + trans_b, + alpha, + UniqName("MatmulV2_output"), + target); #endif } else { out = pe::Matmul(new_A, new_B, trans_a, trans_b, alpha, tensor_name); @@ -109,12 +126,15 @@ std::shared_ptr StrategyForMatMul(const framework::NodeAttr &attrs, *ret = CINNValuePack{res}; }); - framework::CINNSchedule matmul_schedule([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of matmul schedule is empty! Please check.\n"; + framework::CINNSchedule matmul_schedule([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of matmul schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; if (FLAGS_cinn_ir_schedule) { - std::vector results = pe::IRCudaScheduleMatMul(arg_pack, output_shape, target); - *ret = CINNValuePack({results}); + std::vector results = + pe::IRCudaScheduleMatMul(arg_pack, output_shape, target); + *ret = CINNValuePack({results}); } else { CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL); poly::StageMap stages = arg_pack.back(); @@ -127,11 +147,12 @@ std::shared_ptr StrategyForMatMul(const framework::NodeAttr &attrs, CHECK_EQ(arg_pack.size(), 3UL); #else CHECK_EQ(arg_pack.size(), 3UL); - Expr out = arg_pack[0]; + Expr out = arg_pack[0]; Expr packedB = arg_pack[1]; CHECK(packedB.as_tensor()); CHECK(out.as_tensor()); - pe::MatmulScheduleCPU(stages, out.as_tensor_ref(), packedB.as_tensor_ref(), target); + pe::MatmulScheduleCPU( + stages, out.as_tensor_ref(), packedB.as_tensor_ref(), target); #endif } *ret = arg_pack; @@ -144,42 +165,55 @@ std::shared_ptr StrategyForMatMul(const framework::NodeAttr &attrs, return strategy; } -std::vector> InferShapeForMatMul(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 2UL) << "The input's shape size should be 2! Please check again."; +std::vector> InferShapeForMatMul( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 2UL) + << "The input's shape size should be 2! Please check again."; bool trans_a = SafeGetAttr(attrs, "trans_a", false); bool trans_b = SafeGetAttr(attrs, "trans_b", false); - VLOG(4) << "During the matmul shape inference, origin shape_A: " << utils::Join(inputs_shape[0], ", "); - VLOG(4) << "During the matmul shape inference, origin shape_B: " << utils::Join(inputs_shape[1], ", "); + VLOG(4) << "During the matmul shape inference, origin shape_A: " + << utils::Join(inputs_shape[0], ", "); + VLOG(4) << "During the matmul shape inference, origin shape_B: " + << utils::Join(inputs_shape[1], ", "); - const auto &new_shape = pe::utils::GetMatmulNewShapes(inputs_shape, trans_a, trans_b); + const auto &new_shape = + pe::utils::GetMatmulNewShapes(inputs_shape, trans_a, trans_b); - const auto &new_shape_A = new_shape[0]; - const auto &new_shape_B = new_shape[1]; + const auto &new_shape_A = new_shape[0]; + const auto &new_shape_B = new_shape[1]; const auto &output_shape = new_shape[2]; - VLOG(4) << "During the matmul shape inference, new_shape_A: " << utils::Join(new_shape_A, ", "); - VLOG(4) << "During the matmul shape inference, new_shape_B: " << utils::Join(new_shape_B, ", "); - VLOG(4) << "During the matmul shape inference, output_shape: " << utils::Join(output_shape, ", "); + VLOG(4) << "During the matmul shape inference, new_shape_A: " + << utils::Join(new_shape_A, ", "); + VLOG(4) << "During the matmul shape inference, new_shape_B: " + << utils::Join(new_shape_B, ", "); + VLOG(4) << "During the matmul shape inference, output_shape: " + << utils::Join(output_shape, ", "); std::vector> res{output_shape}; return res; } -std::vector InferDtypeForMatMul(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 2UL) << "The input's type size should be 2! Please check again."; - CHECK_EQ(inputs_type[0], inputs_type[1]) << "The input's types should be equal! Please check again."; +std::vector InferDtypeForMatMul(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 2UL) + << "The input's type size should be 2! Please check again."; + CHECK_EQ(inputs_type[0], inputs_type[1]) + << "The input's types should be equal! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::vector> InferLayoutForMatMul(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 2U) << "The input's layouts size is not 2! Please check again."; +std::vector> InferLayoutForMatMul( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 2U) + << "The input's layouts size is not 2! Please check again."; CHECK_EQ(input_shapes.size(), 2U) << "mul should have 2 input shapes"; std::vector new_input_layouts = input_layouts; for (int i = 0; i < input_shapes.size(); i++) { @@ -192,62 +226,71 @@ std::vector> InferLayoutForMatMul(const std::vector StrategyForSplit(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForSplit( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { // get attribute std::vector sections; int axis = 0; if (attrs.attr_store.find("num_or_sections") != attrs.attr_store.end()) { - sections = absl::get>(attrs.attr_store.at("num_or_sections")); + sections = + absl::get>(attrs.attr_store.at("num_or_sections")); } if (attrs.attr_store.find("axis") != attrs.attr_store.end()) { axis = absl::get(attrs.attr_store.at("axis")); } if (axis < 0) axis += static_cast(output_shapes[0].size()); - CHECK(!output_shapes.empty()) << "The Spilt Op's output shape list should not empty."; + CHECK(!output_shapes.empty()) + << "The Spilt Op's output shape list should not empty."; CHECK_LT(axis, static_cast(output_shapes[0].size())); CHECK(!sections.empty()) - << "The Split op doesn't find [num_or_sections] attrbute! It it a mandatory attribute ! Please check."; - - framework::CINNCompute split_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of split compute is empty! Please check."; - CINNValuePack pack_args = args[0]; - CHECK(!pack_args.empty()) << "The input tensors of split compute is empty! Please check."; - Expr A_expr = pack_args[0]; - CHECK(A_expr.as_tensor()); - ir::Tensor A = A_expr.as_tensor_ref(); - - std::vector tensor_names; - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), output_shapes.size() + 1); - for (int idx = 1; idx < pack_args.size(); ++idx) { - CHECK(pack_args[idx].is_string()); - tensor_names.push_back(pack_args[idx].operator std::string()); - } - } else { - for (int idx = 0; idx < output_shapes.size(); ++idx) { - tensor_names.push_back(UniqName("T_Split_Out")); - } - } + << "The Split op doesn't find [num_or_sections] attrbute! It it a " + "mandatory attribute ! Please check."; + + framework::CINNCompute split_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of split compute is empty! Please check."; + CINNValuePack pack_args = args[0]; + CHECK(!pack_args.empty()) + << "The input tensors of split compute is empty! Please check."; + Expr A_expr = pack_args[0]; + CHECK(A_expr.as_tensor()); + ir::Tensor A = A_expr.as_tensor_ref(); + + std::vector tensor_names; + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), output_shapes.size() + 1); + for (int idx = 1; idx < pack_args.size(); ++idx) { + CHECK(pack_args[idx].is_string()); + tensor_names.push_back(pack_args[idx].operator std::string()); + } + } else { + for (int idx = 0; idx < output_shapes.size(); ++idx) { + tensor_names.push_back(UniqName("T_Split_Out")); + } + } - auto out = pe::Split(A, axis, output_shapes, tensor_names); - auto stages = CreateStages(out); + auto out = pe::Split(A, axis, output_shapes, tensor_names); + auto stages = CreateStages(out); - std::vector res; - for (int i = 0; i < out.size(); ++i) { - res.emplace_back(out[i]); - } - res.emplace_back(stages); - *ret = CINNValuePack{res}; - }); + std::vector res; + for (int i = 0; i < out.size(); ++i) { + res.emplace_back(out[i]); + } + res.emplace_back(stages); + *ret = CINNValuePack{res}; + }); - framework::CINNSchedule split_schedule([=](lang::Args args, lang::RetValue *ret) { + framework::CINNSchedule split_schedule([=](lang::Args args, + lang::RetValue *ret) { if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of split schedule is empty! Please check."; + CHECK(!args.empty()) + << "The input argument of split schedule is empty! Please check."; CINNValuePack arg_pack = args[0]; std::vector vec_ast; for (int i = 0; i < arg_pack.size(); i++) { @@ -261,13 +304,16 @@ std::shared_ptr StrategyForSplit(const framework::NodeAttr &attrs, ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); pe::IRCudaSplitSchedule(ir_sch, output_shapes, axis, target); - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else { - CHECK(!args.empty()) << "The input arguments of split schedule is empty! Please check."; + CHECK(!args.empty()) + << "The input arguments of split schedule is empty! Please check."; CINNValuePack arg_pack = args[0]; - CHECK_GE(arg_pack.size(), 2UL) << "The input tensor's size of split schedule is " << arg_pack.size() - << "and it should be greater equal to 2! Please check."; + CHECK_GE(arg_pack.size(), 2UL) + << "The input tensor's size of split schedule is " << arg_pack.size() + << "and it should be greater equal to 2! Please check."; pe::CudaSplitSchedule(&arg_pack, output_shapes, axis, target); *ret = arg_pack; } @@ -279,13 +325,15 @@ std::shared_ptr StrategyForSplit(const framework::NodeAttr &attrs, return strategy; } -std::vector> InferShapeForSplit(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector> InferShapeForSplit( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { std::vector sections; if (attrs.find("num_or_sections") != attrs.end()) { sections = absl::get>(attrs.at("num_or_sections")); } else { - LOG(FATAL) << "The Split op doesn't find [num_or_sections] attrbute! It it a mandatory attribute ! Please check."; + LOG(FATAL) << "The Split op doesn't find [num_or_sections] attrbute! It it " + "a mandatory attribute ! Please check."; } if (inputs_shape.empty()) { @@ -297,7 +345,8 @@ std::vector> InferShapeForSplit(const std::vector> InferShapeForSplit(const std::vector 0) { @@ -330,9 +382,11 @@ std::vector> InferShapeForSplit(const std::vector> InferShapeForSplit(const std::vector> InferShapeForSplit(const std::vector InferDtypeForSplit(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForSplit(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector sections; if (attrs.find("num_or_sections") != attrs.end()) { sections = absl::get>(attrs.at("num_or_sections")); } else { - LOG(FATAL) << "The Split op doesn't find [num_or_sections] attrbute! It it a mandatory attribute ! Please check."; + LOG(FATAL) << "The Split op doesn't find [num_or_sections] attrbute! It it " + "a mandatory attribute ! Please check."; } int output_size = sections.size(); @@ -372,16 +431,20 @@ std::vector InferDtypeForSplit(const std::vector &inputs_type, const return res; } -std::vector> InferLayoutForSplit(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK(!input_layouts.empty()) << "The input's layout size is 0! Please check again."; +std::vector> InferLayoutForSplit( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK(!input_layouts.empty()) + << "The input's layout size is 0! Please check again."; std::vector sections; if (attrs.attr_store.find("num_or_sections") != attrs.attr_store.end()) { - sections = absl::get>(attrs.attr_store.at("num_or_sections")); + sections = + absl::get>(attrs.attr_store.at("num_or_sections")); } else { - LOG(FATAL) << "The Split op doesn't find [num_or_sections] attrbute! It it a mandatory attribute ! Please check."; + LOG(FATAL) << "The Split op doesn't find [num_or_sections] attrbute! It it " + "a mandatory attribute ! Please check."; } int output_size = sections.size(); @@ -393,17 +456,23 @@ std::vector> InferLayoutForSplit(const std::vector StrategyForConcat(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute concat_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of Concat compute is empty! Please check.\n"; - CHECK(!out_type.empty()) << "Output type of Concat is empty! Please check.\n"; +std::shared_ptr StrategyForConcat( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute concat_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of Concat compute is empty! Please check.\n"; + CHECK(!out_type.empty()) + << "Output type of Concat is empty! Please check.\n"; CINNValuePack pack_args = args[0]; - int input_size = FLAGS_cinn_ir_schedule ? pack_args.size() - 1 : pack_args.size(); - CHECK_GE(input_size, 1UL) << "at least 2 input tensors for Concat compute\n"; + int input_size = + FLAGS_cinn_ir_schedule ? pack_args.size() - 1 : pack_args.size(); + CHECK_GE(input_size, 1UL) + << "at least 2 input tensors for Concat compute\n"; CHECK(!output_shapes.empty()); int axis = 0; if (attrs.attr_store.count("axis")) { @@ -424,20 +493,25 @@ std::shared_ptr StrategyForConcat(const framework::NodeAttr &attrs, } auto stages = CreateStages(input_tensors); - auto out = pe::Concat(input_tensors, axis, tensor_name); + auto out = pe::Concat(input_tensors, axis, tensor_name); stages->InsertLazily(out); *ret = CINNValuePack({CINNValue(out), CINNValue(stages)}); }); auto strategy = std::make_shared(); - strategy->AddImpl(concat_compute, GetInjectiveScheduleFunc(output_shapes, target, false), "strategy.concat.x86", 1); + strategy->AddImpl(concat_compute, + GetInjectiveScheduleFunc(output_shapes, target, false), + "strategy.concat.x86", + 1); return strategy; } -std::vector> InferShapeForConcat(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_GE(inputs_shape.size(), 1UL) << "The input's shape size should be no less than 2! Please check again."; +std::vector> InferShapeForConcat( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_GE(inputs_shape.size(), 1UL) + << "The input's shape size should be no less than 2! Please check again."; int axis = 0; for (auto &iter : attrs) { if (iter.first == "axis") { @@ -449,19 +523,23 @@ std::vector> InferShapeForConcat(const std::vector output_shape = inputs_shape[0]; CHECK(axis >= 0 && axis < inputs_shape[0].size()) - << "In Concat op, the attribute `axis` should be >= 0 and < input shape's size, please check!"; + << "In Concat op, the attribute `axis` should be >= 0 and < input " + "shape's size, please check!"; int input_dim = inputs_shape[0].size(); for (int i = 1; i < inputs_shape.size(); i++) { CHECK_EQ(inputs_shape[i].size(), input_dim) - << "Dimensions of inputs tensors in Concat should be equal! Please check."; + << "Dimensions of inputs tensors in Concat should be equal! Please " + "check."; for (int j = 0; j < input_dim; j++) { if (j != axis) { CHECK_EQ(inputs_shape[0][j], inputs_shape[i][j]) << "The " << j << "-th dimension of input[0] and input[" << i - << "] should be the same, but here input[0].shape=[" << cinn::utils::Join(inputs_shape[0], ", ") - << "], input[" << i << "].shape=[" << cinn::utils::Join(inputs_shape[i], ", ") << "]! Please check."; + << "] should be the same, but here input[0].shape=[" + << cinn::utils::Join(inputs_shape[0], ", ") << "], input[" << i + << "].shape=[" << cinn::utils::Join(inputs_shape[i], ", ") + << "]! Please check."; } } @@ -472,94 +550,108 @@ std::vector> InferShapeForConcat(const std::vector InferDtypeForConcat(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForConcat(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::vector> InferLayoutForConcat(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_GE(input_layouts.size(), 1UL) << "The input's layout size is less than 2! Please check again."; +std::vector> InferLayoutForConcat( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_GE(input_layouts.size(), 1UL) + << "The input's layout size is less than 2! Please check again."; return {{input_layouts[0]}, input_layouts}; } -std::shared_ptr StrategyForMul(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForMul( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { CHECK_EQ(inputs.size(), 2UL) << "mul should have 2 input"; const auto &attr_store = attrs.attr_store; - int x_num_col_dims = SafeGetAttr(attr_store, "x_num_col_dims", 1); - int y_num_col_dims = SafeGetAttr(attr_store, "y_num_col_dims", 1); - bool is_infer = SafeGetAttr(attr_store, "is_infer", false); + int x_num_col_dims = SafeGetAttr(attr_store, "x_num_col_dims", 1); + int y_num_col_dims = SafeGetAttr(attr_store, "y_num_col_dims", 1); + bool is_infer = SafeGetAttr(attr_store, "is_infer", false); const auto &shape_A = ToPodVector(inputs[0]->shape); const auto &shape_B = ToPodVector(inputs[1]->shape); - const auto &new_shape = pe::utils::GetMulNewShapes({shape_A, shape_B}, x_num_col_dims, y_num_col_dims, is_infer); + const auto &new_shape = pe::utils::GetMulNewShapes( + {shape_A, shape_B}, x_num_col_dims, y_num_col_dims, is_infer); - const auto &new_shape_A = new_shape[0]; - const auto &new_shape_B = new_shape[1]; + const auto &new_shape_A = new_shape[0]; + const auto &new_shape_B = new_shape[1]; const auto &output_shape = new_shape[2]; - framework::CINNCompute mul_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of Mul compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 2U) << "at least 2 input tensors for Mul compute\n"; - Expr A = pack_args[0]; - Expr B = pack_args[1]; - CHECK(A.as_tensor()); - CHECK(B.as_tensor()); - - auto A_tensor = A.as_tensor_ref(); - auto B_tensor = B.as_tensor_ref(); - auto stages = CreateStages({A_tensor, B_tensor}); - - auto new_shape_A_e = ToCinnExprs(new_shape_A); - auto new_shape_B_e = ToCinnExprs(new_shape_B); - - auto new_A = A_tensor->Reshape(new_shape_A_e, stages); - auto new_B = B_tensor->Reshape(new_shape_B_e, stages); - - std::vector out; - std::string tensor_name = UniqName("Mul_output"); - if (FLAGS_cinn_ir_schedule) { - CHECK(pack_args.back().is_string()); - tensor_name = pack_args.back().operator std::string(); - } + framework::CINNCompute mul_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of Mul compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 2U) + << "at least 2 input tensors for Mul compute\n"; + Expr A = pack_args[0]; + Expr B = pack_args[1]; + CHECK(A.as_tensor()); + CHECK(B.as_tensor()); + + auto A_tensor = A.as_tensor_ref(); + auto B_tensor = B.as_tensor_ref(); + auto stages = CreateStages({A_tensor, B_tensor}); + + auto new_shape_A_e = ToCinnExprs(new_shape_A); + auto new_shape_B_e = ToCinnExprs(new_shape_B); + + auto new_A = A_tensor->Reshape(new_shape_A_e, stages); + auto new_B = B_tensor->Reshape(new_shape_B_e, stages); + + std::vector out; + std::string tensor_name = UniqName("Mul_output"); + if (FLAGS_cinn_ir_schedule) { + CHECK(pack_args.back().is_string()); + tensor_name = pack_args.back().operator std::string(); + } - if (target.arch == Target::Arch::X86) { + if (target.arch == Target::Arch::X86) { #ifdef CINN_WITH_MKL_CBLAS - out = pe::MatmulMKL(new_A, new_B, false, is_infer, 1.0f, tensor_name, target); + out = pe::MatmulMKL( + new_A, new_B, false, is_infer, 1.0f, tensor_name, target); #else - out = pe::MatmulV2(new_A, new_B, false, is_infer, 1.0f, tensor_name, target); + out = pe::MatmulV2( + new_A, new_B, false, is_infer, 1.0f, tensor_name, target); #endif - } else { - out = pe::Matmul(new_A, new_B, false, is_infer, 1.0f, tensor_name); - } + } else { + out = pe::Matmul(new_A, new_B, false, is_infer, 1.0f, tensor_name); + } - std::vector res; - for (auto &t : out) { - stages->InsertLazily(t); - } + std::vector res; + for (auto &t : out) { + stages->InsertLazily(t); + } - for (auto &t : out) { - res.push_back(CINNValue(t)); - } - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); + for (auto &t : out) { + res.push_back(CINNValue(t)); + } + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); - framework::CINNSchedule mul_schedule([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of matmul schedule is empty! Please check.\n"; + framework::CINNSchedule mul_schedule([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of matmul schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; if (FLAGS_cinn_ir_schedule) { - std::vector results = pe::IRCudaScheduleMatMul(arg_pack, output_shape, target); - *ret = CINNValuePack({results}); + std::vector results = + pe::IRCudaScheduleMatMul(arg_pack, output_shape, target); + *ret = CINNValuePack({results}); } else { CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL); poly::StageMap stages = arg_pack.back(); @@ -572,11 +664,12 @@ std::shared_ptr StrategyForMul(const framework::NodeAttr &attrs, CHECK_EQ(arg_pack.size(), 3UL); #else CHECK_EQ(arg_pack.size(), 3UL); - Expr out = arg_pack[0]; + Expr out = arg_pack[0]; Expr packedB = arg_pack[1]; CHECK(packedB.as_tensor()); CHECK(out.as_tensor()); - pe::MatmulScheduleCPU(stages, out.as_tensor_ref(), packedB.as_tensor_ref(), target); + pe::MatmulScheduleCPU( + stages, out.as_tensor_ref(), packedB.as_tensor_ref(), target); #endif } *ret = arg_pack; @@ -589,28 +682,38 @@ std::shared_ptr StrategyForMul(const framework::NodeAttr &attrs, return strategy; } -std::vector> InferShapeForMul(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; - CHECK_GE(inputs_shape[0].size(), 2U) << "Input matrix X's dim should be >= 2! Please check."; - CHECK_GE(inputs_shape[1].size(), 2U) << "Input matrix Y's dim should be >= 2! Please check."; - - VLOG(4) << "During the matmul shape inference, origin shape_A: " << utils::Join(inputs_shape[0], ", "); - VLOG(4) << "During the matmul shape inference, origin shape_B: " << utils::Join(inputs_shape[1], ", "); +std::vector> InferShapeForMul( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 2U) + << "The input's shape size should be 2! Please check again."; + CHECK_GE(inputs_shape[0].size(), 2U) + << "Input matrix X's dim should be >= 2! Please check."; + CHECK_GE(inputs_shape[1].size(), 2U) + << "Input matrix Y's dim should be >= 2! Please check."; + + VLOG(4) << "During the matmul shape inference, origin shape_A: " + << utils::Join(inputs_shape[0], ", "); + VLOG(4) << "During the matmul shape inference, origin shape_B: " + << utils::Join(inputs_shape[1], ", "); int x_num_col_dims = SafeGetAttr(attrs, "x_num_col_dims", 1); int y_num_col_dims = SafeGetAttr(attrs, "y_num_col_dims", 1); - bool is_infer = SafeGetAttr(attrs, "is_infer", false); + bool is_infer = SafeGetAttr(attrs, "is_infer", false); - const auto &new_shape = pe::utils::GetMulNewShapes(inputs_shape, x_num_col_dims, y_num_col_dims, is_infer); + const auto &new_shape = pe::utils::GetMulNewShapes( + inputs_shape, x_num_col_dims, y_num_col_dims, is_infer); - const auto &new_shape_A = new_shape[0]; - const auto &new_shape_B = new_shape[1]; + const auto &new_shape_A = new_shape[0]; + const auto &new_shape_B = new_shape[1]; const auto &output_shape = new_shape[2]; - VLOG(4) << "During the mul shape inference, new_shape_A: " << utils::Join(new_shape_A, ", "); - VLOG(4) << "During the mul shape inference, new_shape_B: " << utils::Join(new_shape_B, ", "); - VLOG(4) << "During the mul shape inference, output_shape: " << utils::Join(output_shape, ", "); + VLOG(4) << "During the mul shape inference, new_shape_A: " + << utils::Join(new_shape_A, ", "); + VLOG(4) << "During the mul shape inference, new_shape_B: " + << utils::Join(new_shape_B, ", "); + VLOG(4) << "During the mul shape inference, output_shape: " + << utils::Join(output_shape, ", "); int a_K = new_shape_A[1]; int b_K = is_infer ? new_shape_B[1] : new_shape_B[0]; @@ -620,18 +723,23 @@ std::vector> InferShapeForMul(const std::vector InferDtypeForMul(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 2U) << "The input's type size should be 2! Please check again."; - CHECK_EQ(inputs_type[0], inputs_type[1]) << "The input's types should be equal! Please check again."; +std::vector InferDtypeForMul(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 2U) + << "The input's type size should be 2! Please check again."; + CHECK_EQ(inputs_type[0], inputs_type[1]) + << "The input's types should be equal! Please check again."; return {inputs_type[0]}; } -std::vector> InferLayoutForMul(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 2U) << "The input's layouts size is not 2! Please check again."; +std::vector> InferLayoutForMul( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 2U) + << "The input's layouts size is not 2! Please check again."; CHECK_EQ(input_shapes.size(), 2U) << "mul should have 2 input shapes"; std::vector new_input_layouts = input_layouts; for (int i = 0; i < input_shapes.size(); i++) { @@ -644,69 +752,88 @@ std::vector> InferLayoutForMul(const std::vector StrategyForCublasGemm(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute gemm_compute([attrs](lang::Args args, lang::RetValue *ret) { - auto &attr_store = attrs.attr_store; - CHECK(attr_store.contains("trans_a")) << "The cublas_gemm should have an attr named `trans_a`."; - CHECK(attr_store.contains("trans_b")) << "The cublas_gemm should have an attr named `trans_b`."; - CHECK(!args.empty()) << "The input `args` of cublas_gemm is empty! Please check."; +std::shared_ptr StrategyForCublasGemm( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute gemm_compute( + [attrs](lang::Args args, lang::RetValue *ret) { + auto &attr_store = attrs.attr_store; + CHECK(attr_store.contains("trans_a")) + << "The cublas_gemm should have an attr named `trans_a`."; + CHECK(attr_store.contains("trans_b")) + << "The cublas_gemm should have an attr named `trans_b`."; + CHECK(!args.empty()) + << "The input `args` of cublas_gemm is empty! Please check."; - CINNValuePack input_args = args[0]; - CHECK_GE(input_args.size(), 3U) << "The input number of cublas_gemm should be equal to 3."; - Expr lhs = input_args[0]; - Expr rhs = input_args[1]; - Expr bias = input_args[2]; - CHECK(lhs.as_tensor()); - CHECK(rhs.as_tensor()); - CHECK(bias.as_tensor()); - auto bias_tensor = bias.as_tensor_ref(); - // dummy gemm computation, which will be replaced by cinn_gpu_cublas_gemm in the GemmRewriter pass. - - std::string tensor_name = UniqName("cublas_gemm_output"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(input_args.size(), 4); - CHECK(input_args[3].is_string()); - tensor_name = input_args[3].operator std::string(); - } - auto out = pe::Identity(bias_tensor, tensor_name).front(); - auto stages = CreateStages({lhs.as_tensor_ref(), rhs.as_tensor_ref(), bias_tensor}); - stages->InsertLazily(out); - std::vector res{CINNValue(out), CINNValue(stages)}; - *ret = CINNValuePack{res}; - }); + CINNValuePack input_args = args[0]; + CHECK_GE(input_args.size(), 3U) + << "The input number of cublas_gemm should be equal to 3."; + Expr lhs = input_args[0]; + Expr rhs = input_args[1]; + Expr bias = input_args[2]; + CHECK(lhs.as_tensor()); + CHECK(rhs.as_tensor()); + CHECK(bias.as_tensor()); + auto bias_tensor = bias.as_tensor_ref(); + // dummy gemm computation, which will be replaced by + // cinn_gpu_cublas_gemm in the GemmRewriter pass. + + std::string tensor_name = UniqName("cublas_gemm_output"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(input_args.size(), 4); + CHECK(input_args[3].is_string()); + tensor_name = input_args[3].operator std::string(); + } + auto out = pe::Identity(bias_tensor, tensor_name).front(); + auto stages = CreateStages( + {lhs.as_tensor_ref(), rhs.as_tensor_ref(), bias_tensor}); + stages->InsertLazily(out); + std::vector res{CINNValue(out), CINNValue(stages)}; + *ret = CINNValuePack{res}; + }); auto strategy = std::make_shared(); - strategy->AddImpl(gemm_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.cublas.gemm", 1); + strategy->AddImpl(gemm_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.cublas.gemm", + 1); return strategy; } -std::vector InferShapeForCublasGemm(const std::vector> &input_shapes, - const framework::AttrMapType &attrs) { - CHECK_EQ(input_shapes.size(), 3UL) << "cublas_gemm should have 3 input shapes"; +std::vector InferShapeForCublasGemm( + const std::vector> &input_shapes, + const framework::AttrMapType &attrs) { + CHECK_EQ(input_shapes.size(), 3UL) + << "cublas_gemm should have 3 input shapes"; CHECK_EQ(input_shapes[0].size(), input_shapes[1].size()); CHECK_EQ(input_shapes[0].size(), input_shapes[2].size()); CHECK((input_shapes[0].size() == 2 || input_shapes[0].size() == 3)); return {input_shapes[2]}; } -std::vector InferDtypeForCublasGemm(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_type.size(), 3UL) << "The input's type size is 0! Please check again."; - CHECK_EQ(inputs_type[0], inputs_type[1]) << "The input A and B's types should be equal! Please check again."; - CHECK_EQ(inputs_type[0], inputs_type[2]) << "The input A and C's types should be equal! Please check again."; +std::vector InferDtypeForCublasGemm(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 3UL) + << "The input's type size is 0! Please check again."; + CHECK_EQ(inputs_type[0], inputs_type[1]) + << "The input A and B's types should be equal! Please check again."; + CHECK_EQ(inputs_type[0], inputs_type[2]) + << "The input A and C's types should be equal! Please check again."; return {inputs_type[0]}; } -std::shared_ptr StrategyForLayoutTransform(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute layout_transform_compute([=](lang::Args args, lang::RetValue *ret) { +std::shared_ptr StrategyForLayoutTransform( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute layout_transform_compute([=](lang::Args args, + lang::RetValue *ret) { std::string src_layout; std::string dst_layout; if (attrs.attr_store.find("src_layout") != attrs.attr_store.end()) { @@ -715,9 +842,11 @@ std::shared_ptr StrategyForLayoutTransform(const framework::NodeAttr if (attrs.attr_store.find("dst_layout") != attrs.attr_store.end()) { dst_layout = absl::get(attrs.attr_store.at("dst_layout")); } - CHECK(!args.empty()) << "The input argument of layout_transform compute is empty! Please check.\n"; + CHECK(!args.empty()) << "The input argument of layout_transform compute is " + "empty! Please check.\n"; CINNValuePack input_args = args[0]; - CHECK(!input_args.empty()) << "at least one input tensor for layout_transform compute\n"; + CHECK(!input_args.empty()) + << "at least one input tensor for layout_transform compute\n"; Expr A = input_args[0]; CHECK(A.as_tensor()); @@ -728,66 +857,76 @@ std::shared_ptr StrategyForLayoutTransform(const framework::NodeAttr tensor_name = input_args[1].operator std::string(); } - auto out = pe::LayoutTransform(A.as_tensor_ref(), src_layout, dst_layout, tensor_name); + auto out = pe::LayoutTransform( + A.as_tensor_ref(), src_layout, dst_layout, tensor_name); auto stages = CreateStages({A.as_tensor_ref()}); std::vector res; stages->InsertLazily(out); - res = {CINNValue(out), CINNValue(stages)}; + res = {CINNValue(out), CINNValue(stages)}; *ret = CINNValuePack{res}; }); - framework::CINNSchedule layout_transform_schedule([=](lang::Args args, lang::RetValue *ret) { - if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of CublasGemm schedule is empty! Please check."; - CINNValuePack arg_pack = args[0]; - std::vector vec_ast; - for (int i = 0; i < arg_pack.size(); i++) { - if (arg_pack[i].is_expr()) { - Expr temp = arg_pack[i]; - vec_ast.emplace_back(temp); + framework::CINNSchedule layout_transform_schedule( + [=](lang::Args args, lang::RetValue *ret) { + if (FLAGS_cinn_ir_schedule) { + CHECK(!args.empty()) << "The input argument of CublasGemm schedule " + "is empty! Please check."; + CINNValuePack arg_pack = args[0]; + std::vector vec_ast; + for (int i = 0; i < arg_pack.size(); i++) { + if (arg_pack[i].is_expr()) { + Expr temp = arg_pack[i]; + vec_ast.emplace_back(temp); + } + } + CHECK(!vec_ast.empty()); + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + ir_sch.MergeExprs(); + + if (target.arch == Target::Arch::X86) { + pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target); + } else { + CINN_NOT_IMPLEMENTED + } + std::vector res{ + CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + *ret = CINNValuePack{res}; + } else { + CHECK(!args.empty()) << "The input argument of layout_transform " + "schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + CHECK_EQ(arg_pack.size(), 2UL); + Expr out = arg_pack[0]; + poly::StageMap stages = arg_pack[1]; + CHECK(out.as_tensor()); + auto tensor_out = out.as_tensor_ref(); + std::vector out_shape; + for (auto shape : tensor_out->shape) { + out_shape.push_back(shape.as_int32()); + } + if (target.arch == Target::Arch::X86) { + pe::ScheduleInjectiveCPU(stages[tensor_out], out_shape, target); + } else { + CINN_NOT_IMPLEMENTED + } + *ret = arg_pack; } - } - CHECK(!vec_ast.empty()); - ir::ModuleExpr mod_expr(vec_ast); - ir::IRSchedule ir_sch(mod_expr); - ir_sch.MergeExprs(); - - if (target.arch == Target::Arch::X86) { - pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target); - } else { - CINN_NOT_IMPLEMENTED - } - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; - *ret = CINNValuePack{res}; - } else { - CHECK(!args.empty()) << "The input argument of layout_transform schedule is empty! Please check.\n"; - CINNValuePack arg_pack = args[0]; - CHECK_EQ(arg_pack.size(), 2UL); - Expr out = arg_pack[0]; - poly::StageMap stages = arg_pack[1]; - CHECK(out.as_tensor()); - auto tensor_out = out.as_tensor_ref(); - std::vector out_shape; - for (auto shape : tensor_out->shape) { - out_shape.push_back(shape.as_int32()); - } - if (target.arch == Target::Arch::X86) { - pe::ScheduleInjectiveCPU(stages[tensor_out], out_shape, target); - } else { - CINN_NOT_IMPLEMENTED - } - *ret = arg_pack; - } - }); + }); auto strategy = std::make_shared(); - CHECK(out_type.size()) << "Out_type of layout_transform op is empty! Please check."; - strategy->AddImpl(layout_transform_compute, layout_transform_schedule, "strategy.layout_transform.x86", 1); + CHECK(out_type.size()) + << "Out_type of layout_transform op is empty! Please check."; + strategy->AddImpl(layout_transform_compute, + layout_transform_schedule, + "strategy.layout_transform.x86", + 1); return strategy; } -std::vector InferShapeForLayoutTransform(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForLayoutTransform( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { std::string src_layout; std::string dst_layout; if (attrs.find("src_layout") != attrs.end()) { @@ -803,8 +942,11 @@ std::vector InferShapeForLayoutTransform(const std::vector &in input_shapes_expr.push_back(Expr(shape)); } absl::flat_hash_map> split_index_map; - std::vector out_shapes = pe::InferShapeLayoutTransform( - input_shapes_expr, ir::Layout(src_layout), ir::Layout(dst_layout), &split_index_map); + std::vector out_shapes = + pe::InferShapeLayoutTransform(input_shapes_expr, + ir::Layout(src_layout), + ir::Layout(dst_layout), + &split_index_map); VLOG(4) << "out_shapes: " << out_shapes; std::vector output_shapes; for (auto &shape : out_shapes) { @@ -813,26 +955,30 @@ std::vector InferShapeForLayoutTransform(const std::vector &in return {output_shapes}; } -std::vector InferDtypeForLayoutTransform(const std::vector &inputs_type, - const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForLayoutTransform( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::shared_ptr StrategyForReverse(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForReverse( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { // check output shape - CHECK(!output_shapes.empty() && !output_shapes[0].empty()) << "Output shape is empty! Please check.\n"; + CHECK(!output_shapes.empty() && !output_shapes[0].empty()) + << "Output shape is empty! Please check.\n"; // get axis[0, n_dim) std::vector axis; if (attrs.attr_store.find("axis") != attrs.attr_store.end()) { axis = absl::get>(attrs.attr_store.at("axis")); for (auto &e : axis) { - if (e >= static_cast(output_shapes[0].size()) || e < -1 * static_cast(output_shapes[0].size())) { + if (e >= static_cast(output_shapes[0].size()) || + e < -1 * static_cast(output_shapes[0].size())) { LOG(FATAL) << "axis is not in [0, n_dim), Please check."; } if (e < 0) { @@ -841,10 +987,13 @@ std::shared_ptr StrategyForReverse(const framework::NodeAttr &attrs, } } - framework::CINNCompute reverse_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of reverse compute is empty! Please check.\n"; + framework::CINNCompute reverse_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of reverse compute is empty! Please check.\n"; CINNValuePack input_args = args[0]; - CHECK(!input_args.empty()) << "at least one input tensor for reverse compute\n"; + CHECK(!input_args.empty()) + << "at least one input tensor for reverse compute\n"; Expr A = input_args[0]; CHECK(A.as_tensor()); @@ -855,25 +1004,31 @@ std::shared_ptr StrategyForReverse(const framework::NodeAttr &attrs, tensor_name = input_args[1].operator std::string(); } - auto out = pe::Reverse(A.as_tensor_ref(), axis, tensor_name); + auto out = pe::Reverse(A.as_tensor_ref(), axis, tensor_name); auto stages = CreateStages({A.as_tensor_ref(), out}); - *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; }); auto strategy = std::make_shared(); CHECK(out_type.size()) << "Out_type of reverse op is empty! Please check."; - strategy->AddImpl(reverse_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.reverse.x86", 1); + strategy->AddImpl(reverse_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.reverse.x86", + 1); return strategy; } -std::vector InferShapeForReverse(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; +std::vector InferShapeForReverse( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) + << "The input's shape size is 0! Please check again."; std::vector res{inputs_shape[0]}; if (attrs.find("axis") != attrs.end()) { auto axis = absl::get>(attrs.at("axis")); for (auto &e : axis) { - if (e >= static_cast(inputs_shape[0].size()) || e < -1 * static_cast(inputs_shape[0].size())) { + if (e >= static_cast(inputs_shape[0].size()) || + e < -1 * static_cast(inputs_shape[0].size())) { LOG(FATAL) << "axis is not in [-n_dim, n_dim), Please check."; } if (e < 0) { @@ -884,27 +1039,32 @@ std::vector InferShapeForReverse(const std::vector> InferLayoutForReverse(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { +std::vector> InferLayoutForReverse( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { if (attrs.attr_store.find("axis") != attrs.attr_store.end()) { auto axis = absl::get>(attrs.attr_store.at("axis")); for (auto &e : axis) { - if (e >= static_cast(input_shapes[0].size()) || e < -1 * static_cast(input_shapes[0].size())) { + if (e >= static_cast(input_shapes[0].size()) || + e < -1 * static_cast(input_shapes[0].size())) { LOG(FATAL) << "axis is not in [-n_dim, n_dim), Please check."; } } } - CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again."; + CHECK_EQ(input_layouts.size(), 1U) + << "The input's layout size is not 1! Please check again."; return {input_layouts, input_layouts}; } -std::vector> InferLayoutForLayoutTransform(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 1U) << "The input's layouts size is not 1! Please check again."; +std::vector> InferLayoutForLayoutTransform( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 1U) + << "The input's layouts size is not 1! Please check again."; std::string dst_layout; std::string src_layout; if (attrs.attr_store.find("dst_layout") != attrs.attr_store.end()) { @@ -916,13 +1076,15 @@ std::vector> InferLayoutForLayoutTransform(const std::v return {{dst_layout}, {src_layout}}; } -std::shared_ptr StrategyForTranspose(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForTranspose( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { // check output shape - CHECK(!output_shapes.empty() && !output_shapes[0].empty()) << "Output shape is empty! Please check.\n"; + CHECK(!output_shapes.empty() && !output_shapes[0].empty()) + << "Output shape is empty! Please check.\n"; std::vector axis; auto input_shape = inputs[0]->shape; @@ -943,10 +1105,13 @@ std::shared_ptr StrategyForTranspose(const framework::NodeAttr &attr LOG(FATAL) << "axis is not be set! Please check."; } - framework::CINNCompute transpose_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of transpose compute is empty! Please check.\n"; + framework::CINNCompute transpose_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input argument of transpose compute is empty! Please check.\n"; CINNValuePack input_args = args[0]; - CHECK(!input_args.empty()) << "at least one input tensor for transpose compute\n"; + CHECK(!input_args.empty()) + << "at least one input tensor for transpose compute\n"; Expr A = input_args[0]; CHECK(A.as_tensor()); std::string tensor_name = UniqName("Transpose_output"); @@ -956,23 +1121,29 @@ std::shared_ptr StrategyForTranspose(const framework::NodeAttr &attr tensor_name = input_args[1].operator std::string(); } - auto out = pe::Transpose(A.as_tensor_ref(), axis, tensor_name); + auto out = pe::Transpose(A.as_tensor_ref(), axis, tensor_name); auto stages = CreateStages({out}); - *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; }); auto strategy = std::make_shared(); - strategy->AddImpl(transpose_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.transpose.x86", 1); + strategy->AddImpl(transpose_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.transpose.x86", + 1); return strategy; } -std::vector InferShapeForTranspose(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { +std::vector InferShapeForTranspose( + const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { std::vector result; - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) + << "The input's shape size is 0! Please check again."; if (attrs.find("axis") != attrs.end()) { auto axis = absl::get>(attrs.at("axis")); - CHECK_EQ(axis.size(), inputs_shape[0].size()) << "input size and axis size is not equal!"; + CHECK_EQ(axis.size(), inputs_shape[0].size()) + << "input size and axis size is not equal!"; std::vector output_shape; for (int idx = 0; idx < axis.size(); ++idx) { CHECK(axis[idx] >= 0 && axis[idx] < axis.size()); @@ -988,12 +1159,15 @@ std::vector InferShapeForTranspose(const std::vector> InferLayoutForTranspose(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_shapes.size(), 1U) << "The input's shape size is not 1! Please check again."; - CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again."; +std::vector> InferLayoutForTranspose( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_shapes.size(), 1U) + << "The input's shape size is not 1! Please check again."; + CHECK_EQ(input_layouts.size(), 1U) + << "The input's layout size is not 1! Please check again."; std::vector axis; if (attrs.attr_store.find("axis") != attrs.attr_store.end()) { @@ -1024,14 +1198,18 @@ std::vector> InferLayoutForTranspose(const std::vector< return {{output_layout}, new_input_layouts}; } -std::shared_ptr StrategyForGather(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - CHECK(!output_shapes.empty() && !output_shapes[0].empty()) << "The shape of output is empty! Please check again."; - VLOG(4) << "The output passed in StrategyForGather: " << utils::Join(output_shapes[0], ", "); - CHECK(!out_type.empty()) << "The output type of Gather is empty! Please check again.\n"; +std::shared_ptr StrategyForGather( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + CHECK(!output_shapes.empty() && !output_shapes[0].empty()) + << "The shape of output is empty! Please check again."; + VLOG(4) << "The output passed in StrategyForGather: " + << utils::Join(output_shapes[0], ", "); + CHECK(!out_type.empty()) + << "The output type of Gather is empty! Please check again.\n"; int axis = 0; if (attrs.attr_store.contains("axis")) { @@ -1046,12 +1224,14 @@ std::shared_ptr StrategyForGather(const framework::NodeAttr &attrs, } framework::CINNCompute gather_compute{ - [axis, output_shape = std::move(output_shape)](lang::Args args, lang::RetValue *ret) { + [axis, output_shape = std::move(output_shape)](lang::Args args, + lang::RetValue *ret) { VLOG(4) << "The axis value used in gather_compute: " << axis; CHECK(!args.empty()) << "The input args are empty! Please check again."; CINNValuePack input_args = args[0]; - int input_size = input_args.size(); - CHECK_GE(input_size, 2U) << "Require 2 input tensors for Gather compute."; + int input_size = input_args.size(); + CHECK_GE(input_size, 2U) + << "Require 2 input tensors for Gather compute."; Expr x = input_args[0]; CHECK(x.as_tensor()); Expr index = input_args[1]; @@ -1064,7 +1244,11 @@ std::shared_ptr StrategyForGather(const framework::NodeAttr &attrs, tensor_name = input_args[2].operator std::string(); } - auto out = pe::Gather(x.as_tensor_ref(), index.as_tensor_ref(), output_shape, axis, tensor_name); + auto out = pe::Gather(x.as_tensor_ref(), + index.as_tensor_ref(), + output_shape, + axis, + tensor_name); auto stages = CreateStages({x.as_tensor_ref(), index.as_tensor_ref()}); stages->InsertLazily(out); std::vector res{CINNValue(out), CINNValue(stages)}; @@ -1072,56 +1256,70 @@ std::shared_ptr StrategyForGather(const framework::NodeAttr &attrs, }}; auto strategy = std::make_shared(); - strategy->AddImpl(gather_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.gather.x86", 1); + strategy->AddImpl(gather_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.gather.x86", + 1); return strategy; } -std::vector> InferShapeForGather(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 2U) << "The inputs' shape size should be equal to 2! Please check again."; - std::vector x_shape = inputs_shape[0]; +std::vector> InferShapeForGather( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 2U) + << "The inputs' shape size should be equal to 2! Please check again."; + std::vector x_shape = inputs_shape[0]; std::vector index_shape = inputs_shape[1]; - int axis = absl::get(attrs.at("axis")); + int axis = absl::get(attrs.at("axis")); VLOG(4) << "The axis value used in Gather: " << axis; CHECK(axis >= 0 && axis < static_cast(x_shape.size())) - << "The attribute `axis` in Gather should be >= 0 and < the size of the first input shape! Please check again."; + << "The attribute `axis` in Gather should be >= 0 and < the size of the " + "first input shape! Please check again."; std::vector output_shape = x_shape; - output_shape[axis] = index_shape[axis]; + output_shape[axis] = index_shape[axis]; VLOG(4) << "The output shape of gather: " << utils::Join(output_shape, ", "); return {std::move(output_shape)}; } -std::vector InferDtypeForGather(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForGather(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; return {inputs_type[0]}; } -std::vector> InferLayoutForGather(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 2U) << "The input's layout size is not equal to 2! Please check again."; +std::vector> InferLayoutForGather( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 2U) + << "The input's layout size is not equal to 2! Please check again."; return {{input_layouts[0]}, input_layouts}; } -std::shared_ptr StrategyForScatterAssign(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForScatterAssign( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { int axis = 0; if (attrs.attr_store.find("axis") != attrs.attr_store.end()) { axis = absl::get(attrs.attr_store.at("axis")); } - framework::CINNCompute scatter_assign_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of ScatterAssign compute is empty! Please check.\n"; + framework::CINNCompute scatter_assign_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of ScatterAssign compute is " + "empty! Please check.\n"; CINNValuePack arg_pack = args[0]; - int input_size = arg_pack.size(); - CHECK_GE(input_size, 3U) << "at least 3 input tensors for ScatterAssign compute\n"; + int input_size = arg_pack.size(); + CHECK_GE(input_size, 3U) + << "at least 3 input tensors for ScatterAssign compute\n"; CHECK(!output_shapes.empty()); Expr expr_input = arg_pack[0]; @@ -1145,29 +1343,35 @@ std::shared_ptr StrategyForScatterAssign(const framework::NodeAttr & tensor_name = arg_pack[3].operator std::string(); } - auto out = pe::ScatterAssign(tensor_input, tensor_updates, tensor_index, target, axis, tensor_name); + auto out = pe::ScatterAssign( + tensor_input, tensor_updates, tensor_index, target, axis, tensor_name); std::vector res; stages->InsertLazily(out); res.push_back(CINNValue(out)); - CHECK(!out_type.empty()) << "Output type of ScatterAssign is empty! Please check.\n"; + CHECK(!out_type.empty()) + << "Output type of ScatterAssign is empty! Please check.\n"; res.push_back(CINNValue(stages)); *ret = CINNValuePack{res}; }); auto strategy = std::make_shared(); - strategy->AddImpl( - scatter_assign_compute, GetInjectiveScheduleFunc(output_shapes, target, false), "strategy.scatter_assign.x86", 1); + strategy->AddImpl(scatter_assign_compute, + GetInjectiveScheduleFunc(output_shapes, target, false), + "strategy.scatter_assign.x86", + 1); return strategy; } -std::vector> InferShapeForScatterAssign(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_GE(inputs_shape.size(), 3U) << "The input's shape size should be no less than 3! Please check again."; +std::vector> InferShapeForScatterAssign( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_GE(inputs_shape.size(), 3U) + << "The input's shape size should be no less than 3! Please check again."; - const auto &input_shape = inputs_shape[0]; + const auto &input_shape = inputs_shape[0]; const auto &assign_shape = inputs_shape[1]; - const auto &index_shape = inputs_shape[2]; + const auto &index_shape = inputs_shape[2]; int axis = 0; if (attrs.find("axis") != attrs.end()) { @@ -1177,56 +1381,72 @@ std::vector> InferShapeForScatterAssign(const std::vector= 0 && axis < input_shape.size()) - << "In ScatterAssign op, the attribute `axis` should be >= 0 and < input shape's size! Please check."; - CHECK_EQ(index_shape.size(), 1U) << "Dimensions of index tensor in ScatterAssign should be 1! Please check."; + << "In ScatterAssign op, the attribute `axis` should be >= 0 and < input " + "shape's size! Please check."; + CHECK_EQ(index_shape.size(), 1U) + << "Dimensions of index tensor in ScatterAssign should be 1! Please " + "check."; CHECK_EQ(input_shape.size(), assign_shape.size()) - << "Dimensions of inputs A and B in ScatterAssign should be equal! Please check."; + << "Dimensions of inputs A and B in ScatterAssign should be equal! " + "Please check."; CHECK_EQ(assign_shape[axis], index_shape[0]) - << "The first dimension of input B and index tensor in ScatterAssign should be equal! Please check."; + << "The first dimension of input B and index tensor in ScatterAssign " + "should be equal! Please check."; for (int i = 0; i < input_shape.size(); ++i) { if (i != axis) { CHECK_EQ(input_shape[i], assign_shape[i]) - << "The " << i << "-th dimension of input A and B in ScatterAssign should be equal! Please check."; + << "The " << i + << "-th dimension of input A and B in ScatterAssign should be equal! " + "Please check."; } } - VLOG(4) << "Each input tensor's shape of ScatterAssign: A(" << cinn::utils::Join(input_shape, ",") << "), B(" - << cinn::utils::Join(assign_shape, ",") << "), index(" << cinn::utils::Join(index_shape, ",") << ")" + VLOG(4) << "Each input tensor's shape of ScatterAssign: A(" + << cinn::utils::Join(input_shape, ",") << "), B(" + << cinn::utils::Join(assign_shape, ",") << "), index(" + << cinn::utils::Join(index_shape, ",") << ")" << " at axis (" << axis << ")"; return {input_shape}; } -std::vector InferDtypeForScatterAssign(const std::vector &inputs_type, - const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForScatterAssign( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::vector> InferLayoutForScatterAssign(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_GE(input_layouts.size(), 3U) << "The input's layout size is less than 3! Please check again."; +std::vector> InferLayoutForScatterAssign( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_GE(input_layouts.size(), 3U) + << "The input's layout size is less than 3! Please check again."; return {{input_layouts[0]}, input_layouts}; } -std::shared_ptr StrategyForScatterAdd(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForScatterAdd( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { int axis = 0; if (attrs.attr_store.find("axis") != attrs.attr_store.end()) { axis = absl::get(attrs.attr_store.at("axis")); } - framework::CINNCompute scatter_add_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of ScatterAdd compute is empty! Please check.\n"; + framework::CINNCompute scatter_add_compute([=](lang::Args args, + lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of ScatterAdd compute is " + "empty! Please check.\n"; CINNValuePack arg_pack = args[0]; - int input_size = arg_pack.size(); - CHECK_GE(input_size, 3U) << "at least 3 input tensors for ScatterAdd compute\n"; + int input_size = arg_pack.size(); + CHECK_GE(input_size, 3U) + << "at least 3 input tensors for ScatterAdd compute\n"; CHECK(!output_shapes.empty()); Expr expr_input = arg_pack[0]; @@ -1250,29 +1470,35 @@ std::shared_ptr StrategyForScatterAdd(const framework::NodeAttr &att tensor_name = arg_pack[3].operator std::string(); } - auto out = pe::ScatterAdd(tensor_input, tensor_updates, tensor_index, target, axis, tensor_name); + auto out = pe::ScatterAdd( + tensor_input, tensor_updates, tensor_index, target, axis, tensor_name); std::vector res; stages->InsertLazily(out); res.push_back(CINNValue(out)); - CHECK(!out_type.empty()) << "Output type of ScatterAdd is empty! Please check.\n"; + CHECK(!out_type.empty()) + << "Output type of ScatterAdd is empty! Please check.\n"; res.push_back(CINNValue(stages)); *ret = CINNValuePack{res}; }); auto strategy = std::make_shared(); - strategy->AddImpl( - scatter_add_compute, GetInjectiveScheduleFunc(output_shapes, target, false), "strategy.scatter_add.x86", 1); + strategy->AddImpl(scatter_add_compute, + GetInjectiveScheduleFunc(output_shapes, target, false), + "strategy.scatter_add.x86", + 1); return strategy; } -std::vector> InferShapeForScatterAdd(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_GE(inputs_shape.size(), 3U) << "The input's shape size should be no less than 3! Please check again."; +std::vector> InferShapeForScatterAdd( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_GE(inputs_shape.size(), 3U) + << "The input's shape size should be no less than 3! Please check again."; - const auto &input_shape = inputs_shape[0]; + const auto &input_shape = inputs_shape[0]; const auto &updates_shape = inputs_shape[1]; - const auto &index_shape = inputs_shape[2]; + const auto &index_shape = inputs_shape[2]; int axis = 0; if (attrs.find("axis") != attrs.end()) { @@ -1282,45 +1508,58 @@ std::vector> InferShapeForScatterAdd(const std::vector= 0 && axis < input_shape.size()) - << "In ScatterAdd op, the attribute `axis` should be >= 0 and < input shape's size! Please check."; - CHECK_EQ(index_shape.size(), 1U) << "Dimensions of index tensor in ScatterAdd should be 1! Please check."; + << "In ScatterAdd op, the attribute `axis` should be >= 0 and < input " + "shape's size! Please check."; + CHECK_EQ(index_shape.size(), 1U) + << "Dimensions of index tensor in ScatterAdd should be 1! Please check."; CHECK_EQ(input_shape.size(), updates_shape.size()) - << "Dimensions of inputs A and B in ScatterAdd should be equal! Please check."; + << "Dimensions of inputs A and B in ScatterAdd should be equal! Please " + "check."; CHECK_EQ(updates_shape[axis], index_shape[0]) - << "The first dimension of input B and index tensor in ScatterAdd should be equal! Please check."; + << "The first dimension of input B and index tensor in ScatterAdd should " + "be equal! Please check."; for (int i = 0; i < input_shape.size(); ++i) { if (i != axis) { CHECK_EQ(input_shape[i], updates_shape[i]) - << "The " << i << "-th dimension of input A and B in ScatterAdd should be equal! Please check."; + << "The " << i + << "-th dimension of input A and B in ScatterAdd should be equal! " + "Please check."; } } - VLOG(4) << "Each input tensor's shape of ScatterAdd: A(" << cinn::utils::Join(input_shape, ",") << "), B(" - << cinn::utils::Join(updates_shape, ",") << "), index(" << cinn::utils::Join(index_shape, ",") << ")" + VLOG(4) << "Each input tensor's shape of ScatterAdd: A(" + << cinn::utils::Join(input_shape, ",") << "), B(" + << cinn::utils::Join(updates_shape, ",") << "), index(" + << cinn::utils::Join(index_shape, ",") << ")" << " at axis (" << axis << ")"; return {input_shape}; } -std::vector InferDtypeForScatterAdd(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForScatterAdd(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::vector> InferLayoutForScatterAdd(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_GE(input_layouts.size(), 3U) << "The input's layout size is less than 3! Please check again."; +std::vector> InferLayoutForScatterAdd( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_GE(input_layouts.size(), 3U) + << "The input's layout size is less than 3! Please check again."; return {{input_layouts[0]}, input_layouts}; } -std::shared_ptr StrategyForSlice(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { +std::shared_ptr StrategyForSlice( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { std::vector starts, ends, axes, strides, decrease_axis; if (attrs.attr_store.find("starts") != attrs.attr_store.end()) { starts = absl::get>(attrs.attr_store.at("starts")); @@ -1335,21 +1574,28 @@ std::shared_ptr StrategyForSlice(const framework::NodeAttr &attrs, strides = absl::get>(attrs.attr_store.at("strides")); } if (attrs.attr_store.find("decrease_axis") != attrs.attr_store.end()) { - decrease_axis = absl::get>(attrs.attr_store.at("decrease_axis")); + decrease_axis = + absl::get>(attrs.attr_store.at("decrease_axis")); } - CHECK(!starts.empty()) << "The Slice op doesn't find [starts] attrbute! It it a mandatory attribute, please check."; - CHECK(!ends.empty()) << "The Slice op doesn't find [ends] attrbute! It it a mandatory attribute, please check."; - CHECK_EQ(starts.size(), ends.size()) << "The size of [starts] and [ends] must be identical! Please check."; + CHECK(!starts.empty()) << "The Slice op doesn't find [starts] attrbute! It " + "it a mandatory attribute, please check."; + CHECK(!ends.empty()) << "The Slice op doesn't find [ends] attrbute! It it a " + "mandatory attribute, please check."; + CHECK_EQ(starts.size(), ends.size()) + << "The size of [starts] and [ends] must be identical! Please check."; if (!axes.empty()) { - CHECK_EQ(starts.size(), axes.size()) << "The size of [starts] and [axes] must be identical! Please check."; + CHECK_EQ(starts.size(), axes.size()) + << "The size of [starts] and [axes] must be identical! Please check."; } else { for (int i = 0; i < starts.size(); i++) { axes.push_back(i); } } if (!strides.empty()) { - CHECK_EQ(starts.size(), strides.size()) << "The size of [starts] and [strides] must be identical! Please check."; + CHECK_EQ(starts.size(), strides.size()) + << "The size of [starts] and [strides] must be identical! Please " + "check."; } else { for (int i = 0; i < starts.size(); i++) { strides.push_back(1); @@ -1361,35 +1607,44 @@ std::shared_ptr StrategyForSlice(const framework::NodeAttr &attrs, output_shape.push_back(Expr(i)); } - framework::CINNCompute slice_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of slice compute is empty! Please check."; - CINNValuePack arg_pack = args[0]; - CHECK(!arg_pack.empty()) << "The input tensors of slice compute is empty! Please check."; - Expr A_expr = arg_pack[0]; - CHECK(A_expr.as_tensor()); - ir::Tensor A = A_expr.as_tensor_ref(); - - std::string tensor_name = UniqName("Slice_output"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(arg_pack.size(), 2U); - CHECK(arg_pack[1].is_string()); - tensor_name = arg_pack[1].operator std::string(); - } + framework::CINNCompute slice_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of slice compute is empty! Please check."; + CINNValuePack arg_pack = args[0]; + CHECK(!arg_pack.empty()) + << "The input tensors of slice compute is empty! Please check."; + Expr A_expr = arg_pack[0]; + CHECK(A_expr.as_tensor()); + ir::Tensor A = A_expr.as_tensor_ref(); + + std::string tensor_name = UniqName("Slice_output"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(arg_pack.size(), 2U); + CHECK(arg_pack[1].is_string()); + tensor_name = arg_pack[1].operator std::string(); + } - auto out = pe::Slice(A, starts, axes, strides, decrease_axis, output_shape, tensor_name); - auto stages = CreateStages({out}); - *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; - }); + auto out = pe::Slice( + A, starts, axes, strides, decrease_axis, output_shape, tensor_name); + auto stages = CreateStages({out}); + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + }); auto strategy = std::make_shared(); - strategy->AddImpl(slice_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.slice.x86", 1); + strategy->AddImpl(slice_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.slice.x86", + 1); return strategy; } -std::vector> InferShapeForSlice(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; +std::vector> InferShapeForSlice( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) + << "The input's shape size is 0! Please check again."; std::vector starts, ends, axes, strides, decrease_axis, infer_flags; for (auto &iter : attrs) { if (iter.first == "starts") { @@ -1408,18 +1663,24 @@ std::vector> InferShapeForSlice(const std::vector> InferShapeForSlice(const std::vector 0) { - CHECK(ends[i] > starts[i]) << "[ends] should greater than [starts] when strides > 0 ! But here " << ends[i] - << " < " << starts[i] << ", Please Check."; - output_shape[axes[i]] = (ends[i] - starts[i] + strides[i] - 1) / strides[i]; + CHECK(ends[i] > starts[i]) + << "[ends] should greater than [starts] when strides > 0 ! But here " + << ends[i] << " < " << starts[i] << ", Please Check."; + output_shape[axes[i]] = + (ends[i] - starts[i] + strides[i] - 1) / strides[i]; } else { - CHECK(ends[i] < starts[i]) << "[ends] should less than [starts] when strides < 0 ! But here " << ends[i] << " > " - << starts[i] << ", Please Check."; - output_shape[axes[i]] = (starts[i] - ends[i] + (-strides[i]) - 1) / (-strides[i]); + CHECK(ends[i] < starts[i]) + << "[ends] should less than [starts] when strides < 0 ! But here " + << ends[i] << " > " << starts[i] << ", Please Check."; + output_shape[axes[i]] = + (starts[i] - ends[i] + (-strides[i]) - 1) / (-strides[i]); } } if (decrease_axis.size() > 0) { std::vector new_shape; for (int i = 0; i < output_shape.size(); ++i) { - if (std::find(decrease_axis.cbegin(), decrease_axis.cend(), i) != decrease_axis.cend()) { - CHECK_EQ(output_shape[i], 1) << "Decrease dim should be 1, but now received " << output_shape[i]; + if (std::find(decrease_axis.cbegin(), decrease_axis.cend(), i) != + decrease_axis.cend()) { + CHECK_EQ(output_shape[i], 1) + << "Decrease dim should be 1, but now received " << output_shape[i]; } else { new_shape.emplace_back(output_shape[i]); } @@ -1470,23 +1738,29 @@ std::vector> InferShapeForSlice(const std::vector> res{output_shape}; return res; } -std::vector InferDtypeForSlice(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForSlice(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; } -std::vector> InferLayoutForSlice(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again."; - CHECK_EQ(input_shapes.size(), 1U) << "The input's shape size is not 1! Please check again."; +std::vector> InferLayoutForSlice( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 1U) + << "The input's layout size is not 1! Please check again."; + CHECK_EQ(input_shapes.size(), 1U) + << "The input's shape size is not 1! Please check again."; std::vector starts; std::vector ends; std::vector axes; @@ -1500,7 +1774,7 @@ std::vector> InferLayoutForSlice(const std::vector 4) { for (int i = 0; i < axes.size(); i++) { if (axes[i] == 1) { @@ -1515,15 +1789,20 @@ std::vector> InferLayoutForSlice(const std::vector StrategyForSliceAssign(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - CHECK_EQ(inputs.size(), 2) << "the number of input tensors must be equal to 2"; - CHECK(!output_shapes.empty() && !output_shapes[0].empty()) << "The shape of output is empty! Please check again."; - VLOG(4) << "The output passed in StrategyForSliceAssign: " << utils::Join(output_shapes[0], ", "); - CHECK(!out_type.empty()) << "The output type of SliceAssign is empty! Please check again.\n"; +std::shared_ptr StrategyForSliceAssign( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + CHECK_EQ(inputs.size(), 2) + << "the number of input tensors must be equal to 2"; + CHECK(!output_shapes.empty() && !output_shapes[0].empty()) + << "The shape of output is empty! Please check again."; + VLOG(4) << "The output passed in StrategyForSliceAssign: " + << utils::Join(output_shapes[0], ", "); + CHECK(!out_type.empty()) + << "The output type of SliceAssign is empty! Please check again.\n"; std::vector starts, ends, axes, strides; if (attrs.attr_store.find("starts") != attrs.attr_store.end()) { @@ -1540,69 +1819,91 @@ std::shared_ptr StrategyForSliceAssign(const framework::NodeAttr &at } CHECK(!starts.empty()) - << "The SliceAssign op doesn't find [starts] attrbute! It it a mandatory attribute, please check."; - CHECK(!ends.empty()) << "The SliceAssign op doesn't find [ends] attrbute! It it a mandatory attribute, please check."; - CHECK_EQ(starts.size(), ends.size()) << "The size of [starts] and [ends] must be identical! Please check."; + << "The SliceAssign op doesn't find [starts] attrbute! It it a mandatory " + "attribute, please check."; + CHECK(!ends.empty()) << "The SliceAssign op doesn't find [ends] attrbute! It " + "it a mandatory attribute, please check."; + CHECK_EQ(starts.size(), ends.size()) + << "The size of [starts] and [ends] must be identical! Please check."; if (!axes.empty()) { - CHECK_EQ(starts.size(), axes.size()) << "The size of [starts] and [axes] must be identical! Please check."; + CHECK_EQ(starts.size(), axes.size()) + << "The size of [starts] and [axes] must be identical! Please check."; } else { for (int i = 0; i < starts.size(); i++) { axes.push_back(i); } } if (!strides.empty()) { - CHECK_EQ(starts.size(), strides.size()) << "The size of [starts] and [strides] must be identical! Please check."; + CHECK_EQ(starts.size(), strides.size()) + << "The size of [starts] and [strides] must be identical! Please " + "check."; } else { for (int i = 0; i < starts.size(); i++) { strides.push_back(1); } } - framework::CINNCompute slice_assign_compute{[=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input args are empty! Please check again."; - CINNValuePack arg_pack = args[0]; - int input_size = arg_pack.size(); - CHECK_GE(input_size, 2U) << "Require 2 input tensors for SliceAssign compute."; - Expr input = arg_pack[0]; - CHECK(input.as_tensor()); - Expr assign = arg_pack[1]; - CHECK(assign.as_tensor()); - - std::string tensor_name = UniqName("slice_assign_output"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(arg_pack.size(), 3U); - CHECK(arg_pack[2].is_string()); - tensor_name = arg_pack[2].operator std::string(); - } + framework::CINNCompute slice_assign_compute{ + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input args are empty! Please check again."; + CINNValuePack arg_pack = args[0]; + int input_size = arg_pack.size(); + CHECK_GE(input_size, 2U) + << "Require 2 input tensors for SliceAssign compute."; + Expr input = arg_pack[0]; + CHECK(input.as_tensor()); + Expr assign = arg_pack[1]; + CHECK(assign.as_tensor()); + + std::string tensor_name = UniqName("slice_assign_output"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(arg_pack.size(), 3U); + CHECK(arg_pack[2].is_string()); + tensor_name = arg_pack[2].operator std::string(); + } - auto out = pe::SliceAssign(input.as_tensor_ref(), assign.as_tensor_ref(), axes, starts, ends, strides, tensor_name); - auto stages = CreateStages({out}); - std::vector res{CINNValue(out), CINNValue(stages)}; - *ret = CINNValuePack{res}; - }}; + auto out = pe::SliceAssign(input.as_tensor_ref(), + assign.as_tensor_ref(), + axes, + starts, + ends, + strides, + tensor_name); + auto stages = CreateStages({out}); + std::vector res{CINNValue(out), CINNValue(stages)}; + *ret = CINNValuePack{res}; + }}; auto strategy = std::make_shared(); - strategy->AddImpl( - slice_assign_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.slice_assign.x86", 1); + strategy->AddImpl(slice_assign_compute, + GetInjectiveScheduleFunc(output_shapes, target), + "strategy.slice_assign.x86", + 1); return strategy; } -std::vector> InferShapeForSliceAssign(const std::vector> &inputs_shape, - const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 2U) << "The inputs' shape size should be equal to 2! Please check again."; +std::vector> InferShapeForSliceAssign( + const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 2U) + << "The inputs' shape size should be equal to 2! Please check again."; return {inputs_shape[0]}; } -std::vector InferDtypeForSliceAssign(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; +std::vector InferDtypeForSliceAssign( + const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) + << "The input's type size is 0! Please check again."; return {inputs_type[0]}; } -std::vector> InferLayoutForSliceAssign(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 2U) << "The input's layout size is not equal to 2! Please check again."; +std::vector> InferLayoutForSliceAssign( + const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(input_layouts.size(), 2U) + << "The input's layout size is not equal to 2! Please check again."; return {{input_layouts[0]}, {""}}; } @@ -1613,7 +1914,8 @@ std::vector> InferLayoutForSliceAssign(const std::vecto CINN_REGISTER_HELPER(transform_ops) { CINN_REGISTER_OP(matmul) .describe( - "This operator is used to perform (batched) matrix multiplication over the last two dimensions of the input " + "This operator is used to perform (batched) matrix multiplication " + "over the last two dimensions of the input " "tensors X and Y.") .set_num_inputs(2) #ifdef CINN_WITH_CUDA @@ -1621,78 +1923,112 @@ CINN_REGISTER_HELPER(transform_ops) { #else .set_num_outputs(2) #endif - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForMatMul) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForMatMul)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForMatMul)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForMatMul) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForMatMul)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForMatMul)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForMatMul)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForMatMul)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); CINN_REGISTER_OP(split) - .describe("This operator is used to split tensors X to 'sections' sub-tensor on specified axis.") + .describe( + "This operator is used to split tensors X to 'sections' sub-tensor " + "on specified axis.") .set_num_inputs(1) .set_num_outputs(0) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForSplit) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForSplit)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForSplit)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForSplit) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForSplit)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForSplit)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForSplit)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForSplit)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); CINN_REGISTER_OP(concat) - .describe("This operator is used to concat two input tensors X and Y on specified axis.") + .describe( + "This operator is used to concat two input tensors X and Y on " + "specified axis.") .set_num_inputs(2) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForConcat) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForConcat)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForConcat)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForConcat) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForConcat)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForConcat)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForConcat)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForConcat)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) .set_support_level(4); CINN_REGISTER_OP(reverse) .describe("This operator implements the meta op reverse.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForReverse) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForReverse)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForLayoutTransform)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForReverse) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForReverse)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForLayoutTransform)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForReverse)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForReverse)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) .set_support_level(4); CINN_REGISTER_OP(transpose) .describe("This operator implements the meta op transpose.") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForTranspose) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForTranspose)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForLayoutTransform)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForTranspose) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForTranspose)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForLayoutTransform)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForTranspose)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForTranspose)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) .set_support_level(4); CINN_REGISTER_OP(mul) - .describe("This operator is used to perform matrix multiplication for input X and Y.") + .describe( + "This operator is used to perform matrix multiplication for input X " + "and Y.") .set_num_inputs(2) .set_num_outputs(2) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForMul) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForMul) .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForMul)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForMul)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForMul)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForMul)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); #ifdef CINN_WITH_CUDA @@ -1700,20 +2036,28 @@ CINN_REGISTER_HELPER(transform_ops) { .describe("This operator uses cublas to compute the gemm.") .set_num_inputs(3) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForCublasGemm) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForCublasGemm)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForCublasGemm)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForCublasGemm) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForCublasGemm)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForCublasGemm)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); CINN_REGISTER_OP(cublas_matmul) .describe("This operator uses cublas to compute the matmul.") .set_num_inputs(2) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForMatMul) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForMatMul)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForMatMul)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForMatMul) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForMatMul)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForMatMul)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); #endif @@ -1721,76 +2065,113 @@ CINN_REGISTER_HELPER(transform_ops) { .describe("This operator is used to transform op's layouts") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForLayoutTransform) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForLayoutTransform)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForLayoutTransform)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForLayoutTransform) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForLayoutTransform)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForLayoutTransform)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForLayoutTransform)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForLayoutTransform)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) .set_support_level(4); CINN_REGISTER_OP(slice) .describe("This operator implements the slice layer") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForSlice) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForSlice)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForSlice)) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForSlice) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForSlice)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForSlice)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForSlice)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForSlice)) #endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) .set_support_level(4); CINN_REGISTER_OP(slice_assign) - .describe("This operator is used to perform slice assign for tensor input and tensor assign.") + .describe( + "This operator is used to perform slice assign for tensor input and " + "tensor assign.") .set_num_inputs(2) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForSliceAssign) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForSliceAssign)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForSliceAssign)) - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForSliceAssign)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForSliceAssign) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForSliceAssign)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForSliceAssign)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForSliceAssign)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); CINN_REGISTER_OP(gather) .describe( - "This operator is used to create a new tensor which indexes the `input` tensor along dimension `axis` using " + "This operator is used to create a new tensor which indexes the " + "`input` tensor along dimension `axis` using " "the entries in `index`.") .set_num_inputs(2) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForGather) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForGather)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForGather)) - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForGather)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForGather) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForGather)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForGather)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForGather)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) .set_support_level(4); CINN_REGISTER_OP(scatter_assign) - .describe("This operator is used to assign tensor B to tensor A by index.") + .describe( + "This operator is used to assign tensor B to tensor A by index.") .set_num_inputs(3) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForScatterAssign) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForScatterAssign)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForScatterAssign)) - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForScatterAssign)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForScatterAssign) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForScatterAssign)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForScatterAssign)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForScatterAssign)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kInjective) .set_support_level(4); CINN_REGISTER_OP(scatter_add) - .describe("This operator is used to add update tensor B into tensor A by index.") + .describe( + "This operator is used to add update tensor B into tensor A by " + "index.") .set_num_inputs(3) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForScatterAdd) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForScatterAdd)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForScatterAdd)) - .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForScatterAdd)) - // Because the scatter_add operator calls the external function by passing pointers, - // the code generated by operator fusion will have out-of-bounds access. - // It should not fuse with any other injective operators, though scatter_add is injective. - // turn KNonFusible to kInjective will fail /Paddle/python/paddle/fluid/tests/unittests/test_index_select_op.py - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForScatterAdd) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForScatterAdd)) + .set_attr("inferdtype", + MakeOpFunction(cinn::hlir::op::InferDtypeForScatterAdd)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForScatterAdd)) + // Because the scatter_add operator calls the external function by passing + // pointers, the code generated by operator fusion will have out-of-bounds + // access. It should not fuse with any other injective operators, though + // scatter_add is injective. turn KNonFusible to kInjective will fail + // /Paddle/python/paddle/fluid/tests/unittests/test_index_select_op.py + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; diff --git a/paddle/cinn/hlir/op/transform_test.cc b/paddle/cinn/hlir/op/transform_test.cc index ea06fb03f4093..2d433e3ad3c50 100644 --- a/paddle/cinn/hlir/op/transform_test.cc +++ b/paddle/cinn/hlir/op/transform_test.cc @@ -56,7 +56,8 @@ using framework::StrategyFunction; TEST(SliceAssign, SliceAssign_Op) { // code gen auto slice_assign = Operator::Get("slice_assign"); - auto strategy = Operator::GetAttrs("CINNStrategy")[slice_assign]; + auto strategy = + Operator::GetAttrs("CINNStrategy")[slice_assign]; int m = 64; int n = 32; @@ -66,9 +67,9 @@ TEST(SliceAssign, SliceAssign_Op) { // set attrs NodeAttr attrs; - attrs.attr_store["axis"] = std::vector{0, 1}; - attrs.attr_store["starts"] = std::vector{16, 16}; - attrs.attr_store["ends"] = std::vector{32, 32}; + attrs.attr_store["axis"] = std::vector{0, 1}; + attrs.attr_store["starts"] = std::vector{16, 16}; + attrs.attr_store["ends"] = std::vector{32, 32}; attrs.attr_store["strides"] = std::vector{1, 1}; std::vector out_type{Float(32)}; @@ -80,26 +81,31 @@ TEST(SliceAssign, SliceAssign_Op) { #else auto target = common::DefaultHostTarget(); #endif - auto impl = OpStrategy::SelectImpl(strategy(attrs, inputs, out_type, {output_shape}, target)); + auto impl = OpStrategy::SelectImpl( + strategy(attrs, inputs, out_type, {output_shape}, target)); std::string func_name = "slice_assign"; if (FLAGS_cinn_ir_schedule) { - std::string out_name = "output"; - common::CINNValuePack cinn_input = common::CINNValuePack{ - {common::CINNValue(input.tensor()), common::CINNValue(assign.tensor()), common::CINNValue(out_name)}}; + std::string out_name = "output"; + common::CINNValuePack cinn_input = + common::CINNValuePack{{common::CINNValue(input.tensor()), + common::CINNValue(assign.tensor()), + common::CINNValue(out_name)}}; std::vector input_output_names{"input", "assign", out_name}; - auto funcs = framework::GetFuncFromImpl(impl, cinn_input, inputs, input_output_names, func_name, target); + auto funcs = framework::GetFuncFromImpl( + impl, cinn_input, inputs, input_output_names, func_name, target); for (auto func : funcs) { LOG(INFO) << "Test Operator_BroadcastTo's Strategy, func is :\n" << func; } } else { common::CINNValuePack cinn_input = - common::CINNValuePack{{common::CINNValue(input.tensor()), common::CINNValue(assign.tensor())}}; + common::CINNValuePack{{common::CINNValue(input.tensor()), + common::CINNValue(assign.tensor())}}; common::CINNValuePack rets = impl->fcompute(cinn_input); - rets = impl->fschedule(rets); + rets = impl->fschedule(rets); // the last element is a StageMap for (int i = 0; i < rets->size() - 1; i++) { @@ -109,7 +115,8 @@ TEST(SliceAssign, SliceAssign_Op) { } } - auto func = lang::LowerVec("slice_assign", rets.back(), inputs, {}, {}, nullptr, target); + auto func = lang::LowerVec( + "slice_assign", rets.back(), inputs, {}, {}, nullptr, target); for (auto& f : func) { LOG(INFO) << "Test Strategy Codegen:\n" << f; } diff --git a/paddle/cinn/hlir/pass/alterlayout.cc b/paddle/cinn/hlir/pass/alterlayout.cc index a89b8dd371e4c..3c8d775fc9bef 100644 --- a/paddle/cinn/hlir/pass/alterlayout.cc +++ b/paddle/cinn/hlir/pass/alterlayout.cc @@ -33,74 +33,87 @@ using framework::NodeData; using framework::Operator; using framework::OpValueType; -using InferShapeFunc = std::function(const std::vector&, - const framework::AttrMapType&)>; -using InferTypeFunc = std::function(const std::vector&, const framework::AttrMapType&)>; -using InferLayoutFunc = std::function>(const std::vector&, - const std::vector&, - const framework::NodeAttr&, - const Target&)>; +using InferShapeFunc = std::function( + const std::vector&, const framework::AttrMapType&)>; +using InferTypeFunc = std::function( + const std::vector&, const framework::AttrMapType&)>; +using InferLayoutFunc = std::function>( + const std::vector&, + const std::vector&, + const framework::NodeAttr&, + const Target&)>; // insert layout_transform after the input var -std::tuple InsertLayoutTransformNodeAfter(Graph* graph, - NodeData* input_data, - Node* dst_node, - int pos, - const std::string& src_layout, - const std::string& dst_layout, - const std::string& name) { +std::tuple InsertLayoutTransformNodeAfter( + Graph* graph, + NodeData* input_data, + Node* dst_node, + int pos, + const std::string& src_layout, + const std::string& dst_layout, + const std::string& name) { CHECK(graph); CHECK(input_data); - std::string op_type = "layout_transform"; - auto trans_node = new Node(Operator::Get(op_type), op_type, name); - trans_node->attrs.attr_store["src_layout"] = src_layout; - trans_node->attrs.attr_store["dst_layout"] = dst_layout; - auto output_data = InsertGraphOpNodeAfter(graph, trans_node, input_data, dst_node, pos); + std::string op_type = "layout_transform"; + auto trans_node = new Node(Operator::Get(op_type), op_type, name); + trans_node->attrs.attr_store["src_layout"] = src_layout; + trans_node->attrs.attr_store["dst_layout"] = dst_layout; + auto output_data = + InsertGraphOpNodeAfter(graph, trans_node, input_data, dst_node, pos); trans_node->attrs.attr_store["input_layouts"] = {src_layout}; - trans_node->attrs.attr_store["out_layouts"] = {dst_layout}; + trans_node->attrs.attr_store["out_layouts"] = {dst_layout}; return std::make_tuple(trans_node, output_data); } // insert layout_transform before the output var -std::tuple InsertLayoutTransformNodeBefore(Graph* graph, - Node* input_node, - NodeData* dst_data, - int pos, - const std::string& src_layout, - const std::string& dst_layout, - const std::string& name) { +std::tuple InsertLayoutTransformNodeBefore( + Graph* graph, + Node* input_node, + NodeData* dst_data, + int pos, + const std::string& src_layout, + const std::string& dst_layout, + const std::string& name) { CHECK(graph); CHECK(input_node); CHECK(dst_data); - std::string op_type = "layout_transform"; - auto trans_node = new Node(Operator::Get(op_type), op_type, name); - trans_node->attrs.attr_store["src_layout"] = src_layout; - trans_node->attrs.attr_store["dst_layout"] = dst_layout; - auto temp_outdata = InsertGraphOpNodeBefore(graph, trans_node, input_node, dst_data, pos); + std::string op_type = "layout_transform"; + auto trans_node = new Node(Operator::Get(op_type), op_type, name); + trans_node->attrs.attr_store["src_layout"] = src_layout; + trans_node->attrs.attr_store["dst_layout"] = dst_layout; + auto temp_outdata = + InsertGraphOpNodeBefore(graph, trans_node, input_node, dst_data, pos); trans_node->attrs.attr_store["input_layouts"] = {src_layout}; - trans_node->attrs.attr_store["out_layouts"] = {dst_layout}; + trans_node->attrs.attr_store["out_layouts"] = {dst_layout}; return std::make_tuple(trans_node, temp_outdata); } -std::vector UpdateInferInfos(Node* node, - const std::vector& input_shapes, - const std::vector& input_types, - const std::vector& input_layouts, - const common::Target& target, - const OpValueType& op_infershape, - const OpValueType& op_infertype, - const OpValueType& op_inferlayout, - absl::flat_hash_map* shape_dict, - absl::flat_hash_map* type_dict, - absl::flat_hash_map* layout_dict) { +std::vector UpdateInferInfos( + Node* node, + const std::vector& input_shapes, + const std::vector& input_types, + const std::vector& input_layouts, + const common::Target& target, + const OpValueType& op_infershape, + const OpValueType& op_infertype, + const OpValueType& op_inferlayout, + absl::flat_hash_map* shape_dict, + absl::flat_hash_map* type_dict, + absl::flat_hash_map* layout_dict) { CHECK(shape_dict); CHECK(type_dict); CHECK(layout_dict); - CHECK(op_infershape[node->op()]) << "find no InferShape function for op " << node->op()->name; - CHECK(op_infertype[node->op()]) << "find no InferDtype function for op " << node->op()->name; - CHECK(op_inferlayout[node->op()]) << "find no InferLayout function for op " << node->op()->name; - auto infershapes = op_infershape[node->op()](input_shapes, node->attrs.attr_store); - auto infertypes = op_infertype[node->op()](input_types, node->attrs.attr_store); - auto inferlayouts = op_inferlayout[node->op()](input_shapes, input_layouts, node->attrs, target); + CHECK(op_infershape[node->op()]) + << "find no InferShape function for op " << node->op()->name; + CHECK(op_infertype[node->op()]) + << "find no InferDtype function for op " << node->op()->name; + CHECK(op_inferlayout[node->op()]) + << "find no InferLayout function for op " << node->op()->name; + auto infershapes = + op_infershape[node->op()](input_shapes, node->attrs.attr_store); + auto infertypes = + op_infertype[node->op()](input_types, node->attrs.attr_store); + auto inferlayouts = op_inferlayout[node->op()]( + input_shapes, input_layouts, node->attrs, target); CHECK(!infershapes.empty()) << node->op()->name << " finds no infershape"; CHECK(!infertypes.empty()) << node->op()->name << " finds no infertype"; @@ -112,14 +125,15 @@ std::vector UpdateInferInfos(Node* node, CHECK_EQ(outlinks.size(), infershapes.size()); for (int i = 0; i < outlinks.size(); i++) { - auto* sink = outlinks[i]->sink(); - (*shape_dict)[sink->id()] = infershapes[i]; - (*type_dict)[sink->id()] = infertypes[i]; + auto* sink = outlinks[i]->sink(); + (*shape_dict)[sink->id()] = infershapes[i]; + (*type_dict)[sink->id()] = infertypes[i]; (*layout_dict)[sink->id()] = inferlayouts[0][i]; - VLOG(3) << "Infershape: " << node->op()->name << "'s " << i << "-th outlink " << sink->id() << ": " + VLOG(3) << "Infershape: " << node->op()->name << "'s " << i + << "-th outlink " << sink->id() << ": " << utils::Join(infershapes[i], ", "); } - node->attrs.attr_store["out_layouts"] = inferlayouts[0]; + node->attrs.attr_store["out_layouts"] = inferlayouts[0]; node->attrs.attr_store["input_layouts"] = inferlayouts[1]; return infershapes; } @@ -127,11 +141,14 @@ std::vector UpdateInferInfos(Node* node, void AlterLayoutPass(Graph* graph) { // alterlayout only in X86 for it's specific layout requirements if (graph->target_.arch == Target::Arch::X86) { - auto store_nodes = std::get<0>(graph->topological_order()); - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - auto& type_dict = graph->GetMutableAttrs>("inferdtype"); - auto& op_infershape = Operator::GetAttrs("infershape"); - auto& op_inferdtype = Operator::GetAttrs("inferdtype"); + auto store_nodes = std::get<0>(graph->topological_order()); + auto& shape_dict = graph->GetMutableAttrs< + absl::flat_hash_map>("infershape"); + auto& type_dict = + graph->GetMutableAttrs>( + "inferdtype"); + auto& op_infershape = Operator::GetAttrs("infershape"); + auto& op_inferdtype = Operator::GetAttrs("inferdtype"); auto& op_inferlayout = Operator::GetAttrs("inferlayout"); absl::flat_hash_map layout_dict; std::string model_name = ""; @@ -139,7 +156,8 @@ void AlterLayoutPass(Graph* graph) { model_name = graph->GetMutableAttrs("model_name"); VLOG(3) << "model_name: " << model_name; } - // collect all convs' original input config before altering layout for loading tune params afterwards + // collect all convs' original input config before altering layout for + // loading tune params afterwards int index = 0; for (int i = 0; i < store_nodes.size(); i++) { auto node = store_nodes[i]->safe_as(); @@ -147,25 +165,37 @@ void AlterLayoutPass(Graph* graph) { std::vector padding({0, 0}); std::vector stride({1, 1}); std::vector dilation({1, 1}); - if (node->attrs.attr_store.find("padding") != node->attrs.attr_store.end()) { - padding = absl::get>(node->attrs.attr_store.at("padding")); + if (node->attrs.attr_store.find("padding") != + node->attrs.attr_store.end()) { + padding = + absl::get>(node->attrs.attr_store.at("padding")); } - if (node->attrs.attr_store.find("stride") != node->attrs.attr_store.end()) { - stride = absl::get>(node->attrs.attr_store.at("stride")); + if (node->attrs.attr_store.find("stride") != + node->attrs.attr_store.end()) { + stride = + absl::get>(node->attrs.attr_store.at("stride")); } - if (node->attrs.attr_store.find("dilation") != node->attrs.attr_store.end()) { - dilation = absl::get>(node->attrs.attr_store.at("dilation")); + if (node->attrs.attr_store.find("dilation") != + node->attrs.attr_store.end()) { + dilation = absl::get>( + node->attrs.attr_store.at("dilation")); } const auto& conv_inlinks = node->inlinks_in_order(); CHECK_EQ(conv_inlinks.size(), 2U) << "conv2d should have 2 inputs"; std::vector> inputs_shape; for (auto& link : conv_inlinks) { auto* source = link->source(); - CHECK(shape_dict.count(source->id())) << source->id() << " finds no infershape"; + CHECK(shape_dict.count(source->id())) + << source->id() << " finds no infershape"; inputs_shape.push_back(shape_dict.at(source->id())); } - std::string key = - pe::GenerateX86ConvKey(inputs_shape[0], inputs_shape[1], stride, padding, dilation, index++, model_name); + std::string key = pe::GenerateX86ConvKey(inputs_shape[0], + inputs_shape[1], + stride, + padding, + dilation, + index++, + model_name); VLOG(3) << "key: " << key; node->attrs.attr_store["key"] = key; } @@ -176,18 +206,22 @@ void AlterLayoutPass(Graph* graph) { auto node = store_nodes[i]->safe_as(); if (node) { if (node->op()->name == "conv2d") { - CHECK(node->attrs.attr_store.count("data_format")) << node->op()->name << " op has no data_format attr"; - std::string data_format = absl::get(node->attrs.attr_store.at("data_format")); + CHECK(node->attrs.attr_store.count("data_format")) + << node->op()->name << " op has no data_format attr"; + std::string data_format = + absl::get(node->attrs.attr_store.at("data_format")); if (data_format != "NCHW") { // not NCHW such as NHWC or has already been altered layout continue; } - has_altered = true; + has_altered = true; std::string new_op_type = node->op()->name + "_NCHWc"; // alter conv2d op to conv2d_NCHWc - Node* new_node = new Node(Operator::Get(new_op_type), new_op_type, common::UniqName(new_op_type)); + Node* new_node = new Node(Operator::Get(new_op_type), + new_op_type, + common::UniqName(new_op_type)); new_node->attrs.attr_store = node->attrs.attr_store; - std::string new_data_format = "NCHWc"; + std::string new_data_format = "NCHWc"; new_node->attrs.attr_store["data_format"] = new_data_format; const auto& conv_inlinks = node->inlinks_in_order(); @@ -197,23 +231,29 @@ void AlterLayoutPass(Graph* graph) { input_nodes.push_back(source); } // get new layout: ic_bn, oc_bn - CHECK_EQ(input_nodes.size(), 2U) << "conv2d should have 2 input nodes"; - auto* input_node = input_nodes[0]; + CHECK_EQ(input_nodes.size(), 2U) + << "conv2d should have 2 input nodes"; + auto* input_node = input_nodes[0]; auto* weight_node = input_nodes[1]; - CHECK(shape_dict.count(input_node->id())) << input_node->id() << " has no infershape"; - CHECK(shape_dict.count(weight_node->id())) << weight_node->id() << " has no infershape"; - CHECK(type_dict.count(input_node->id())) << input_node->id() << " has no infertype"; - CHECK(type_dict.count(weight_node->id())) << weight_node->id() << " has no infertype"; - auto input_shape = shape_dict.at(input_node->id()); + CHECK(shape_dict.count(input_node->id())) + << input_node->id() << " has no infershape"; + CHECK(shape_dict.count(weight_node->id())) + << weight_node->id() << " has no infershape"; + CHECK(type_dict.count(input_node->id())) + << input_node->id() << " has no infertype"; + CHECK(type_dict.count(weight_node->id())) + << weight_node->id() << " has no infertype"; + auto input_shape = shape_dict.at(input_node->id()); auto weight_shape = shape_dict.at(weight_node->id()); - auto input_type = type_dict.at(input_node->id()); - auto weight_type = type_dict.at(weight_node->id()); + auto input_type = type_dict.at(input_node->id()); + auto weight_type = type_dict.at(weight_node->id()); Node* weight_trans_node; Node* input_trans_node; std::vector conv2d_NCHWc_inputshapes; std::vector conv2d_NCHWc_inputtypes; std::vector conv2d_NCHWc_inputlayouts; - CHECK(weight_shape.size() == 4) << "old conv2d's weight shape should be 4"; + CHECK(weight_shape.size() == 4) + << "old conv2d's weight shape should be 4"; absl::flat_hash_map conv2d_factors; int oc, fc, ic = 1; if (input_shape.size() == 4) { @@ -221,7 +261,9 @@ void AlterLayoutPass(Graph* graph) { } else if (input_shape.size() == 5) { ic = input_shape[1] * input_shape[4]; } else { - LOG(FATAL) << "conv2d's input shape should be 4D/5D. Wrong input shape: " << utils::Join(input_shape, ", "); + LOG(FATAL) + << "conv2d's input shape should be 4D/5D. Wrong input shape: " + << utils::Join(input_shape, ", "); } if (weight_shape.size() == 4) { @@ -231,18 +273,29 @@ void AlterLayoutPass(Graph* graph) { oc = weight_shape[0] * weight_shape[5]; fc = weight_shape[1] * weight_shape[4]; } else { - LOG(FATAL) << "conv2d's weight shape should be 4D/6D. Wrong weight shape: " - << utils::Join(weight_shape, ", "); + LOG(FATAL) + << "conv2d's weight shape should be 4D/6D. Wrong weight shape: " + << utils::Join(weight_shape, ", "); } VLOG(3) << "oc: " << oc; VLOG(3) << "ic: " << ic; VLOG(3) << "fc: " << fc; // get the original conv config stored in the key attr - CHECK(new_node->attrs.attr_store.count("key")) << "conv2d finds no key attr"; - std::string key = absl::get(new_node->attrs.attr_store.at("key")); + CHECK(new_node->attrs.attr_store.count("key")) + << "conv2d finds no key attr"; + std::string key = + absl::get(new_node->attrs.attr_store.at("key")); VLOG(3) << "key: " << key; - pe::GetConv2dFactors(&conv2d_factors, oc, ic, fc, -1, -1, input_type, graph->target_, key); + pe::GetConv2dFactors(&conv2d_factors, + oc, + ic, + fc, + -1, + -1, + input_type, + graph->target_, + key); CHECK(conv2d_factors.count("oc_bn")); CHECK(conv2d_factors.count("ic_bn")); CHECK(conv2d_factors.count("fc_bn")); @@ -262,13 +315,15 @@ void AlterLayoutPass(Graph* graph) { CHECK(input_data); NodeData* output_data; std::tie(input_trans_node, output_data) = - InsertLayoutTransformNodeAfter(graph, - input_data, - node, - 0, - src_input_layout, - dst_input_layout, - common::UniqName(node->op()->name + "_input_layout_tranform")); + InsertLayoutTransformNodeAfter( + graph, + input_data, + node, + 0, + src_input_layout, + dst_input_layout, + common::UniqName(node->op()->name + + "_input_layout_tranform")); UpdateInferInfos(input_trans_node, {input_shape}, {input_type}, @@ -280,36 +335,43 @@ void AlterLayoutPass(Graph* graph) { &shape_dict, &type_dict, &layout_dict); - CHECK(shape_dict.count(output_data->id())) << output_data->id() << " finds no infershape in shape_dict."; - CHECK(type_dict.count(output_data->id())) << output_data->id() << " finds no infertype in shape_dict."; + CHECK(shape_dict.count(output_data->id())) + << output_data->id() << " finds no infershape in shape_dict."; + CHECK(type_dict.count(output_data->id())) + << output_data->id() << " finds no infertype in shape_dict."; auto trans_out_shapes = shape_dict[output_data->id()]; auto trans_out_dtypes = type_dict[output_data->id()]; conv2d_NCHWc_inputshapes.push_back(trans_out_shapes); conv2d_NCHWc_inputtypes.push_back(trans_out_dtypes); conv2d_NCHWc_inputlayouts.push_back(dst_input_layout); } else { - CHECK_EQ(input_shape.size(), 5U) << "conv2d_NCHWc op's input shape dim should be 5"; + CHECK_EQ(input_shape.size(), 5U) + << "conv2d_NCHWc op's input shape dim should be 5"; conv2d_NCHWc_inputshapes.push_back(input_shape); conv2d_NCHWc_inputtypes.push_back(input_type); - CHECK(layout_dict.count(input_node->id())) << input_node->id() << " should have out_layout attr"; + CHECK(layout_dict.count(input_node->id())) + << input_node->id() << " should have out_layout attr"; conv2d_NCHWc_inputlayouts.push_back(layout_dict[input_node->id()]); } if (weight_shape.size() == 4) { std::string src_kernel_layout = "OIHW"; - std::string dst_kernel_layout = "OIHW" + std::to_string(fc_bn) + "i" + std::to_string(oc_bn) + "o"; + std::string dst_kernel_layout = "OIHW" + std::to_string(fc_bn) + + "i" + std::to_string(oc_bn) + "o"; VLOG(3) << "dst_kernel_layout: " << dst_kernel_layout; // insert weight layout_transform auto weight_data = weight_node->safe_as(); CHECK(weight_data); NodeData* output_data; std::tie(weight_trans_node, output_data) = - InsertLayoutTransformNodeAfter(graph, - weight_data, - node, - 1, - src_kernel_layout, - dst_kernel_layout, - common::UniqName(node->op()->name + "_weight_layout_tranform")); + InsertLayoutTransformNodeAfter( + graph, + weight_data, + node, + 1, + src_kernel_layout, + dst_kernel_layout, + common::UniqName(node->op()->name + + "_weight_layout_tranform")); UpdateInferInfos(weight_trans_node, {weight_shape}, {weight_type}, @@ -321,22 +383,27 @@ void AlterLayoutPass(Graph* graph) { &shape_dict, &type_dict, &layout_dict); - CHECK(shape_dict.count(output_data->id())) << output_data->id() << " finds no infershape in shape_dict."; - CHECK(type_dict.count(output_data->id())) << output_data->id() << " finds no infertype in shape_dict."; + CHECK(shape_dict.count(output_data->id())) + << output_data->id() << " finds no infershape in shape_dict."; + CHECK(type_dict.count(output_data->id())) + << output_data->id() << " finds no infertype in shape_dict."; auto trans_out_shapes = shape_dict[output_data->id()]; auto trans_out_dtypes = type_dict[output_data->id()]; conv2d_NCHWc_inputshapes.push_back(trans_out_shapes); conv2d_NCHWc_inputtypes.push_back(trans_out_dtypes); conv2d_NCHWc_inputlayouts.push_back(dst_kernel_layout); } else { - CHECK_EQ(weight_shape.size(), 6U) << weight_node->id() << " shape dim should be 6"; + CHECK_EQ(weight_shape.size(), 6U) + << weight_node->id() << " shape dim should be 6"; conv2d_NCHWc_inputshapes.push_back(weight_shape); conv2d_NCHWc_inputtypes.push_back(weight_type); - CHECK(layout_dict.count(weight_node->id())) << weight_node->id() << " should have out_layout attr"; + CHECK(layout_dict.count(weight_node->id())) + << weight_node->id() << " should have out_layout attr"; conv2d_NCHWc_inputlayouts.push_back(layout_dict[weight_node->id()]); } // replace conv2d to conv2d_NCHWc - auto infershapes = op_infershape[new_node->op()](conv2d_NCHWc_inputshapes, new_node->attrs.attr_store); + auto infershapes = op_infershape[new_node->op()]( + conv2d_NCHWc_inputshapes, new_node->attrs.attr_store); const auto& old_inlinks = node->inlinks_in_order(); const auto& old_outlinks = node->outlinks_in_order(); for (auto& link : old_inlinks) { @@ -360,13 +427,17 @@ void AlterLayoutPass(Graph* graph) { count++; } for (int i = 1; i < infershapes.size(); i++) { - auto* new_out = - new NodeData(node_ptr, i, 0, common::UniqName(new_node->id() + "_out_" + std::to_string(i))); + auto* new_out = new NodeData( + node_ptr, + i, + 0, + common::UniqName(new_node->id() + "_out_" + std::to_string(i))); graph->RegisterNode(new_out->id(), new_out); new_node->as()->LinkTo(new_out); } graph->RegisterNode(new_node->id(), new_node); - // update conv2d_NCHWc's infershape, infertype, inferlayout and set attrs + // update conv2d_NCHWc's infershape, infertype, inferlayout and set + // attrs UpdateInferInfos(new_node, conv2d_NCHWc_inputshapes, conv2d_NCHWc_inputtypes, @@ -385,8 +456,10 @@ void AlterLayoutPass(Graph* graph) { std::vector input_layouts; for (auto& link : node->inlinks_in_order()) { auto* source = link->source(); - CHECK(shape_dict.count(source->id())) << source->id() << " finds no infershape"; - CHECK(type_dict.count(source->id())) << source->id() << " finds no infertype"; + CHECK(shape_dict.count(source->id())) + << source->id() << " finds no infershape"; + CHECK(type_dict.count(source->id())) + << source->id() << " finds no infertype"; input_shapes.push_back(shape_dict[source->id()]); input_types.push_back(type_dict[source->id()]); if (layout_dict.count(source->id())) { @@ -395,12 +468,15 @@ void AlterLayoutPass(Graph* graph) { input_layouts.push_back(""); } } - CHECK(op_inferlayout[node->op()]) << "find no InferLayout function for op " << node->op()->name; - auto inferlayouts = op_inferlayout[node->op()](input_shapes, input_layouts, node->attrs, graph->target_); - // if input inferred layouts is different from original's, expand dims or do transformation. + CHECK(op_inferlayout[node->op()]) + << "find no InferLayout function for op " << node->op()->name; + auto inferlayouts = op_inferlayout[node->op()]( + input_shapes, input_layouts, node->attrs, graph->target_); + // if input inferred layouts is different from original's, expand dims + // or do transformation. CHECK_EQ(inferlayouts.size(), 2U); auto new_input_layouts = inferlayouts[1]; - auto inlinks = node->inlinks_in_order(); + auto inlinks = node->inlinks_in_order(); CHECK_EQ(input_layouts.size(), inlinks.size()); CHECK_EQ(input_layouts.size(), new_input_layouts.size()); CHECK_EQ(input_layouts.size(), input_shapes.size()); @@ -410,11 +486,14 @@ void AlterLayoutPass(Graph* graph) { // expand dims or do transformation int input_shape_size = input_shapes[i].size(); if (input_shape_size == 1 && new_input_layouts[i].size() > 4) { - // C -> NCHWxc: 1. C -> NCHW 2. layout transform from NCHW to NCHWxc + // C -> NCHWxc: 1. C -> NCHW 2. layout transform from NCHW to + // NCHWxc int axis = -1; - CHECK(node->attrs.attr_store.count("axis")) << node->id() << " find no axis attr"; + CHECK(node->attrs.attr_store.count("axis")) + << node->id() << " find no axis attr"; axis = absl::get(node->attrs.attr_store["axis"]); - CHECK(new_input_layouts[i].substr(0, 4) == "NCHW") << "only support NCHWxc"; + CHECK(new_input_layouts[i].substr(0, 4) == "NCHW") + << "only support NCHWxc"; if (axis == -1) { axis += 4; } @@ -427,19 +506,22 @@ void AlterLayoutPass(Graph* graph) { } } // C -> NCHW, insert layout tranfrom - auto source = inlinks[i]->source(); - std::string src_layout = "C"; + auto source = inlinks[i]->source(); + std::string src_layout = "C"; layout_dict[source->id()] = src_layout; - auto input_data = source->safe_as(); + auto input_data = source->safe_as(); CHECK(input_data); VLOG(3) << source->id() << " do layout_tranform from C to NCHW"; std::string op_type = "broadcast_to"; auto trans_node = - new Node(Operator::Get(op_type), op_type, common::UniqName(source->id() + "_broadcastto")); - trans_node->attrs.attr_store["out_shape"] = new_shapes; - std::vector broadcast_axes = {1}; + new Node(Operator::Get(op_type), + op_type, + common::UniqName(source->id() + "_broadcastto")); + trans_node->attrs.attr_store["out_shape"] = new_shapes; + std::vector broadcast_axes = {1}; trans_node->attrs.attr_store["broadcast_axes"] = broadcast_axes; - auto output_data = InsertGraphOpNodeAfter(graph, trans_node, input_data, node, i); + auto output_data = InsertGraphOpNodeAfter( + graph, trans_node, input_data, node, i); UpdateInferInfos(trans_node, {input_shapes[i]}, {input_types[i]}, @@ -453,21 +535,24 @@ void AlterLayoutPass(Graph* graph) { &layout_dict); std::string new_src_layout = "NCHW"; - reset_axis = true; + reset_axis = true; // insert layout tranfrom auto new_input_data = output_data->safe_as(); CHECK(new_input_data); NodeData* new_output_data; Node* new_trans_node; - VLOG(3) << new_input_data->id() << " do layout_tranform from NCHW to NCHWxc"; + VLOG(3) << new_input_data->id() + << " do layout_tranform from NCHW to NCHWxc"; std::tie(new_trans_node, new_output_data) = - InsertLayoutTransformNodeAfter(graph, - new_input_data, - node, - i, - new_src_layout, - new_input_layouts[i], - common::UniqName(new_input_data->id() + "_layout_tranform")); + InsertLayoutTransformNodeAfter( + graph, + new_input_data, + node, + i, + new_src_layout, + new_input_layouts[i], + common::UniqName(new_input_data->id() + + "_layout_tranform")); UpdateInferInfos(new_trans_node, {shape_dict[new_input_data->id()]}, {input_types[i]}, @@ -479,25 +564,28 @@ void AlterLayoutPass(Graph* graph) { &shape_dict, &type_dict, &layout_dict); - } else if (input_shape_size == 4 && new_input_layouts[i].size() > 4) { + } else if (input_shape_size == 4 && + new_input_layouts[i].size() > 4) { // NCHW -> NCHWxc // insert layout tranfrom - auto source = inlinks[i]->source(); - auto src_layout = "NCHW"; + auto source = inlinks[i]->source(); + auto src_layout = "NCHW"; layout_dict[source->id()] = src_layout; - auto input_data = source->safe_as(); + auto input_data = source->safe_as(); CHECK(input_data); NodeData* output_data; Node* trans_node; - VLOG(3) << source->id() << " do layout_tranform from NCHW to NCHWxc"; + VLOG(3) << source->id() + << " do layout_tranform from NCHW to NCHWxc"; std::tie(trans_node, output_data) = - InsertLayoutTransformNodeAfter(graph, - input_data, - node, - i, - src_layout, - new_input_layouts[i], - common::UniqName(source->id() + "_layout_tranform")); + InsertLayoutTransformNodeAfter( + graph, + input_data, + node, + i, + src_layout, + new_input_layouts[i], + common::UniqName(source->id() + "_layout_tranform")); UpdateInferInfos(trans_node, {input_shapes[i]}, {input_types[i]}, @@ -509,25 +597,28 @@ void AlterLayoutPass(Graph* graph) { &shape_dict, &type_dict, &layout_dict); - } else if (input_shape_size == 5 && new_input_layouts[i].size() == 4) { + } else if (input_shape_size == 5 && + new_input_layouts[i].size() == 4) { // NCHWxc -> NCHW // insert layout tranfrom - auto source = inlinks[i]->source(); - auto src_layout = input_layouts[i]; + auto source = inlinks[i]->source(); + auto src_layout = input_layouts[i]; layout_dict[source->id()] = src_layout; - auto input_data = source->safe_as(); + auto input_data = source->safe_as(); CHECK(input_data); NodeData* output_data; Node* trans_node; - VLOG(3) << source->id() << " do layout_tranform from NCHWxc to NCHW"; + VLOG(3) << source->id() + << " do layout_tranform from NCHWxc to NCHW"; std::tie(trans_node, output_data) = - InsertLayoutTransformNodeAfter(graph, - input_data, - node, - i, - src_layout, - new_input_layouts[i], - common::UniqName(source->id() + "_layout_tranform")); + InsertLayoutTransformNodeAfter( + graph, + input_data, + node, + i, + src_layout, + new_input_layouts[i], + common::UniqName(source->id() + "_layout_tranform")); UpdateInferInfos(trans_node, {input_shapes[i]}, {input_types[i]}, @@ -550,8 +641,10 @@ void AlterLayoutPass(Graph* graph) { input_layouts.clear(); for (auto& link : node->inlinks_in_order()) { auto* source = link->source(); - CHECK(shape_dict.count(source->id())) << source->id() << " finds no infershape"; - CHECK(type_dict.count(source->id())) << source->id() << " finds no infertype"; + CHECK(shape_dict.count(source->id())) + << source->id() << " finds no infershape"; + CHECK(type_dict.count(source->id())) + << source->id() << " finds no infertype"; input_shapes.push_back(shape_dict[source->id()]); input_types.push_back(type_dict[source->id()]); if (layout_dict.count(source->id())) { @@ -580,35 +673,41 @@ void AlterLayoutPass(Graph* graph) { for (int i = store_nodes.size() - 1; i >= 0; i--) { auto* node = store_nodes[i]->safe_as(); if (node) { - CHECK(node->attrs.attr_store.count("out_layouts")) << node->id() << " finds no out_layouts attr"; - auto out_layouts = absl::get>(node->attrs.attr_store.at("out_layouts")); + CHECK(node->attrs.attr_store.count("out_layouts")) + << node->id() << " finds no out_layouts attr"; + auto out_layouts = absl::get>( + node->attrs.attr_store.at("out_layouts")); CHECK(!out_layouts.empty()); if (out_layouts[0].size() > 4) { // recover the layout finally, NCHWxc->NCHW, only first output auto outlinks = node->outlinks_in_order(); CHECK(!outlinks.empty()); - auto* out_node = outlinks[0]->sink(); + auto* out_node = outlinks[0]->sink(); std::string dst_layout = "NCHW"; - CHECK(layout_dict.count(out_node->id())) << out_node->id() << " finds no out_layout"; + CHECK(layout_dict.count(out_node->id())) + << out_node->id() << " finds no out_layout"; std::string src_layout = layout_dict[out_node->id()]; // insert layout_transform NodeData* temp_out; Node* trans_node; - CHECK(shape_dict.count(out_node->id())) << out_node->id() << " finds no infershape"; - CHECK(type_dict.count(out_node->id())) << out_node->id() << " finds no infertype"; + CHECK(shape_dict.count(out_node->id())) + << out_node->id() << " finds no infershape"; + CHECK(type_dict.count(out_node->id())) + << out_node->id() << " finds no infertype"; auto shape = shape_dict[out_node->id()]; - auto type = type_dict[out_node->id()]; - // insert layout transform before the output var to keep the final original output var - std::tie(trans_node, temp_out) = - InsertLayoutTransformNodeBefore(graph, - node, - out_node->safe_as(), - 0, - src_layout, - dst_layout, - common::UniqName(node->op()->name + "_final_layout_tranform")); - shape_dict[temp_out->id()] = shape; - type_dict[temp_out->id()] = type; + auto type = type_dict[out_node->id()]; + // insert layout transform before the output var to keep the final + // original output var + std::tie(trans_node, temp_out) = InsertLayoutTransformNodeBefore( + graph, + node, + out_node->safe_as(), + 0, + src_layout, + dst_layout, + common::UniqName(node->op()->name + "_final_layout_tranform")); + shape_dict[temp_out->id()] = shape; + type_dict[temp_out->id()] = type; layout_dict[temp_out->id()] = src_layout; UpdateInferInfos(trans_node, {shape}, @@ -626,8 +725,8 @@ void AlterLayoutPass(Graph* graph) { } } graph->ClearUnlinkedNodes(&shape_dict, &type_dict, &layout_dict); - graph->attrs["infershape"] = std::make_shared(shape_dict); - graph->attrs["inferdtype"] = std::make_shared(type_dict); + graph->attrs["infershape"] = std::make_shared(shape_dict); + graph->attrs["inferdtype"] = std::make_shared(type_dict); graph->attrs["inferlayout"] = std::make_shared(layout_dict); } } @@ -639,7 +738,8 @@ void AlterLayoutPass(Graph* graph) { CINN_REGISTER_HELPER(AlterLayout) { CINN_REGISTER_PASS(AlterLayout) .describe( - "This pass alters ops' data layouts in the graph(e.g. NCHW -> NCHWxc, OIHW -> OIHWxoxi) and saves to " + "This pass alters ops' data layouts in the graph(e.g. NCHW -> " + "NCHWxc, OIHW -> OIHWxoxi) and saves to " "g.attrs[\"inferlayout\"]") .set_change_structure(true) .provide_graph_attr("infershape") diff --git a/paddle/cinn/hlir/pass/alterlayout_test.cc b/paddle/cinn/hlir/pass/alterlayout_test.cc index b979af059512d..45ec29d061c74 100755 --- a/paddle/cinn/hlir/pass/alterlayout_test.cc +++ b/paddle/cinn/hlir/pass/alterlayout_test.cc @@ -57,11 +57,11 @@ TEST(conv, conv) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; auto c = program.conv2d(A, B, attrs); @@ -101,11 +101,11 @@ TEST(conv_relu_conv, conv_relu_conv) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; auto c = program.conv2d(A, B, attrs); auto d = program.relu(c); @@ -150,11 +150,11 @@ TEST(conv_add_conv, conv_add_conv) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; auto c = program.conv2d(A, B, attrs); auto d = program.elementwise_add(c, C, 1); @@ -203,11 +203,11 @@ TEST(conv_bn_conv, conv_bn_conv) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; absl::flat_hash_map attrs1; attrs1["epsilon"] = (float)0.001; @@ -255,18 +255,18 @@ TEST(conv_pool2d_conv, conv_pool2d_conv) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; absl::flat_hash_map attrs2; - attrs2["stride_size"] = std::vector({2, 2}); + attrs2["stride_size"] = std::vector({2, 2}); attrs2["padding_size"] = std::vector({1, 1, 1, 1}); - attrs2["kernel_size"] = std::vector({3, 3}); - std::string pool_type = "max"; - attrs2["pool_type"] = pool_type; + attrs2["kernel_size"] = std::vector({3, 3}); + std::string pool_type = "max"; + attrs2["pool_type"] = pool_type; auto c = program.conv2d(A, B, attrs); auto d = program.pool2d(c, attrs2); @@ -310,11 +310,11 @@ TEST(conv_softmax_conv, conv_softmax_conv) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; absl::flat_hash_map attrs1; attrs1["axis"] = (int)-1; @@ -361,11 +361,11 @@ TEST(conv_sigmoid_conv, conv_sigmoid_conv) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; auto c = program.conv2d(A, B, attrs); auto d = program.sigmoid(c); @@ -410,11 +410,11 @@ TEST(conv_mul_conv, conv_mul_conv) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; absl::flat_hash_map attrs1; attrs1["axis"] = (int)-1; diff --git a/paddle/cinn/hlir/pass/check_fusion_accuracy_pass.cc b/paddle/cinn/hlir/pass/check_fusion_accuracy_pass.cc index 3194a1b06b47d..c2d6249c28a47 100644 --- a/paddle/cinn/hlir/pass/check_fusion_accuracy_pass.cc +++ b/paddle/cinn/hlir/pass/check_fusion_accuracy_pass.cc @@ -41,7 +41,7 @@ using cinn::hlir::framework::GenerateAccCheckNodeId; using common::GraphEdge; using common::GraphNode; -using GroupPtr = std::shared_ptr; +using GroupPtr = std::shared_ptr; using GroupList = std::vector; using ShapeDict = absl::flat_hash_map; @@ -52,10 +52,13 @@ class AssertMsg { public: AssertMsg(int group_id) : group_id_(group_id) {} - void SetMsg(const std::string& title, const std::string& msg) { msg_info_[title] = msg; } + void SetMsg(const std::string& title, const std::string& msg) { + msg_info_[title] = msg; + } const std::string& GetMsg(const std::string& title) const { - CHECK(msg_info_.count(title)) << "Msg of group " << group_id_ << " not has title: " << title; + CHECK(msg_info_.count(title)) + << "Msg of group " << group_id_ << " not has title: " << title; return msg_info_.at(title); } @@ -113,15 +116,20 @@ class CheckFusionAccuracyPass { std::pair CreateAllNode(const std::string& node_id); - std::pair CreateAssertNode(const std::string& node_id, utils::AssertMsg* assert_msg); + std::pair CreateAssertNode(const std::string& node_id, + utils::AssertMsg* assert_msg); // the AssertAllClose operator are composed of isclose+all+assert - std::vector CreateAssertAllClose(const std::string& node_id, - utils::AssertMsg* assert_msg, - const std::vector& inputs); + std::vector CreateAssertAllClose( + const std::string& node_id, + utils::AssertMsg* assert_msg, + const std::vector& inputs); - // link origin group's output and pass group's output to the AssertAllClose nodes - GroupList LinkToAssertAllClose(const std::unordered_set& group_outputs, utils::AssertMsg* msg); + // link origin group's output and pass group's output to the AssertAllClose + // nodes + GroupList LinkToAssertAllClose( + const std::unordered_set& group_outputs, + utils::AssertMsg* msg); // skip check some op and var, now only support check float dtype bool IsSkipVar(const NodeData* var); @@ -139,16 +147,20 @@ class CheckFusionAccuracyPass { std::atomic_int CheckFusionAccuracyPass::key_count_{0}; -bool CheckFusionAccuracyPass::IsSkipVar(const NodeData* var) { return !dtype_dict_.at(var->id()).is_float(); } +bool CheckFusionAccuracyPass::IsSkipVar(const NodeData* var) { + return !dtype_dict_.at(var->id()).is_float(); +} std::string CheckFusionAccuracyPass::DebugNodeData(NodeData* node) { std::stringstream ss; - ss << node->id() << "{shape=[" << cinn::utils::Join(shape_dict_.at(node->id()), ", ") + ss << node->id() << "{shape=[" + << cinn::utils::Join(shape_dict_.at(node->id()), ", ") << "], dtype=" << dtype_dict_.at(node->id()) << "}"; return ss.str(); } -NodeData* CheckFusionAccuracyPass::CreateOutputNode(NodePtr node, const std::string& output_id) { +NodeData* CheckFusionAccuracyPass::CreateOutputNode( + NodePtr node, const std::string& output_id) { // create node's output data node auto node_id = output_id; if (node_id.empty()) { @@ -156,7 +168,8 @@ NodeData* CheckFusionAccuracyPass::CreateOutputNode(NodePtr node, const std::str } CHECK(graph_->RetrieveNode(node_id) == nullptr) - << "The node " << node->op()->name << "'s output " << node_id << " had been registered in graph! Please check."; + << "The node " << node->op()->name << "'s output " << node_id + << " had been registered in graph! Please check."; auto* output_data = new NodeData(node, 0, 0, node_id); node->LinkTo(output_data); @@ -165,23 +178,26 @@ NodeData* CheckFusionAccuracyPass::CreateOutputNode(NodePtr node, const std::str return output_data; } -void CheckFusionAccuracyPass::CreateCheckNodeOutputs(Node* old_node, NodePtr new_node) { +void CheckFusionAccuracyPass::CreateCheckNodeOutputs(Node* old_node, + NodePtr new_node) { const auto& outlinks = old_node->outlinks_in_order(); for (const auto& out_edge : outlinks) { auto out_node = out_edge->sink()->safe_as(); - CHECK(out_node) << "Node " << old_node->id() << "'s output node is nullptr! Please check."; + CHECK(out_node) << "Node " << old_node->id() + << "'s output node is nullptr! Please check."; const auto& out_node_id = out_node->id(); // If the check node's output variable node not created if (!FusionHelperBase::IsConstOp(old_node)) { - // note the const op will recompute in group, so that the op may appear in many group - // CHECK_EQ(old2new_nodedata_map_.count(out_node), 0) - // << "Var " << out_node_id << " repeated! The graph is not a SSA graph! Please check."; + // note the const op will recompute in group, so that the op may appear in + // many group CHECK_EQ(old2new_nodedata_map_.count(out_node), 0) + // << "Var " << out_node_id << " repeated! The graph is not a SSA + // graph! Please check."; } const auto& check_out_node_id = GenerateAccCheckNodeId(out_node_id); - auto check_out_node = CreateOutputNode(new_node, check_out_node_id); + auto check_out_node = CreateOutputNode(new_node, check_out_node_id); check_out_node->output_index = out_node->output_index; auto check_out_shape = shape_dict_.at(out_node_id); @@ -190,18 +206,21 @@ void CheckFusionAccuracyPass::CreateCheckNodeOutputs(Node* old_node, NodePtr new auto check_out_dtype = dtype_dict_.at(out_node_id); dtype_dict_.emplace(check_out_node_id, std::move(check_out_dtype)); - VLOG(4) << "Create the check fusion accuracy node of node " << old_node->id() << "'s output node " - << DebugNodeData(out_node) << " success, which is " << DebugNodeData(check_out_node); + VLOG(4) << "Create the check fusion accuracy node of node " + << old_node->id() << "'s output node " << DebugNodeData(out_node) + << " success, which is " << DebugNodeData(check_out_node); old2new_nodedata_map_[out_node] = check_out_node; } } -void CheckFusionAccuracyPass::RelinkNodeInputs(Node* old_node, NodePtr new_node) { +void CheckFusionAccuracyPass::RelinkNodeInputs(Node* old_node, + NodePtr new_node) { const auto& inlinks = old_node->inlinks_in_order(); for (const auto& in_edge : inlinks) { auto in_node = in_edge->source()->safe_as(); - CHECK(in_node) << "Node " << old_node->id() << "'s input node is nullptr! Please check."; + CHECK(in_node) << "Node " << old_node->id() + << "'s input node is nullptr! Please check."; if (old2new_nodedata_map_.count(in_node)) { old2new_nodedata_map_[in_node]->LinkTo(new_node.get()); @@ -212,15 +231,17 @@ void CheckFusionAccuracyPass::RelinkNodeInputs(Node* old_node, NodePtr new_node) } NodePtr CheckFusionAccuracyPass::CreateCheckNode(Node* node) { - CHECK(node->op()) << "Node " << node->id() << " is not operator! Please check."; + CHECK(node->op()) << "Node " << node->id() + << " is not operator! Please check."; const auto& check_node_id = GenerateAccCheckNodeId(node->id()); CHECK(graph_->RetrieveNode(check_node_id) == nullptr) - << "The node " << node->id() << "'s check fusion accuracy node" << check_node_id - << " had been registered in graph! Please check."; + << "The node " << node->id() << "'s check fusion accuracy node" + << check_node_id << " had been registered in graph! Please check."; - auto check_node = Node::Create(node->op(), GenerateAccCheckNodeId(node->attrs.node_name), check_node_id); + auto check_node = Node::Create( + node->op(), GenerateAccCheckNodeId(node->attrs.node_name), check_node_id); check_node->attrs.attr_store = node->attrs.attr_store; graph_->RegisterNode(check_node_id, check_node.get()); @@ -228,19 +249,23 @@ NodePtr CheckFusionAccuracyPass::CreateCheckNode(Node* node) { CreateCheckNodeOutputs(node, check_node); RelinkNodeInputs(node, check_node); - VLOG(4) << "Create node " << framework::DebugString(node) << "'s check fusion accuracy node success, which is " + VLOG(4) << "Create node " << framework::DebugString(node) + << "'s check fusion accuracy node success, which is " << framework::DebugString(check_node.get()); return check_node; } OpPatternKind CheckFusionAccuracyPass::GetOpKind(const framework::Node* node) { - auto op_pattern_dict_ = &framework::Operator::GetAttrs("OpPattern"); - CHECK(op_pattern_dict_->Find(node->op())) << "Don't find the pattern of op : " << node->id(); + auto op_pattern_dict_ = + &framework::Operator::GetAttrs("OpPattern"); + CHECK(op_pattern_dict_->Find(node->op())) + << "Don't find the pattern of op : " << node->id(); auto kind = op_pattern_dict_[0][node->op()]; if (kind == framework::kBroadcast) { - // As binary op was defined as broadcast, actually it should be element-wise. + // As binary op was defined as broadcast, actually it should be + // element-wise. if (node->op()->name != "broadcast_to") { return framework::kElementWise; } @@ -250,7 +275,7 @@ OpPatternKind CheckFusionAccuracyPass::GetOpKind(const framework::Node* node) { } GroupPtr CheckFusionAccuracyPass::CreateSingleNodeGroup(NodePtr node_ptr) { - auto node = node_ptr.get(); + auto node = node_ptr.get(); auto group = std::make_shared(); // init group group->nodes.push_back(node); @@ -259,7 +284,7 @@ GroupPtr CheckFusionAccuracyPass::CreateSingleNodeGroup(NodePtr node_ptr) { // input node for (auto& edge : node->inlinks()) { auto input_graph_node = edge->source(); - auto input_node_data = input_graph_node->safe_as(); + auto input_node_data = input_graph_node->safe_as(); CHECK(input_node_data); // input data has no source node if (input_node_data->source_node.get()) { @@ -276,16 +301,22 @@ GroupPtr CheckFusionAccuracyPass::CreateSingleNodeGroup(NodePtr node_ptr) { return group; } -std::pair CheckFusionAccuracyPass::CreateIsCloseNode(const std::string& node_id) { +std::pair CheckFusionAccuracyPass::CreateIsCloseNode( + const std::string& node_id) { const auto& is_close_node_id = "isclose_" + node_id; - auto is_close_node = Node::Create(Operator::Get("isclose"), GenerateAccCheckNodeId("isclose"), is_close_node_id); + auto is_close_node = Node::Create(Operator::Get("isclose"), + GenerateAccCheckNodeId("isclose"), + is_close_node_id); is_close_node->attrs.attr_store["rtol"] = - cinn::runtime::utils::AssertTrueMsgTool::GetInstance()->GetFlagValue("rtol"); + cinn::runtime::utils::AssertTrueMsgTool::GetInstance() + ->GetFlagValue("rtol"); is_close_node->attrs.attr_store["atol"] = - cinn::runtime::utils::AssertTrueMsgTool::GetInstance()->GetFlagValue("atol"); + cinn::runtime::utils::AssertTrueMsgTool::GetInstance() + ->GetFlagValue("atol"); is_close_node->attrs.attr_store["equal_nan"] = - cinn::runtime::utils::AssertTrueMsgTool::GetInstance()->GetFlagValue("equal_nan"); + cinn::runtime::utils::AssertTrueMsgTool::GetInstance() + ->GetFlagValue("equal_nan"); graph_->RegisterNode(is_close_node_id, is_close_node.get()); @@ -296,23 +327,27 @@ std::pair CheckFusionAccuracyPass::CreateIsCloseNode(const s shape_dict_.emplace(output_data->id(), std::move(check_out_shape)); dtype_dict_.emplace(output_data->id(), common::Bool()); - VLOG(4) << "Create node " << node_id << "'s isclose node success, whose id is " << is_close_node_id + VLOG(4) << "Create node " << node_id + << "'s isclose node success, whose id is " << is_close_node_id << ", whose output is " << DebugNodeData(output_data); return {is_close_node, output_data}; } -std::pair CheckFusionAccuracyPass::CreateAllNode(const std::string& node_id) { +std::pair CheckFusionAccuracyPass::CreateAllNode( + const std::string& node_id) { const auto& all_node_id = "all_" + node_id; - auto all_node = Node::Create(Operator::Get("reduce_all"), GenerateAccCheckNodeId("reduce_all"), all_node_id); + auto all_node = Node::Create(Operator::Get("reduce_all"), + GenerateAccCheckNodeId("reduce_all"), + all_node_id); int shape_size = shape_dict_[node_id].size(); std::vector axes(shape_size); for (int i = 0; i < shape_size; ++i) { axes[i] = i; } - all_node->attrs.attr_store["dim"] = axes; + all_node->attrs.attr_store["dim"] = axes; all_node->attrs.attr_store["keep_dim"] = false; graph_->RegisterNode(all_node_id, all_node.get()); @@ -323,23 +358,28 @@ std::pair CheckFusionAccuracyPass::CreateAllNode(const std:: shape_dict_.emplace(output_data->id(), framework::shape_t{1}); dtype_dict_.emplace(output_data->id(), common::Bool()); - VLOG(4) << "Create node " << node_id << "'s all node success, whose id is " << all_node_id << ", whose output is " - << DebugNodeData(output_data); + VLOG(4) << "Create node " << node_id << "'s all node success, whose id is " + << all_node_id << ", whose output is " << DebugNodeData(output_data); return {all_node, output_data}; } -std::pair CheckFusionAccuracyPass::CreateAssertNode(const std::string& node_id, - utils::AssertMsg* assert_msg) { +std::pair CheckFusionAccuracyPass::CreateAssertNode( + const std::string& node_id, utils::AssertMsg* assert_msg) { const auto& assert_node_id = "assert_" + node_id; - auto assert_node = Node::Create(Operator::Get("assert_true"), GenerateAccCheckNodeId("assert_true"), assert_node_id); - // TODO(thisjiang): change type from 'int' to 'std::string' when custom call support 'std::string' type - int msg_key = key_count_.fetch_add(1); + auto assert_node = Node::Create(Operator::Get("assert_true"), + GenerateAccCheckNodeId("assert_true"), + assert_node_id); + // TODO(thisjiang): change type from 'int' to 'std::string' when custom call + // support 'std::string' type + int msg_key = key_count_.fetch_add(1); assert_node->attrs.attr_store["msg"] = msg_key; - cinn::runtime::utils::AssertTrueMsgTool::GetInstance()->SetMsg(msg_key, assert_msg->str()); + cinn::runtime::utils::AssertTrueMsgTool::GetInstance()->SetMsg( + msg_key, assert_msg->str()); assert_node->attrs.attr_store["only_warning"] = - cinn::runtime::utils::AssertTrueMsgTool::GetInstance()->GetFlagValue("only_warning"); + cinn::runtime::utils::AssertTrueMsgTool::GetInstance() + ->GetFlagValue("only_warning"); graph_->RegisterNode(assert_node_id, assert_node.get()); @@ -349,15 +389,17 @@ std::pair CheckFusionAccuracyPass::CreateAssertNode(const st shape_dict_.emplace(output_data->id(), framework::shape_t{1}); dtype_dict_.emplace(output_data->id(), common::Bool()); - VLOG(4) << "Create node " << node_id << "'s assert node success, whose id is " << assert_node_id - << ", whose output is " << DebugNodeData(output_data); + VLOG(4) << "Create node " << node_id << "'s assert node success, whose id is " + << assert_node_id << ", whose output is " + << DebugNodeData(output_data); return {assert_node, output_data}; } -std::vector CheckFusionAccuracyPass::CreateAssertAllClose(const std::string& node_id, - utils::AssertMsg* assert_msg, - const std::vector& inputs) { +std::vector CheckFusionAccuracyPass::CreateAssertAllClose( + const std::string& node_id, + utils::AssertMsg* assert_msg, + const std::vector& inputs) { std::vector group_nodes; // create isclose + all + assert nodes // create isclose node and link inputs to the node @@ -372,7 +414,8 @@ std::vector CheckFusionAccuracyPass::CreateAssertAllClose(const std::st // check and create all node auto in_shape = shape_dict_[node_id]; - int prod_size = std::accumulate(in_shape.begin(), in_shape.end(), 1, std::multiplies()); + int prod_size = std::accumulate( + in_shape.begin(), in_shape.end(), 1, std::multiplies()); if (prod_size > 1) { // need reduce const auto& all_node = CreateAllNode(node_id); @@ -390,29 +433,35 @@ std::vector CheckFusionAccuracyPass::CreateAssertAllClose(const std::st return group_nodes; } -GroupList CheckFusionAccuracyPass::LinkToAssertAllClose(const std::unordered_set& group_outputs, - utils::AssertMsg* msg) { +GroupList CheckFusionAccuracyPass::LinkToAssertAllClose( + const std::unordered_set& group_outputs, utils::AssertMsg* msg) { GroupList assert_groups; for (auto* group_out : group_outputs) { const auto& out_node_id = group_out->id(); if (IsSkipVar(group_out)) { - LOG(WARNING) << "The CheckFusionAccuracyPass only support check float point dtype data now, skip check node \"" - << out_node_id << "\", who's dtype=" << dtype_dict_.at(out_node_id); + LOG(WARNING) << "The CheckFusionAccuracyPass only support check float " + "point dtype data now, skip check node \"" + << out_node_id + << "\", who's dtype=" << dtype_dict_.at(out_node_id); continue; } - CHECK(old2new_nodedata_map_.count(group_out)) << "The check fusion accuracy's node corresponding to " << out_node_id - << " had not been created! Please check."; - auto pass_out = old2new_nodedata_map_.at(group_out); + CHECK(old2new_nodedata_map_.count(group_out)) + << "The check fusion accuracy's node corresponding to " << out_node_id + << " had not been created! Please check."; + auto pass_out = old2new_nodedata_map_.at(group_out); const auto& acc_check_out_id = pass_out->id(); msg->SetMsg("Var Name", out_node_id); - msg->SetMsg("Suggestion", - cinn::utils::StringFormat("You can check the value by set FLAGS_cinn_self_check_accuracy and compare " - "the result between \"%s\" and \"%s\"", - out_node_id.c_str(), - acc_check_out_id.c_str())); + msg->SetMsg( + "Suggestion", + cinn::utils::StringFormat("You can check the value by set " + "FLAGS_cinn_self_check_accuracy and compare " + "the result between \"%s\" and \"%s\"", + out_node_id.c_str(), + acc_check_out_id.c_str())); - const auto& nodes = CreateAssertAllClose(acc_check_out_id, msg, {group_out, pass_out}); + const auto& nodes = + CreateAssertAllClose(acc_check_out_id, msg, {group_out, pass_out}); for (const auto& node : nodes) { assert_groups.emplace_back(CreateSingleNodeGroup(node)); @@ -421,9 +470,12 @@ GroupList CheckFusionAccuracyPass::LinkToAssertAllClose(const std::unordered_set return assert_groups; } -std::vector CheckFusionAccuracyPass::TopologicalOrder(const std::vector& nodes) { +std::vector CheckFusionAccuracyPass::TopologicalOrder( + const std::vector& nodes) { struct NodeCompare { - bool operator()(Node* lhs, Node* rhs) const { return lhs->id() < rhs->id(); } + bool operator()(Node* lhs, Node* rhs) const { + return lhs->id() < rhs->id(); + } }; std::set node_set(nodes.begin(), nodes.end()); @@ -469,14 +521,16 @@ std::vector CheckFusionAccuracyPass::TopologicalOrder(const std::vectorsink()->safe_as(); if (indegree.count(out_node) && (--indegree[out_node]) == 0) { - // if the output node in group and its input nodes are all visited, push + // if the output node in group and its input nodes are all visited, + // push queue.push_back(out_node); } } } } - CHECK_EQ(ordered_nodes.size(), nodes.size()) << "There has circle in group! Please check."; + CHECK_EQ(ordered_nodes.size(), nodes.size()) + << "There has circle in group! Please check."; return ordered_nodes; } @@ -497,17 +551,20 @@ GroupList CheckFusionAccuracyPass::Apply() { // fusion group only has one node, do not need check, skip if (group_nodes.size() <= 1) { - VLOG(4) << "The Group " << group->GetFuncName() << " just has one node, skip."; + VLOG(4) << "The Group " << group->GetFuncName() + << " just has one node, skip."; continue; } // split orign group and create group for each node const auto& ordered_nodes = TopologicalOrder(group_nodes); - VLOG(4) << "Check the accuracy of group " << graph_->DebugGroupedGraph(ordered_nodes); + VLOG(4) << "Check the accuracy of group " + << graph_->DebugGroupedGraph(ordered_nodes); for (auto* node : ordered_nodes) { if (node->is_variable()) { - VLOG(4) << "The node " << node->id() << " is variable, skip check fusion accuracy."; + VLOG(4) << "The node " << node->id() + << " is variable, skip check fusion accuracy."; continue; } @@ -521,18 +578,25 @@ GroupList CheckFusionAccuracyPass::Apply() { msg.SetMsg("Group id", std::to_string(i)); msg.SetMsg( "Group structure", - cinn::utils::StringFormat("\nGroup %d {\n%s}", i, graph_->DebugGroupedGraph(ordered_nodes, fetch_ids).c_str())); + cinn::utils::StringFormat( + "\nGroup %d {\n%s}", + i, + graph_->DebugGroupedGraph(ordered_nodes, fetch_ids).c_str())); // link the group's output data node to assert all close node - const auto& assert_group = LinkToAssertAllClose(group->GetOutputNodeDatas(), &msg); - check_fusion_groups.insert(check_fusion_groups.end(), assert_group.begin(), assert_group.end()); + const auto& assert_group = + LinkToAssertAllClose(group->GetOutputNodeDatas(), &msg); + check_fusion_groups.insert( + check_fusion_groups.end(), assert_group.begin(), assert_group.end()); i++; } return check_fusion_groups; } -void CheckFusionAccuracyPassImpl(Graph* graph) { graph->fusion_groups = CheckFusionAccuracyPass(graph).Apply(); } +void CheckFusionAccuracyPassImpl(Graph* graph) { + graph->fusion_groups = CheckFusionAccuracyPass(graph).Apply(); +} } // namespace cinn::hlir::pass diff --git a/paddle/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc b/paddle/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc index d3c063f1a03dc..dbad4bd00bfe6 100644 --- a/paddle/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc +++ b/paddle/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc @@ -37,11 +37,14 @@ int CountAfterPassNodeSize(Graph* graph) { output_size += group->GetOutputNodeDatas().size(); } - // CheckFusionAccuracyPass will split each group, and add isclose+all+assert node for each output + // CheckFusionAccuracyPass will split each group, and add isclose+all+assert + // node for each output return node_size + output_size * 3; } -void RunTest(const Target& target, const std::shared_ptr& graph, const std::vector& input_names) { +void RunTest(const Target& target, + const std::shared_ptr& graph, + const std::vector& input_names) { auto scope = BuildScope(target, graph); hlir::framework::GraphCompiler gc(target, scope, graph); @@ -76,16 +79,21 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -108,17 +116,22 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion_1) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -144,17 +157,22 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion_2) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -180,17 +198,22 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion_3) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -216,17 +239,22 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion_4) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -245,17 +273,22 @@ TEST(CheckFusionAccuracyPass, ElementWise_Fusion_5) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -277,17 +310,22 @@ TEST(CheckFusionAccuracyPass, Broadcast_Test_0) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -309,17 +347,22 @@ TEST(CheckFusionAccuracyPass, Broadcast_Test_2) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -343,17 +386,22 @@ TEST(CheckFusionAccuracyPass, Broadcast_Test_4) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -377,17 +425,22 @@ TEST(CheckFusionAccuracyPass, Broadcast_Test_5) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -408,17 +461,22 @@ TEST(CheckFusionAccuracyPass, Reduce_Test_0) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -438,17 +496,22 @@ TEST(CheckFusionAccuracyPass, Reduce_Test_1) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -471,17 +534,22 @@ TEST(CheckFusionAccuracyPass, Reduce_Test_2) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -504,17 +572,22 @@ TEST(CheckFusionAccuracyPass, Reduce_Test_3) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -538,17 +611,22 @@ TEST(CheckFusionAccuracyPass, Reduce_Test_4) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); @@ -569,17 +647,22 @@ TEST(CheckFusionAccuracyPass, Reduce_Test_5) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); - hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass", "FusionMergePass"}); + hlir::framework::ApplyPasses(graph.get(), + {"OpFusionPass", "FusionMergePass"}); - int group_size_after = graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); + int group_size_after = + graph->fusion_groups.size() + CountAfterPassNodeSize(graph.get()); - VLOG(1) << "Before CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); - hlir::framework::ApplyPasses(graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); - VLOG(1) << "After CheckFusionAccuracyPass:\n" << graph->DebugGroupedGraph(std::unordered_set{}); + VLOG(1) << "Before CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); + hlir::framework::ApplyPasses( + graph.get(), {"CheckFusionAccuracyPass", "TransToCustomCallPass"}); + VLOG(1) << "After CheckFusionAccuracyPass:\n" + << graph->DebugGroupedGraph(std::unordered_set{}); CHECK_EQ(graph->fusion_groups.size(), group_size_after); diff --git a/paddle/cinn/hlir/pass/common_subexpression_elimination.cc b/paddle/cinn/hlir/pass/common_subexpression_elimination.cc index f5aa1e58e8d3a..3c14e1d03b680 100644 --- a/paddle/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/paddle/cinn/hlir/pass/common_subexpression_elimination.cc @@ -34,8 +34,9 @@ using framework::NodeData; using common::GraphEdge; using common::GraphNode; -using InputToNodeMap = std::unordered_map>; -using shape_dict_t = absl::flat_hash_map; +using InputToNodeMap = + std::unordered_map>; +using shape_dict_t = absl::flat_hash_map; std::unordered_set unordered_ops = { "elementwise_add", @@ -52,7 +53,8 @@ std::unordered_set unordered_ops = { "bitwise_and", }; -// When all the inputs are the same, those ops just ensure that all the outputs shape is the same. +// When all the inputs are the same, those ops just ensure that all the outputs +// shape is the same. std::unordered_set reshape_ops = { "reshape", "concat", @@ -72,31 +74,37 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { // Get the number of input edges for op1 and op2 auto op1_inputs_size = op1_in_edges.size(); auto op2_inputs_size = op2_in_edges.size(); - // If the number of input edges is not the same, the subexpression is not the same. + // If the number of input edges is not the same, the subexpression is not the + // same. if (op1_inputs_size != op2_inputs_size) { return false; } // Get the number of attributes for op1 and op2. auto op1_attrs_size = op1->attrs.attr_store.size(); auto op2_attrs_size = op2->attrs.attr_store.size(); - // If the number of attributes is not the same, the subexpression is not the same. + // If the number of attributes is not the same, the subexpression is not the + // same. if (op1_attrs_size != op2_attrs_size) { return false; } // Check if the input nodes match. if (unordered_ops.count(op1->op()->name)) { - // For unordered ops, check if any input node of op2 matches any input node of op1. + // For unordered ops, check if any input node of op2 matches any input node + // of op1. for (auto& op1_edge : op1_in_edges) { auto* op1_source_node = op1_edge->source()->safe_as(); CHECK(op1_source_node); - bool op1_equal_op2 = std::any_of(op2_in_edges.begin(), op2_in_edges.end(), [&](common::Shared& edge) { - auto* op2_source_node = edge->source()->safe_as(); - CHECK(op2_source_node); - if (op1_source_node->id() == op2_source_node->id()) { - return true; - } - return false; - }); + bool op1_equal_op2 = std::any_of( + op2_in_edges.begin(), + op2_in_edges.end(), + [&](common::Shared& edge) { + auto* op2_source_node = edge->source()->safe_as(); + CHECK(op2_source_node); + if (op1_source_node->id() == op2_source_node->id()) { + return true; + } + return false; + }); if (!op1_equal_op2) { return false; } @@ -117,7 +125,8 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { // Check if the number of dimensions is the same. auto* op1_sink_node = GetNodeData(op1); auto* op2_sink_node = GetNodeData(op2); - if (shape_dict[op1_sink_node->id()].size() != shape_dict[op2_sink_node->id()].size()) { + if (shape_dict[op1_sink_node->id()].size() != + shape_dict[op2_sink_node->id()].size()) { return false; } if (reshape_ops.count(op1->op()->name)) { @@ -125,53 +134,56 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { return shape_dict[op1_sink_node->id()] == shape_dict[op2_sink_node->id()]; } else { // For non-reshape ops, check if the attributes is the same. - return std::all_of(op1->attrs.attr_store.begin(), op1->attrs.attr_store.end(), [&](auto attr) { - if (!op2->attrs.attr_store.count(attr.first)) { - return false; - } - auto& attr1 = attr.second; - auto& attr2 = op2->attrs.attr_store[attr.first]; - auto ndim = static_cast(shape_dict[op1_sink_node->id()].size()); - if (special_attrs.count(attr.first)) { - switch (special_attrs[attr.first]) { - case 1: { - auto op1_axis = absl::get(attr1); - auto op2_axis = absl::get(attr2); - if (op1_axis < 0) { - op1_axis += ndim; - } - if (op2_axis < 0) { - op2_axis += ndim; - } - return op2_axis == op1_axis; + return std::all_of( + op1->attrs.attr_store.begin(), + op1->attrs.attr_store.end(), + [&](auto attr) { + if (!op2->attrs.attr_store.count(attr.first)) { + return false; } - case 2: { - auto& op1_axes = absl::get>(attr1); - auto& op2_axes = absl::get>(attr2); - auto op1_size = op1_axes.size(); - auto op2_size = op2_axes.size(); - if (op1_size != op2_size) { - return false; - } - for (int i = 0; i < op1_axes.size(); ++i) { - int op1_axis = op1_axes[i]; - int op2_axis = op2_axes[i]; - if (op1_axis < 0) { - op1_axis += ndim; - } - if (op2_axis < 0) { - op2_axis += ndim; + auto& attr1 = attr.second; + auto& attr2 = op2->attrs.attr_store[attr.first]; + auto ndim = static_cast(shape_dict[op1_sink_node->id()].size()); + if (special_attrs.count(attr.first)) { + switch (special_attrs[attr.first]) { + case 1: { + auto op1_axis = absl::get(attr1); + auto op2_axis = absl::get(attr2); + if (op1_axis < 0) { + op1_axis += ndim; + } + if (op2_axis < 0) { + op2_axis += ndim; + } + return op2_axis == op1_axis; } - if (op2_axis != op1_axis) { - return false; + case 2: { + auto& op1_axes = absl::get>(attr1); + auto& op2_axes = absl::get>(attr2); + auto op1_size = op1_axes.size(); + auto op2_size = op2_axes.size(); + if (op1_size != op2_size) { + return false; + } + for (int i = 0; i < op1_axes.size(); ++i) { + int op1_axis = op1_axes[i]; + int op2_axis = op2_axes[i]; + if (op1_axis < 0) { + op1_axis += ndim; + } + if (op2_axis < 0) { + op2_axis += ndim; + } + if (op2_axis != op1_axis) { + return false; + } + } + return true; } } - return true; } - } - } - return attr1 == attr2; - }); + return attr1 == attr2; + }); } } @@ -197,7 +209,8 @@ void RemoveNodes(framework::Graph* graph, std::vector& nodes) { void RemoveNodes(framework::Graph* graph, std::vector& nodes_data) { for (auto* data : nodes_data) { - if (std::find(graph->outputs.begin(), graph->outputs.end(), data) != graph->outputs.end()) { + if (std::find(graph->outputs.begin(), graph->outputs.end(), data) != + graph->outputs.end()) { return; } RemoveNodes(graph, data); @@ -215,9 +228,13 @@ void ReplaceNode(NodeData* src_new, NodeData* src_old, Node* trt) { } } -void CommonSubexpressionElimination(Graph* graph, std::vector store_nodes, InputToNodeMap in2node) { +void CommonSubexpressionElimination(Graph* graph, + std::vector store_nodes, + InputToNodeMap in2node) { std::unordered_map> candidates_map; - auto shape_dict = graph->GetAttrs>("infershape"); + auto shape_dict = + graph->GetAttrs>( + "infershape"); std::vector remove_nodes; std::vector remove_nodes_data; @@ -227,9 +244,9 @@ void CommonSubexpressionElimination(Graph* graph, std::vector store_ VLOG(4) << "size of store_nodes is " << store_nodes.size(); auto node = graph_node->safe_as(); if (node) { - auto& node_type = node->op()->name; + auto& node_type = node->op()->name; auto& candidates = candidates_map[node_type]; - bool found = false; + bool found = false; for (auto* candidate_node : candidates) { // If node is same with candidate_node, continue the next. if (node->id() == candidate_node->id()) continue; @@ -237,25 +254,29 @@ void CommonSubexpressionElimination(Graph* graph, std::vector store_ if (!IsSameSubexpression(node, candidate_node, shape_dict)) continue; found = true; for (int k = 0; k < node->outlinks_in_order().size(); ++k) { - const auto& out_links = node->outlinks_in_order(); + const auto& out_links = node->outlinks_in_order(); const auto& candidate_out_links = candidate_node->outlinks_in_order(); CHECK(out_links.size() == candidate_out_links.size()); - auto* sink_node = out_links[k]->sink()->safe_as(); - auto* candidate_sink_node = candidate_out_links[k]->sink()->safe_as(); + auto* sink_node = out_links[k]->sink()->safe_as(); + auto* candidate_sink_node = + candidate_out_links[k]->sink()->safe_as(); CHECK(sink_node); CHECK(candidate_sink_node); - auto iter_sink_node = std::find(graph->outputs.begin(), graph->outputs.end(), sink_node); + auto iter_sink_node = std::find( + graph->outputs.begin(), graph->outputs.end(), sink_node); if (iter_sink_node != graph->outputs.end()) { // If sink node in outputs, the node cannot be removed. continue; } remove_nodes_data.push_back(sink_node); - // Replace sink_node with candidate_sink_node in nodes linked by sink_node. + // Replace sink_node with candidate_sink_node in nodes linked by + // sink_node. auto out_nodes = in2node[sink_node->id()]; for (auto out_node : out_nodes) { ReplaceNode(candidate_sink_node, sink_node, out_node); // The changed out node will be detected again. - if (std::find(store_nodes.begin(), store_nodes.end(), out_node) == store_nodes.end()) { + if (std::find(store_nodes.begin(), store_nodes.end(), out_node) == + store_nodes.end()) { store_nodes.insert(store_nodes.begin(), out_node); } } diff --git a/paddle/cinn/hlir/pass/common_subexpression_elimination_test.cc b/paddle/cinn/hlir/pass/common_subexpression_elimination_test.cc index 5c01561d9b3f1..73e4bcf8b8b5f 100644 --- a/paddle/cinn/hlir/pass/common_subexpression_elimination_test.cc +++ b/paddle/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -50,14 +50,14 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { Placeholder B(Float(32), {32, 1, 1}, "B", true); Program program; - auto add_1 = program.add(A, B); - auto add_2 = program.add(B, A); - auto add = program.add(add_1, add_2); - auto t_1 = program.transpose(add, {2, 1, 0}); // {1, 16, 32} - auto t_2 = program.transpose(add, {2, 1, 0}); // {1, 16, 32} - auto t_3 = program.transpose(add, {2, 1, 0}); // {1, 16, 32} + auto add_1 = program.add(A, B); + auto add_2 = program.add(B, A); + auto add = program.add(add_1, add_2); + auto t_1 = program.transpose(add, {2, 1, 0}); // {1, 16, 32} + auto t_2 = program.transpose(add, {2, 1, 0}); // {1, 16, 32} + auto t_3 = program.transpose(add, {2, 1, 0}); // {1, 16, 32} auto concat = program.concat({t_1, t_2, t_3}); - auto max = program.reduce_max(concat, {0}, true); + auto max = program.reduce_max(concat, {0}, true); Target target = common::DefaultTarget(); program.SetInputs({A, B}); @@ -73,8 +73,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { hlir::framework::GraphCompiler gc(target, scope, graph); auto runtime_program = gc.Build(); - auto& prerun_instrs = runtime_program->GetPreRunInstructions(); - auto& run_instrs = runtime_program->GetRunInstructions(); + auto& prerun_instrs = runtime_program->GetPreRunInstructions(); + auto& run_instrs = runtime_program->GetRunInstructions(); ASSERT_EQ(prerun_instrs.size(), 0); ASSERT_EQ(run_instrs.size(), 5); @@ -95,13 +95,13 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { Placeholder B(Float(32), {32, 1}, "B", true); Program program; - auto add_1 = program.add(A, A); - auto add_2 = program.add(A, A); + auto add_1 = program.add(A, A); + auto add_2 = program.add(A, A); auto reshape_1 = program.reshape(B, {4, -1}); auto reshape_2 = program.reshape(B, {4, 8}); - auto concat_1 = program.concat({reshape_1, reshape_2}); - auto concat_2 = program.concat({reshape_1, reshape_2}); - auto concat_3 = program.concat({reshape_1, reshape_2}, 1); + auto concat_1 = program.concat({reshape_1, reshape_2}); + auto concat_2 = program.concat({reshape_1, reshape_2}); + auto concat_3 = program.concat({reshape_1, reshape_2}, 1); Target target = common::DefaultTarget(); program.SetInputs({A, B}); @@ -117,8 +117,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { hlir::framework::GraphCompiler gc(target, scope, graph); auto runtime_program = gc.Build(); - auto& prerun_instrs = runtime_program->GetPreRunInstructions(); - auto& run_instrs = runtime_program->GetRunInstructions(); + auto& prerun_instrs = runtime_program->GetPreRunInstructions(); + auto& run_instrs = runtime_program->GetRunInstructions(); ASSERT_EQ(prerun_instrs.size(), 0); ASSERT_EQ(run_instrs.size(), 4); @@ -136,26 +136,28 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { #ifdef CINN_WITH_CUDA TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { - auto strides = std::vector({2, 2}); - auto dilations = std::vector({1, 1}); - auto paddings = std::vector({3, 3}); + auto strides = std::vector({2, 2}); + auto dilations = std::vector({1, 1}); + auto paddings = std::vector({3, 3}); auto data_format = "NCHW"; NetBuilder builder("CSE"); - auto A = builder.CreateInput(Float(32), {1, 3, 224, 224}, "A"); - auto B = builder.CreateInput(Float(32), {1, 1, 224, 224}, "B"); - auto add_1 = builder.Add(A, B); + auto A = builder.CreateInput(Float(32), {1, 3, 224, 224}, "A"); + auto B = builder.CreateInput(Float(32), {1, 1, 224, 224}, "B"); + auto add_1 = builder.Add(A, B); auto weight_1 = builder.FillConstant({64, 3, 7, 7}, 1.0f, "w1", false); auto weight_2 = builder.FillConstant({64, 3, 7, 7}, 1.0f, "w2", false); - auto bias = builder.FillConstant({1, 64, 112, 112}, 2.0f, "b1", false); - auto conv_1 = builder.Conv2d(add_1, weight_1, strides, paddings, dilations, 1, data_format); - auto add_2 = builder.Add(conv_1, bias); - auto relu_1 = builder.Relu(add_2); - auto conv_2 = builder.Conv2d(add_1, weight_2, strides, paddings, dilations, 1, data_format); - auto add_3 = builder.Add(conv_2, bias); - auto relu_2 = builder.Relu(add_3); - auto out1 = builder.Add(relu_1, add_2); - auto out2 = builder.Add(add_2, relu_2); + auto bias = builder.FillConstant({1, 64, 112, 112}, 2.0f, "b1", false); + auto conv_1 = builder.Conv2d( + add_1, weight_1, strides, paddings, dilations, 1, data_format); + auto add_2 = builder.Add(conv_1, bias); + auto relu_1 = builder.Relu(add_2); + auto conv_2 = builder.Conv2d( + add_1, weight_2, strides, paddings, dilations, 1, data_format); + auto add_3 = builder.Add(conv_2, bias); + auto relu_2 = builder.Relu(add_3); + auto out1 = builder.Add(relu_1, add_2); + auto out2 = builder.Add(add_2, relu_2); auto program = builder.Build(); LOG(INFO) << "Program:\n" << program; @@ -165,7 +167,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { fetch_list.insert(out2->id); Target target = common::DefaultNVGPUTarget(); - auto graph = std::make_shared(program, fetch_list, target); + auto graph = + std::make_shared(program, fetch_list, target); LOG(INFO) << "graph:\n" << graph->DebugGroupedGraph(fetch_list); hlir::framework::ApplyPass(graph.get(), "InferShape"); @@ -179,8 +182,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { hlir::framework::GraphCompiler gc(target, scope, graph); auto runtime_program = gc.Build(); - auto& prerun_instrs = runtime_program->GetPreRunInstructions(); - auto& run_instrs = runtime_program->GetRunInstructions(); + auto& prerun_instrs = runtime_program->GetPreRunInstructions(); + auto& run_instrs = runtime_program->GetRunInstructions(); ASSERT_EQ(prerun_instrs.size(), 0); ASSERT_EQ(run_instrs.size(), 7); scope->Var("A"); diff --git a/paddle/cinn/hlir/pass/const_propagate.cc b/paddle/cinn/hlir/pass/const_propagate.cc index fec0602797001..3db1c17422294 100644 --- a/paddle/cinn/hlir/pass/const_propagate.cc +++ b/paddle/cinn/hlir/pass/const_propagate.cc @@ -65,7 +65,8 @@ void ConstPropagatePass(Graph* graph) { CINN_REGISTER_HELPER(ConstPropagate) { CINN_REGISTER_PASS(ConstPropagate) .describe( - "This pass will propagate const node_datas and mark the op_node with the attr[\"pre_run\"] if inputs are all " + "This pass will propagate const node_datas and mark the op_node with " + "the attr[\"pre_run\"] if inputs are all " "constants;") .set_change_structure(false) .provide_graph_attr("pre_run") diff --git a/paddle/cinn/hlir/pass/const_propagate_test.cc b/paddle/cinn/hlir/pass/const_propagate_test.cc index 4ad813fc8461a..0259dea31c3f0 100644 --- a/paddle/cinn/hlir/pass/const_propagate_test.cc +++ b/paddle/cinn/hlir/pass/const_propagate_test.cc @@ -39,13 +39,13 @@ TEST(const_conv, const_conv) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; - auto c = program.conv2d(A, B, attrs); + auto c = program.conv2d(A, B, attrs); Target target = common::DefaultTarget(); program.SetInputs({A, B}); program.Validate(); @@ -59,8 +59,8 @@ TEST(const_conv, const_conv) { hlir::framework::GraphCompiler gc(target, scope, graph); auto runtime_program = gc.Build(); - auto& prerun_instrs = runtime_program->GetPreRunInstructions(); - auto& run_instrs = runtime_program->GetRunInstructions(); + auto& prerun_instrs = runtime_program->GetPreRunInstructions(); + auto& run_instrs = runtime_program->GetRunInstructions(); ASSERT_EQ(prerun_instrs.size(), 0); ASSERT_EQ(run_instrs.size(), 1); @@ -87,7 +87,8 @@ TEST(const_bn, const_bn) { Program program; absl::flat_hash_map attrs; attrs["epsilon"] = static_cast(0.001); - auto a = program.fused_batchnorm_inference(A, Scale, Bias, Mean, Variance, attrs); + auto a = + program.fused_batchnorm_inference(A, Scale, Bias, Mean, Variance, attrs); Target target = common::DefaultTarget(); program.SetInputs({A, Scale, Bias, Mean, Variance}); @@ -102,8 +103,8 @@ TEST(const_bn, const_bn) { hlir::framework::GraphCompiler gc(target, scope, graph); auto runtime_program = gc.Build(); - auto& prerun_instrs = runtime_program->GetPreRunInstructions(); - auto& run_instrs = runtime_program->GetRunInstructions(); + auto& prerun_instrs = runtime_program->GetPreRunInstructions(); + auto& run_instrs = runtime_program->GetRunInstructions(); // Revert changes in PR #990 to pass the model unittests ASSERT_EQ(run_instrs.size(), 1); @@ -113,10 +114,10 @@ TEST(const_bn, const_bn) { scope->Var("Mean"); scope->Var("Variance"); - auto A1 = scope->GetTensor("A"); - auto Scale1 = scope->GetTensor("Scale"); - auto Bias1 = scope->GetTensor("Bias"); - auto Mean1 = scope->GetTensor("Mean"); + auto A1 = scope->GetTensor("A"); + auto Scale1 = scope->GetTensor("Scale"); + auto Bias1 = scope->GetTensor("Bias"); + auto Mean1 = scope->GetTensor("Mean"); auto Variance1 = scope->GetTensor("Variance"); SetRandData(A1, target); SetRandData(Scale1, target); diff --git a/paddle/cinn/hlir/pass/constant_folding_pass.cc b/paddle/cinn/hlir/pass/constant_folding_pass.cc index 483147a792dd7..e0396b137395d 100644 --- a/paddle/cinn/hlir/pass/constant_folding_pass.cc +++ b/paddle/cinn/hlir/pass/constant_folding_pass.cc @@ -28,18 +28,22 @@ using framework::shape_t; using common::GraphEdge; using common::GraphNode; -using AlterFunction = std::function; +using AlterFunction = + std::function; // Constant Fold Pass // class ConstantFoldingPassHelper : public FusionHelperBase { public: - ConstantFoldingPassHelper(Graph* graph) : FusionHelperBase(graph), graph_(graph) { RegisterAlterFunction(); } + ConstantFoldingPassHelper(Graph* graph) + : FusionHelperBase(graph), graph_(graph) { + RegisterAlterFunction(); + } void operator()() { bool update = false; do { - update = false; + update = false; auto nodes_inorder = std::get<0>(graph_->topological_order()); for (auto node : nodes_inorder) { if (!node->safe_as()) { @@ -60,9 +64,11 @@ class ConstantFoldingPassHelper : public FusionHelperBase { // if producer's output in graph_->outputs, then will not fold for (auto& edge : producer->outlinks()) { auto graph_node = edge->sink(); - auto data = graph_node->safe_as(); + auto data = graph_node->safe_as(); CHECK(data); - if (std::find(graph_->outputs.begin(), graph_->outputs.end(), data) != graph_->outputs.end()) { + if (std::find(graph_->outputs.begin(), + graph_->outputs.end(), + data) != graph_->outputs.end()) { can_fold = false; break; } @@ -83,15 +89,16 @@ class ConstantFoldingPassHelper : public FusionHelperBase { private: void RegisterAlterFunction() { - alter_function_ = {{"broadcast_to_const_scalar", fold_broadcast_to_constant}, - {"broadcast_to_fill_constant", fold_broadcast_to_constant}, - {"reshape_fill_constant", fold_reshape_fill_constant}, - {"squeeze_fill_constant", fold_squeeze_fill_constant}, - {"expand_dims_fill_constant", fold_expand_dims_fill_constant}}; + alter_function_ = { + {"broadcast_to_const_scalar", fold_broadcast_to_constant}, + {"broadcast_to_fill_constant", fold_broadcast_to_constant}, + {"reshape_fill_constant", fold_reshape_fill_constant}, + {"squeeze_fill_constant", fold_squeeze_fill_constant}, + {"expand_dims_fill_constant", fold_expand_dims_fill_constant}}; } std::string GetTypeName(Node* node) { - auto producers = GetProducerNode(node->safe_as()); + auto producers = GetProducerNode(node->safe_as()); std::string key = node->op()->name; for (auto producer : producers) { key += std::string("_") + producer->op()->name; diff --git a/paddle/cinn/hlir/pass/constant_folding_pass_test.cc b/paddle/cinn/hlir/pass/constant_folding_pass_test.cc index 39293b7e8d407..2698d035d30d9 100644 --- a/paddle/cinn/hlir/pass/constant_folding_pass_test.cc +++ b/paddle/cinn/hlir/pass/constant_folding_pass_test.cc @@ -19,13 +19,17 @@ namespace cinn { namespace frontend { -int GetSize(std::vector& shape) { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); } +int GetSize(std::vector& shape) { + return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); +} -std::unordered_map> GetInputRandom(const std::vector&& inputs) { +std::unordered_map> GetInputRandom( + const std::vector&& inputs) { std::unordered_map> input_data; for (auto input : inputs) { input_data[input->id] = std::vector(GetSize(input->shape)); - InitRandomVector(&input_data[input->id], input_data[input->id].size(), 0.0f, 1.0f, 1e-3); + InitRandomVector( + &input_data[input->id], input_data[input->id].size(), 0.0f, 1.0f, 1e-3); } return input_data; @@ -37,7 +41,8 @@ std::unordered_map> RunModelTest( const std::unordered_map>& input_data, const std::unordered_set& fetch_ids) { auto target = common::DefaultTarget(); - auto graph = std::make_shared(program, fetch_ids, target); + auto graph = + std::make_shared(program, fetch_ids, target); hlir::framework::ApplyPasses(graph.get(), passes); auto scope = BuildScope(target, graph); @@ -71,11 +76,16 @@ TEST(Constant_Folding, fold_broadcast_to_const_scalar_1) { auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); auto D = net_builder.Add(B, C); - auto fetch_ids = {D->id}; + auto fetch_ids = {D->id}; auto input_data = GetInputRandom({C}); - auto program = net_builder.Build(); - auto output0 = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); - auto output1 = RunModelTest(program, {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto program = net_builder.Build(); + auto output0 = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto output1 = + RunModelTest(program, + {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, + input_data, + fetch_ids); for (auto& output : output0) { CHECK(output1.count(output.first)); @@ -94,11 +104,16 @@ TEST(Constant_Folding, fold_broadcast_to_const_scalar_2) { auto E = net_builder.Add(B, C); auto F = net_builder.Add(A, D); - auto fetch_ids = {E->id, F->id}; + auto fetch_ids = {E->id, F->id}; auto input_data = GetInputRandom({C, D}); - auto program = net_builder.Build(); - auto output0 = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); - auto output1 = RunModelTest(program, {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto program = net_builder.Build(); + auto output0 = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto output1 = + RunModelTest(program, + {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, + input_data, + fetch_ids); for (auto& output : output0) { CHECK(output1.count(output.first)); @@ -118,11 +133,16 @@ TEST(Constant_Folding, fold_broadcast_to_const_scalar_3) { auto F = net_builder.Add(B, C); auto G = net_builder.Add(D, E); - auto fetch_ids = {G->id, F->id}; + auto fetch_ids = {G->id, F->id}; auto input_data = GetInputRandom({C, E}); - auto program = net_builder.Build(); - auto output0 = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); - auto output1 = RunModelTest(program, {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto program = net_builder.Build(); + auto output0 = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto output1 = + RunModelTest(program, + {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, + input_data, + fetch_ids); for (auto& output : output0) { CHECK(output1.count(output.first)); @@ -139,11 +159,16 @@ TEST(Constant_Folding, fold_broadcast_to_fill_constant_1) { auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); auto D = net_builder.Add(B, C); - auto fetch_ids = {D->id}; + auto fetch_ids = {D->id}; auto input_data = GetInputRandom({C}); - auto program = net_builder.Build(); - auto output0 = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); - auto output1 = RunModelTest(program, {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto program = net_builder.Build(); + auto output0 = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto output1 = + RunModelTest(program, + {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, + input_data, + fetch_ids); for (auto& output : output0) { CHECK(output1.count(output.first)); @@ -162,11 +187,16 @@ TEST(Constant_Folding, fold_broadcast_to_fill_constant_2) { auto E = net_builder.Add(B, C); auto F = net_builder.Add(A, D); - auto fetch_ids = {E->id, F->id}; + auto fetch_ids = {E->id, F->id}; auto input_data = GetInputRandom({C, D}); - auto program = net_builder.Build(); - auto output0 = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); - auto output1 = RunModelTest(program, {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto program = net_builder.Build(); + auto output0 = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto output1 = + RunModelTest(program, + {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, + input_data, + fetch_ids); for (auto& output : output0) { CHECK(output1.count(output.first)); @@ -183,11 +213,16 @@ TEST(Constant_Folding, fold_reshape_fill_constant_1) { auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); auto D = net_builder.Add(B, C); - auto fetch_ids = {D->id}; + auto fetch_ids = {D->id}; auto input_data = GetInputRandom({C}); - auto program = net_builder.Build(); - auto output0 = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); - auto output1 = RunModelTest(program, {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto program = net_builder.Build(); + auto output0 = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto output1 = + RunModelTest(program, + {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, + input_data, + fetch_ids); for (auto& output : output0) { CHECK(output1.count(output.first)); @@ -206,11 +241,16 @@ TEST(Constant_Folding, fold_reshape_fill_constant_2) { auto E = net_builder.Add(B, C); auto F = net_builder.Add(A, D); - auto fetch_ids = {E->id, F->id}; + auto fetch_ids = {E->id, F->id}; auto input_data = GetInputRandom({C, D}); - auto program = net_builder.Build(); - auto output0 = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); - auto output1 = RunModelTest(program, {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto program = net_builder.Build(); + auto output0 = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto output1 = + RunModelTest(program, + {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, + input_data, + fetch_ids); for (auto& output : output0) { CHECK(output1.count(output.first)); @@ -227,11 +267,16 @@ TEST(Constant_Folding, fold_squeeze_fill_constant_1) { auto C = net_builder.CreateInput(Float(32), {h, w}, "C"); auto D = net_builder.Add(B, C); - auto fetch_ids = {D->id}; + auto fetch_ids = {D->id}; auto input_data = GetInputRandom({C}); - auto program = net_builder.Build(); - auto output0 = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); - auto output1 = RunModelTest(program, {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto program = net_builder.Build(); + auto output0 = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto output1 = + RunModelTest(program, + {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, + input_data, + fetch_ids); for (auto& output : output0) { CHECK(output1.count(output.first)); @@ -250,11 +295,16 @@ TEST(Constant_Folding, fold_squeeze_fill_constant_2) { auto E = net_builder.Add(B, C); auto F = net_builder.Add(A, D); - auto fetch_ids = {E->id, F->id}; + auto fetch_ids = {E->id, F->id}; auto input_data = GetInputRandom({C, D}); - auto program = net_builder.Build(); - auto output0 = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); - auto output1 = RunModelTest(program, {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto program = net_builder.Build(); + auto output0 = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto output1 = + RunModelTest(program, + {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, + input_data, + fetch_ids); for (auto& output : output0) { CHECK(output1.count(output.first)); @@ -271,11 +321,16 @@ TEST(Constant_Folding, fold_expand_dims_to_fill_constant_1) { auto C = net_builder.CreateInput(Float(32), {h, 1, w, 1}, "C"); auto D = net_builder.Add(B, C); - auto fetch_ids = {D->id}; + auto fetch_ids = {D->id}; auto input_data = GetInputRandom({C}); - auto program = net_builder.Build(); - auto output0 = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); - auto output1 = RunModelTest(program, {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto program = net_builder.Build(); + auto output0 = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto output1 = + RunModelTest(program, + {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, + input_data, + fetch_ids); for (auto& output : output0) { CHECK(output1.count(output.first)); @@ -294,11 +349,16 @@ TEST(Constant_Folding, fold_expand_dims_to_fill_constant_2) { auto E = net_builder.Add(B, C); auto F = net_builder.Add(A, D); - auto fetch_ids = {E->id, F->id}; + auto fetch_ids = {E->id, F->id}; auto input_data = GetInputRandom({C, D}); - auto program = net_builder.Build(); - auto output0 = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); - auto output1 = RunModelTest(program, {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto program = net_builder.Build(); + auto output0 = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto output1 = + RunModelTest(program, + {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, + input_data, + fetch_ids); for (auto& output : output0) { CHECK(output1.count(output.first)); @@ -317,11 +377,16 @@ TEST(Constant_Folding, fold_expand_dims_to_fill_constant_3) { auto E = net_builder.Add(B, C); auto F = net_builder.Add(A, D); - auto fetch_ids = {E->id, F->id}; + auto fetch_ids = {E->id, F->id}; auto input_data = GetInputRandom({C, D}); - auto program = net_builder.Build(); - auto output0 = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); - auto output1 = RunModelTest(program, {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto program = net_builder.Build(); + auto output0 = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, input_data, fetch_ids); + auto output1 = + RunModelTest(program, + {"ConstantFolding", "OpFusionPass", "FusionMergePass"}, + input_data, + fetch_ids); for (auto& output : output0) { CHECK(output1.count(output.first)); diff --git a/paddle/cinn/hlir/pass/constant_folding_pass_util.cc b/paddle/cinn/hlir/pass/constant_folding_pass_util.cc index ed6bf86ee5190..90aab2144065f 100644 --- a/paddle/cinn/hlir/pass/constant_folding_pass_util.cc +++ b/paddle/cinn/hlir/pass/constant_folding_pass_util.cc @@ -33,14 +33,24 @@ using cinn::utils::ShapeType; namespace utils { class ConstantFoldingHelper { public: - ConstantFoldingHelper(const FusionHelperBase* helper, Graph* graph, Node* node) - : helper_(helper), graph_(graph), consumer_(node), producer_(helper->GetProducerNode(node)[0]) {} - - const AttributeMap& GetProducerAttrs() const { return producer_->attrs.attr_store; } - const AttributeMap& GetConsumerAttrs() const { return consumer_->attrs.attr_store; } + ConstantFoldingHelper(const FusionHelperBase* helper, + Graph* graph, + Node* node) + : helper_(helper), + graph_(graph), + consumer_(node), + producer_(helper->GetProducerNode(node)[0]) {} + + const AttributeMap& GetProducerAttrs() const { + return producer_->attrs.attr_store; + } + const AttributeMap& GetConsumerAttrs() const { + return consumer_->attrs.attr_store; + } // fold consumer node and producer node into a new op node - void operator()(const AttributeMap& attrs_map, const std::string& new_op_name) { + void operator()(const AttributeMap& attrs_map, + const std::string& new_op_name) { auto* new_fold_node = CreateNewNode(new_op_name, attrs_map); // create new link. @@ -48,11 +58,15 @@ class ConstantFoldingHelper { } // fold consumer node into producer node - void operator()(const AttributeMap& attrs_map) { this->operator()(attrs_map, producer_->op()->name); } + void operator()(const AttributeMap& attrs_map) { + this->operator()(attrs_map, producer_->op()->name); + } private: - Node* CreateNewNode(const std::string& op_name, const AttributeMap& attrs_map) { - auto* node = new Node(Operator::Get(op_name), op_name, common::UniqName(op_name)); + Node* CreateNewNode(const std::string& op_name, + const AttributeMap& attrs_map) { + auto* node = + new Node(Operator::Get(op_name), op_name, common::UniqName(op_name)); node->attrs.attr_store = attrs_map; graph_->RegisterNode(node->id(), node); return node; @@ -137,44 +151,50 @@ class ConstantFoldingHelper { } // namespace utils // fold fill_constant/const_scalar->broadcast_to ==> fill_constant -void fold_broadcast_to_constant(const FusionHelperBase* helper, Graph* graph, Node* node) { +void fold_broadcast_to_constant(const FusionHelperBase* helper, + Graph* graph, + Node* node) { utils::ConstantFoldingHelper fold_helper(helper, graph, node); const auto& broadcast_to_attrs = fold_helper.GetConsumerAttrs(); - const auto& constant_attrs = fold_helper.GetProducerAttrs(); + const auto& constant_attrs = fold_helper.GetProducerAttrs(); auto shape = GetAttr(broadcast_to_attrs, "out_shape"); AttributeMap new_attrs; - new_attrs["dtype"] = constant_attrs.at("dtype"); - new_attrs["shape"] = GetAttr(broadcast_to_attrs, "out_shape"); - new_attrs["value"] = constant_attrs.at("value"); + new_attrs["dtype"] = constant_attrs.at("dtype"); + new_attrs["shape"] = GetAttr(broadcast_to_attrs, "out_shape"); + new_attrs["value"] = constant_attrs.at("value"); new_attrs["force_cpu"] = false; fold_helper(new_attrs, "fill_constant"); } // fold fill_constant->reshape ==> fill_constant -void fold_reshape_fill_constant(const FusionHelperBase* helper, Graph* graph, Node* node) { +void fold_reshape_fill_constant(const FusionHelperBase* helper, + Graph* graph, + Node* node) { utils::ConstantFoldingHelper fold_helper(helper, graph, node); const auto& reshape_attrs = fold_helper.GetConsumerAttrs(); AttributeMap new_attrs = fold_helper.GetProducerAttrs(); - new_attrs["shape"] = GetAttr(reshape_attrs, "shape"); + new_attrs["shape"] = GetAttr(reshape_attrs, "shape"); fold_helper(new_attrs); } // fold fill_constant->squeeze ==> fill_constant -void fold_squeeze_fill_constant(const FusionHelperBase* helper, Graph* graph, Node* node) { +void fold_squeeze_fill_constant(const FusionHelperBase* helper, + Graph* graph, + Node* node) { utils::ConstantFoldingHelper fold_helper(helper, graph, node); - const auto& squeeze_attrs = fold_helper.GetConsumerAttrs(); + const auto& squeeze_attrs = fold_helper.GetConsumerAttrs(); const auto& constant_attrs = fold_helper.GetProducerAttrs(); const auto& shape = GetAttr(constant_attrs, "shape"); - const auto& axes = GetAttr(squeeze_attrs, "axes"); + const auto& axes = GetAttr(squeeze_attrs, "axes"); // set node attr std::vector n_shape; if (axes.size() == 0) { @@ -192,28 +212,30 @@ void fold_squeeze_fill_constant(const FusionHelperBase* helper, Graph* graph, No } AttributeMap new_attrs = constant_attrs; - new_attrs["shape"] = n_shape; + new_attrs["shape"] = n_shape; fold_helper(new_attrs); } // fold fill_constant->expand_dims ==> fill_constant -void fold_expand_dims_fill_constant(const FusionHelperBase* helper, Graph* graph, Node* node) { +void fold_expand_dims_fill_constant(const FusionHelperBase* helper, + Graph* graph, + Node* node) { utils::ConstantFoldingHelper fold_helper(helper, graph, node); const auto& expand_dims_attrs = fold_helper.GetConsumerAttrs(); - const auto& constant_attrs = fold_helper.GetProducerAttrs(); + const auto& constant_attrs = fold_helper.GetProducerAttrs(); const auto& shape = GetAttr(constant_attrs, "shape"); - auto axes = GetAttr(expand_dims_attrs, "axes"); + auto axes = GetAttr(expand_dims_attrs, "axes"); int shape_size = shape.size(); - int axes_size = axes.size(); + int axes_size = axes.size(); int total_size = shape_size + axes_size; - axes = cinn::utils::GetPositiveAxes(axes, total_size); + axes = cinn::utils::GetPositiveAxes(axes, total_size); - // check axes whether in range [-total_size, total_size-1] and convert all to [0, total_size-1]. - // check axes can't repeat. + // check axes whether in range [-total_size, total_size-1] and convert all to + // [0, total_size-1]. check axes can't repeat. std::sort(axes.begin(), axes.end(), std::less()); for (int idx = 0; idx < axes_size - 1; ++idx) { CHECK_NE(axes[idx], axes[idx + 1]); @@ -227,7 +249,7 @@ void fold_expand_dims_fill_constant(const FusionHelperBase* helper, Graph* graph } AttributeMap new_attrs = constant_attrs; - new_attrs["shape"] = n_shape; + new_attrs["shape"] = n_shape; fold_helper(new_attrs); } diff --git a/paddle/cinn/hlir/pass/constant_folding_pass_util.h b/paddle/cinn/hlir/pass/constant_folding_pass_util.h index b253af2b9ebaa..3ba07b8c26d14 100644 --- a/paddle/cinn/hlir/pass/constant_folding_pass_util.h +++ b/paddle/cinn/hlir/pass/constant_folding_pass_util.h @@ -23,16 +23,24 @@ namespace hlir { namespace pass { // fold fill_constant/const_scalar->broadcast_to ==> fill_constant -void fold_broadcast_to_constant(const FusionHelperBase* helper, Graph* graph, Node* node); +void fold_broadcast_to_constant(const FusionHelperBase* helper, + Graph* graph, + Node* node); // fold fill_constant->reshape ==> fill_constant -void fold_reshape_fill_constant(const FusionHelperBase* helper, Graph* graph, Node* node); +void fold_reshape_fill_constant(const FusionHelperBase* helper, + Graph* graph, + Node* node); // fold fill_constant->squeeze ==> fill_constant -void fold_squeeze_fill_constant(const FusionHelperBase* helper, Graph* graph, Node* node); +void fold_squeeze_fill_constant(const FusionHelperBase* helper, + Graph* graph, + Node* node); // fold fill_constant->expand_dims ==> fill_constant -void fold_expand_dims_fill_constant(const FusionHelperBase* helper, Graph* graph, Node* node); +void fold_expand_dims_fill_constant(const FusionHelperBase* helper, + Graph* graph, + Node* node); } // namespace pass } // namespace hlir diff --git a/paddle/cinn/hlir/pass/custom_call_pass.cc b/paddle/cinn/hlir/pass/custom_call_pass.cc index 2af0d68c5a599..2d47a211c6b14 100644 --- a/paddle/cinn/hlir/pass/custom_call_pass.cc +++ b/paddle/cinn/hlir/pass/custom_call_pass.cc @@ -32,31 +32,36 @@ class GraphAlterHelper { public: GraphAlterHelper(Graph* graph) : graph_(graph) { if (!FLAGS_cinn_custom_call_deny_ops.empty()) { - auto splited_names = cinn::utils::Split(FLAGS_cinn_custom_call_deny_ops, ";"); - deny_ops_ = {splited_names.begin(), splited_names.end()}; + auto splited_names = + cinn::utils::Split(FLAGS_cinn_custom_call_deny_ops, ";"); + deny_ops_ = {splited_names.begin(), splited_names.end()}; } } void TransToCustomCall(const common::Target& target) { // collect candidate nodes - auto mark_nodes = graph_->CollectNodes([this, &target](const common::GraphNode* graph_node) -> bool { - if (graph_node->safe_as()) { - auto node = graph_node->safe_as(); - auto&& op_name = node->op()->name; - // a op with external_api registered and not excluded explicitly will be selected - if (!IsExcluded(op_name) && ExternalApiRegistry::Global()->Has(op_name, target)) { - VLOG(4) << "Op:" << op_name << " will use custom_call"; - return true; - } - } + auto mark_nodes = graph_->CollectNodes( + [this, &target](const common::GraphNode* graph_node) -> bool { + if (graph_node->safe_as()) { + auto node = graph_node->safe_as(); + auto&& op_name = node->op()->name; + // a op with external_api registered and not excluded explicitly + // will be selected + if (!IsExcluded(op_name) && + ExternalApiRegistry::Global()->Has(op_name, target)) { + VLOG(4) << "Op:" << op_name << " will use custom_call"; + return true; + } + } - return false; - }); + return false; + }); for (auto* graph_node : mark_nodes) { auto* node = graph_node->safe_as(); // revise the output edges for conv2d because the compute implement of // codegen-registered is not consistent with cudnn - if ((node->op()->name == "conv2d" || node->op()->name == "depthwise_conv2d") && + if ((node->op()->name == "conv2d" || + node->op()->name == "depthwise_conv2d") && target == common::DefaultNVGPUTarget()) { auto out_links = node->outlinks_in_order(); for (int idx = 1; idx < out_links.size(); ++idx) { @@ -68,7 +73,7 @@ class GraphAlterHelper { } node->attrs.attr_store["original_op"] = node->op()->name; - node->attrs.op = framework::Operator::Get("custom_call"); + node->attrs.op = framework::Operator::Get("custom_call"); } } @@ -76,7 +81,9 @@ class GraphAlterHelper { Graph* graph_; std::unordered_set deny_ops_; - bool IsExcluded(const std::string& op_name) { return deny_ops_.count(op_name); } + bool IsExcluded(const std::string& op_name) { + return deny_ops_.count(op_name); + } }; void TransToCustomCallInternal(Graph* graph) { @@ -92,7 +99,8 @@ void TransToCustomCallInternal(Graph* graph) { CINN_REGISTER_HELPER(TransToCustomCallPass) { CINN_REGISTER_PASS(TransToCustomCallPass) .describe( - "This pass replaces every op with external_api registered on the specified target to be custom_call op, " + "This pass replaces every op with external_api registered on the " + "specified target to be custom_call op, " "except the blacklist specified by FLAGS_cinn_custom_call_deny_ops") .set_change_structure(false) .set_body(cinn::hlir::pass::TransToCustomCallInternal); diff --git a/paddle/cinn/hlir/pass/dce_pass.cc b/paddle/cinn/hlir/pass/dce_pass.cc index 6865c2048d06d..32e6e952ed6a8 100644 --- a/paddle/cinn/hlir/pass/dce_pass.cc +++ b/paddle/cinn/hlir/pass/dce_pass.cc @@ -30,10 +30,11 @@ using framework::shape_t; using common::GraphEdge; using common::GraphNode; -using GroupPtr = std::shared_ptr; +using GroupPtr = std::shared_ptr; using GroupList = std::vector; -using ConditionFunction = std::function; +using ConditionFunction = + std::function; class DceHelper : public FusionHelperBase { public: @@ -85,7 +86,7 @@ class DceHelper : public FusionHelperBase { if (nodes_set_.count(node)) { continue; } - auto& inlinks = node->inlinks(); + auto& inlinks = node->inlinks(); auto& outlinks = node->outlinks(); // remove others link to node. @@ -97,7 +98,7 @@ class DceHelper : public FusionHelperBase { // remove node data link to others. for (auto link : outlinks) { // node data - auto ndata = link->sink(); + auto ndata = link->sink(); auto& links = ndata->outlinks(); for (auto link_ : links) { auto dest = link_->sink(); diff --git a/paddle/cinn/hlir/pass/dce_pass_test.cc b/paddle/cinn/hlir/pass/dce_pass_test.cc index 1d81962c5680c..7f5c3355b0067 100644 --- a/paddle/cinn/hlir/pass/dce_pass_test.cc +++ b/paddle/cinn/hlir/pass/dce_pass_test.cc @@ -29,10 +29,11 @@ TEST(DCE, Test_0) { auto D = net_builder.Multiply(A, B); auto fetch_ids = {D->id}; - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); - auto graph = std::make_shared(program, fetch_ids, target); + auto graph = + std::make_shared(program, fetch_ids, target); hlir::framework::ApplyPass(graph.get(), "DCE"); CHECK_EQ(graph->nodes().size(), 4); @@ -52,10 +53,11 @@ TEST(DCE, Test_1) { auto H = net_builder.Add(E, G); auto fetch_ids = {F->id}; - auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto program = net_builder.Build(); + auto target = common::DefaultTarget(); - auto graph = std::make_shared(program, fetch_ids, target); + auto graph = + std::make_shared(program, fetch_ids, target); hlir::framework::ApplyPass(graph.get(), "DCE"); CHECK_EQ(graph->nodes().size(), 8); } diff --git a/paddle/cinn/hlir/pass/dense_merge_pass.cc b/paddle/cinn/hlir/pass/dense_merge_pass.cc index 1c867a84a098c..3ffd7fb369e71 100644 --- a/paddle/cinn/hlir/pass/dense_merge_pass.cc +++ b/paddle/cinn/hlir/pass/dense_merge_pass.cc @@ -25,11 +25,9 @@ using framework::Graph; using framework::Node; using framework::NodeAttr; -// Dense Merge Pass: merge those gemm which has same var as input into a batched cubals call op. -// A * B, A * C, A * D,... -// after -// A * [B, C, D,...] -// Using cublas batched gemm can avoid do concat and slice. +// Dense Merge Pass: merge those gemm which has same var as input into a batched +// cubals call op. A * B, A * C, A * D,... after A * [B, C, D,...] Using cublas +// batched gemm can avoid do concat and slice. class DenseMergePassHelper : public FusionHelperBase { public: @@ -73,9 +71,11 @@ class DenseMergePassHelper : public FusionHelperBase { std::vector dense_ops; for (auto link : node->outlinks()) { auto sink = link->sink()->safe_as(); - if (sink->op()->name == "matmul" || sink->op()->name == "mul" || sink->op()->name == "cublas_gemm" || + if (sink->op()->name == "matmul" || sink->op()->name == "mul" || + sink->op()->name == "cublas_gemm" || sink->op()->name == "cublas_matmul") { - if (std::find(dense_ops.begin(), dense_ops.end(), sink) == dense_ops.end()) { + if (std::find(dense_ops.begin(), dense_ops.end(), sink) == + dense_ops.end()) { dense_ops.push_back(sink); } } @@ -83,17 +83,25 @@ class DenseMergePassHelper : public FusionHelperBase { return dense_ops; } - void LeftMerge(NodeData* node, std::vector dense_ops) { DoMerge(node, dense_ops, 1, "left"); } + void LeftMerge(NodeData* node, std::vector dense_ops) { + DoMerge(node, dense_ops, 1, "left"); + } - void RightMerge(NodeData* node, std::vector dense_ops) { DoMerge(node, dense_ops, 0, "right"); } + void RightMerge(NodeData* node, std::vector dense_ops) { + DoMerge(node, dense_ops, 0, "right"); + } - void DoMerge(NodeData* node, std::vector dense_ops, int pos, std::string side) { + void DoMerge(NodeData* node, + std::vector dense_ops, + int pos, + std::string side) { // split dense op by it's attr std::unordered_map> dense_op_map; for (auto dense_op : dense_ops) { const auto& in_links = dense_op->inlinks_in_order(); CHECK_GT(in_links.size(), pos); - auto sign = GenOpSign(in_links[pos]->source()->safe_as(), dense_op->attrs); + auto sign = GenOpSign(in_links[pos]->source()->safe_as(), + dense_op->attrs); if (dense_op_map.count(sign)) { dense_op_map[sign].push_back(dense_op); } else { @@ -107,11 +115,14 @@ class DenseMergePassHelper : public FusionHelperBase { } // create custom call node - Node* node_tmp = new Node(Operator::Get("custom_call"), "custom_call", common::UniqName("custom_call")); + Node* node_tmp = new Node(Operator::Get("custom_call"), + "custom_call", + common::UniqName("custom_call")); graph_->RegisterNode(node_tmp->id(), node_tmp); - node_tmp->attrs.attr_store = dense_op.second[0]->attrs.attr_store; - node_tmp->attrs.attr_store["side"] = side; - node_tmp->attrs.attr_store["custom_call"] = std::string("cinn_call_batched_cublas"); + node_tmp->attrs.attr_store = dense_op.second[0]->attrs.attr_store; + node_tmp->attrs.attr_store["side"] = side; + node_tmp->attrs.attr_store["custom_call"] = + std::string("cinn_call_batched_cublas"); // update inlink. node->LinkTo(node_tmp); @@ -137,14 +148,28 @@ class DenseMergePassHelper : public FusionHelperBase { } std::string GenOpSign(const NodeData* node, const NodeAttr& attrs) { - auto attr_store = attrs.attr_store; - bool trans_a = attr_store.count("trans_a") ? absl::get(attr_store.at("trans_a")) : false; - bool trans_b = attr_store.count("trans_b") ? absl::get(attr_store.at("trans_b")) : false; - bool trans_out = attr_store.count("trans_out") ? absl::get(attr_store.at("trans_out")) : false; - float alpha = attr_store.count("alpha") ? absl::get(attr_store.at("alpha")) : 1.0f; - float beta = attr_store.count("beta") ? absl::get(attr_store.at("beta")) : 0.0f; - int x_num_col_dims = attr_store.count("x_num_col_dims") ? absl::get(attr_store.at("x_num_col_dims")) : 0; - int y_num_col_dims = attr_store.count("y_num_col_dims") ? absl::get(attr_store.at("y_num_col_dims")) : 0; + auto attr_store = attrs.attr_store; + bool trans_a = attr_store.count("trans_a") + ? absl::get(attr_store.at("trans_a")) + : false; + bool trans_b = attr_store.count("trans_b") + ? absl::get(attr_store.at("trans_b")) + : false; + bool trans_out = attr_store.count("trans_out") + ? absl::get(attr_store.at("trans_out")) + : false; + float alpha = attr_store.count("alpha") + ? absl::get(attr_store.at("alpha")) + : 1.0f; + float beta = attr_store.count("beta") + ? absl::get(attr_store.at("beta")) + : 0.0f; + int x_num_col_dims = attr_store.count("x_num_col_dims") + ? absl::get(attr_store.at("x_num_col_dims")) + : 0; + int y_num_col_dims = attr_store.count("y_num_col_dims") + ? absl::get(attr_store.at("y_num_col_dims")) + : 0; std::string sign = ""; sign += std::to_string(trans_a); diff --git a/paddle/cinn/hlir/pass/dense_merge_pass_test.cc b/paddle/cinn/hlir/pass/dense_merge_pass_test.cc index fa66ec2d1240f..05ee12558f7ca 100644 --- a/paddle/cinn/hlir/pass/dense_merge_pass_test.cc +++ b/paddle/cinn/hlir/pass/dense_merge_pass_test.cc @@ -19,7 +19,9 @@ namespace cinn { namespace frontend { -int GetSize(std::vector& shape) { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); } +int GetSize(std::vector& shape) { + return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); +} void RunModelTest(Program& program, const std::vector&& inputs, @@ -28,13 +30,17 @@ void RunModelTest(Program& program, std::vector> inputs_data; for (auto input : inputs) { inputs_data.emplace_back(GetSize(input->shape)); - InitRandomVector(&inputs_data.back(), inputs_data.back().size(), 0.0f, 1.0f, 1e-3); + InitRandomVector( + &inputs_data.back(), inputs_data.back().size(), 0.0f, 1.0f, 1e-3); } auto target = common::DefaultTarget(); - std::unordered_map, std::vector>> outputs; + std::unordered_map, std::vector>> + outputs; { - auto graph = std::make_shared(program, fetch_ids, target); + auto graph = + std::make_shared(program, fetch_ids, target); hlir::framework::ApplyPass(graph.get(), "TransToCustomCallPass"); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); @@ -46,7 +52,7 @@ void RunModelTest(Program& program, for (int idx = 0; idx < inputs.size(); ++idx) { scope->Var(inputs[idx]->id); auto tensor = scope->GetTensor(inputs[idx]->id); - auto* data = tensor->mutable_data(target); + auto* data = tensor->mutable_data(target); CopyFromVector(inputs_data[idx], tensor, target); } run_program->Execute(); @@ -54,11 +60,13 @@ void RunModelTest(Program& program, auto tensor = scope->GetTensor(id); std::vector data(tensor->shape().numel()); CopyToVector(tensor, &data); - outputs[id] = std::pair, std::vector>(data, std::vector()); + outputs[id] = std::pair, std::vector>( + data, std::vector()); } } { - auto graph = std::make_shared(program, fetch_ids, target); + auto graph = + std::make_shared(program, fetch_ids, target); hlir::framework::ApplyPass(graph.get(), "DenseMergePass"); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); @@ -70,7 +78,7 @@ void RunModelTest(Program& program, for (int idx = 0; idx < inputs.size(); ++idx) { scope->Var(inputs[idx]->id); auto tensor = scope->GetTensor(inputs[idx]->id); - auto* data = tensor->mutable_data(target); + auto* data = tensor->mutable_data(target); CopyFromVector(inputs_data[idx], tensor, target); } run_program->Execute(); @@ -97,7 +105,7 @@ TEST(DenseMergePass, Test_Matmul_0) { auto E = net_builder.Matmul(A, C); auto fetch_ids = {D->id, E->id}; - auto program = net_builder.Build(); + auto program = net_builder.Build(); RunModelTest(program, {A, B, C}, fetch_ids); } @@ -110,7 +118,7 @@ TEST(DenseMergePass, Test_Matmul_1) { auto E = net_builder.Matmul(B, C); auto fetch_ids = {D->id, E->id}; - auto program = net_builder.Build(); + auto program = net_builder.Build(); RunModelTest(program, {A, B, C}, fetch_ids); } @@ -127,7 +135,7 @@ TEST(DenseMergePass, Test_Matmul_2) { auto I = net_builder.Matmul(D, E); auto fetch_ids = {F->id, G->id, H->id, I->id}; - auto program = net_builder.Build(); + auto program = net_builder.Build(); RunModelTest(program, {A, B, C, D, E}, fetch_ids); } @@ -144,7 +152,7 @@ TEST(DenseMergePass, Test_Matmul_3) { auto I = net_builder.Matmul(C, E); auto fetch_ids = {F->id, G->id, H->id, I->id}; - auto program = net_builder.Build(); + auto program = net_builder.Build(); RunModelTest(program, {A, B, C, D, E}, fetch_ids); } @@ -160,7 +168,7 @@ TEST(DenseMergePass, Test_Matmul_4) { auto I = net_builder.Matmul(B, D); auto fetch_ids = {F->id, G->id, H->id, I->id}; - auto program = net_builder.Build(); + auto program = net_builder.Build(); RunModelTest(program, {A, B, C, D}, fetch_ids); } diff --git a/paddle/cinn/hlir/pass/dot_merger.cc b/paddle/cinn/hlir/pass/dot_merger.cc index 5ee1d79365c5c..d241d78815cf3 100644 --- a/paddle/cinn/hlir/pass/dot_merger.cc +++ b/paddle/cinn/hlir/pass/dot_merger.cc @@ -28,10 +28,11 @@ using framework::NodeData; using framework::Operator; template -using OpValueType = cinn::hlir::framework::OpValueType; -using infershape_t = std::function(const std::vector&, - const framework::AttrMapType&)>; -using inferdtype_t = std::function(const std::vector&, const framework::AttrMapType&)>; +using OpValueType = cinn::hlir::framework::OpValueType; +using infershape_t = std::function( + const std::vector&, const framework::AttrMapType&)>; +using inferdtype_t = std::function( + const std::vector&, const framework::AttrMapType&)>; using dtype_dict_t = absl::flat_hash_map; using shape_dict_t = absl::flat_hash_map; @@ -57,8 +58,12 @@ T get_attr(Node* instr, const std::string& attr, T def) { return absl::get(instr->attrs.attr_store.at(attr)); } -NodeData* input_operand(Node* instr, int idx) { return instr->inlinks_in_order()[idx]->source()->safe_as(); } -NodeData* output_operand(Node* instr, int idx) { return instr->outlinks_in_order()[idx]->sink()->safe_as(); } +NodeData* input_operand(Node* instr, int idx) { + return instr->inlinks_in_order()[idx]->source()->safe_as(); +} +NodeData* output_operand(Node* instr, int idx) { + return instr->outlinks_in_order()[idx]->sink()->safe_as(); +} void remove_node(framework::Graph* graph, GraphNode* node) { auto inlinks = node->inlinks(); @@ -85,7 +90,7 @@ bool all_equal(const T& arg, const Args&... args) { void PrintAllMatmulOps(framework::Graph* graph, const std::string& dot_type) { auto& dtype_dict{graph->GetMutableAttrs("inferdtype")}; auto& shape_dict{graph->GetMutableAttrs("infershape")}; - auto nodes = std::get<0>(graph->topological_order()); + auto nodes = std::get<0>(graph->topological_order()); auto print_shape = [](const std::vector& shape) -> std::string { std::stringstream ss; for (auto i : shape) { @@ -96,14 +101,16 @@ void PrintAllMatmulOps(framework::Graph* graph, const std::string& dot_type) { for (auto* n : nodes) { auto* op_node = n->safe_as(); if (op_node && op_node->op()->name == dot_type) { - auto a_id = input_operand(op_node, 0)->id(); - auto b_id = input_operand(op_node, 1)->id(); + auto a_id = input_operand(op_node, 0)->id(); + auto b_id = input_operand(op_node, 1)->id(); auto a_shape = shape_dict.at(a_id); auto b_shape = shape_dict.at(b_id); LOG(INFO) << "Find op: " << dot_type; LOG(INFO) << "Attrs: " - << "trans_a = " << get_attr(op_node, "trans_a", false) << ", " - << "trans_b = " << get_attr(op_node, "trans_b", false) << ", " + << "trans_a = " << get_attr(op_node, "trans_a", false) + << ", " + << "trans_b = " << get_attr(op_node, "trans_b", false) + << ", " << "a: " << a_id << ", " << print_shape(a_shape) << " " << "b: " << b_id << ", " << print_shape(b_shape); } @@ -134,7 +141,8 @@ class DotBuilder { NodeData* Concat(int axis, std::vector inputs) { const std::string type{"concat"}; - auto instr = common::Shared(new Node(framework::Operator::Get(type), type, node_name(type))); + auto instr = common::Shared( + new Node(framework::Operator::Get(type), type, node_name(type))); instr->attrs.attr_store["axis"] = axis; for (auto* in : inputs) { in->LinkTo(instr.get()); @@ -143,29 +151,39 @@ class DotBuilder { return output; } - NodeData* Matmul(bool trans_a, bool trans_b, bool trans_out, float alpha, NodeData* lhs, NodeData* rhs) { + NodeData* Matmul(bool trans_a, + bool trans_b, + bool trans_out, + float alpha, + NodeData* lhs, + NodeData* rhs) { const std::string type{dot_type_}; - auto instr = common::Shared(new Node(framework::Operator::Get(type), type, node_name(type))); - matmul_ = instr.get(); - instr->attrs.attr_store["trans_a"] = trans_a; - instr->attrs.attr_store["trans_b"] = trans_b; + auto instr = common::Shared( + new Node(framework::Operator::Get(type), type, node_name(type))); + matmul_ = instr.get(); + instr->attrs.attr_store["trans_a"] = trans_a; + instr->attrs.attr_store["trans_b"] = trans_b; instr->attrs.attr_store["trans_out"] = trans_out; - instr->attrs.attr_store["alpha"] = alpha; + instr->attrs.attr_store["alpha"] = alpha; lhs->LinkTo(instr.get()); rhs->LinkTo(instr.get()); auto* output = Var(instr); return output; } - NodeData* Slice( - std::vector axes, std::vector starts, std::vector ends, NodeData* input, NodeData* output) { + NodeData* Slice(std::vector axes, + std::vector starts, + std::vector ends, + NodeData* input, + NodeData* output) { const std::string type{"slice"}; - auto instr = common::Shared(new Node(framework::Operator::Get(type), type, node_name(type))); - instr->attrs.attr_store["axes"] = std::move(axes); - instr->attrs.attr_store["starts"] = std::move(starts); - instr->attrs.attr_store["ends"] = std::move(ends); - instr->attrs.attr_store["infer_flags"] = std::vector{}; - instr->attrs.attr_store["strides"] = std::vector{}; + auto instr = common::Shared( + new Node(framework::Operator::Get(type), type, node_name(type))); + instr->attrs.attr_store["axes"] = std::move(axes); + instr->attrs.attr_store["starts"] = std::move(starts); + instr->attrs.attr_store["ends"] = std::move(ends); + instr->attrs.attr_store["infer_flags"] = std::vector{}; + instr->attrs.attr_store["strides"] = std::vector{}; instr->attrs.attr_store["decrease_axis"] = std::vector{}; input->LinkTo(instr.get()); instr->LinkTo(output); @@ -176,7 +194,8 @@ class DotBuilder { } std::string node_name(std::string prefix) const { - return std::move(prefix.append("__dot_merger_").append(std::to_string(idx_++))); + return std::move( + prefix.append("__dot_merger_").append(std::to_string(idx_++))); } Node* matmul_op() const { return matmul_; } @@ -215,9 +234,11 @@ class DotMergerPass { merge_nodes.push_back(a); for (size_t j = i + 1; j < dots.size(); ++j) { auto* b = dots[j]; - if (!b || nodes_to_remove.count(a) || nodes_to_remove.count(b) || accessible(a, b) || accessible(b, a)) { + if (!b || nodes_to_remove.count(a) || nodes_to_remove.count(b) || + accessible(a, b) || accessible(b, a)) { VLOG(5) << "Because nodes `" << a->id() << "` and `" << b->id() - << " have data dependencies or have been deleted, they cannot be merged."; + << " have data dependencies or have been deleted, they " + "cannot be merged."; continue; } if (!is_merge(&builder, a, b)) { @@ -247,7 +268,8 @@ class DotMergerPass { } private: - static std::map> GetClusters(framework::Graph* graph, const std::string& op_type) { + static std::map> GetClusters( + framework::Graph* graph, const std::string& op_type) { std::map> clusters; auto nodes = std::get<0>(graph->topological_order()); for (auto* n : nodes) { @@ -276,43 +298,52 @@ class DotMergerPass { static bool is_merge(DotBuilder* builder, Node* a, Node* b) { CHECK(a && b) << "The pointer of node is illegal."; - const std::array trans_a{get_attr(a, "trans_a", false), get_attr(b, "trans_a", false)}; - const std::array trans_b{get_attr(a, "trans_b", false), get_attr(b, "trans_b", false)}; - const std::array trans_out{get_attr(a, "trans_out", false), get_attr(b, "trans_out", false)}; - const std::array alpha{get_attr(a, "alpha", 1.f), get_attr(b, "alpha", 1.f)}; + const std::array trans_a{get_attr(a, "trans_a", false), + get_attr(b, "trans_a", false)}; + const std::array trans_b{get_attr(a, "trans_b", false), + get_attr(b, "trans_b", false)}; + const std::array trans_out{get_attr(a, "trans_out", false), + get_attr(b, "trans_out", false)}; + const std::array alpha{get_attr(a, "alpha", 1.f), + get_attr(b, "alpha", 1.f)}; if (!all_equal(trans_a, trans_b, trans_out, alpha)) { return false; } NodeData *shared_input{}, *input_a{}, *input_b{}; if (input_operand(a, 1) == input_operand(b, 1)) { shared_input = input_operand(a, 1); - input_a = input_operand(a, 0); - input_b = input_operand(b, 0); + input_a = input_operand(a, 0); + input_b = input_operand(b, 0); } else if (input_operand(a, 0) == input_operand(b, 0)) { shared_input = input_operand(a, 0); - input_a = input_operand(a, 1); - input_b = input_operand(b, 1); + input_a = input_operand(a, 1); + input_b = input_operand(b, 1); } else { return false; } - auto* output_a = output_operand(a, 0); - auto* output_b = output_operand(b, 0); + auto* output_a = output_operand(a, 0); + auto* output_b = output_operand(b, 0); auto& graph_outs = builder->graph()->outputs; for (auto* n : {shared_input, input_a, input_b}) { - if (std::find(graph_outs.begin(), graph_outs.end(), n) != graph_outs.end()) { + if (std::find(graph_outs.begin(), graph_outs.end(), n) != + graph_outs.end()) { return false; } } return true; } - static Node* NewMergeDots(DotBuilder* builder, std::vector merge_nodes) { - const std::array trans_a{get_attr(merge_nodes[0], "trans_a", false), - get_attr(merge_nodes[1], "trans_a", false)}; - const std::array trans_b{get_attr(merge_nodes[0], "trans_b", false), - get_attr(merge_nodes[1], "trans_b", false)}; - const std::array alpha{get_attr(merge_nodes[0], "alpha", 1.f), - get_attr(merge_nodes[1], "alpha", 1.f)}; + static Node* NewMergeDots(DotBuilder* builder, + std::vector merge_nodes) { + const std::array trans_a{ + get_attr(merge_nodes[0], "trans_a", false), + get_attr(merge_nodes[1], "trans_a", false)}; + const std::array trans_b{ + get_attr(merge_nodes[0], "trans_b", false), + get_attr(merge_nodes[1], "trans_b", false)}; + const std::array alpha{ + get_attr(merge_nodes[0], "alpha", 1.f), + get_attr(merge_nodes[1], "alpha", 1.f)}; bool lhs{true}; int axis{1}; @@ -320,7 +351,7 @@ class DotMergerPass { if (input_operand(merge_nodes[0], 1) == input_operand(merge_nodes[1], 1)) { shared_input = input_operand(merge_nodes[0], 1); - lhs = false; + lhs = false; if (!trans_a[0]) { axis = 0; } else if (trans_b[0]) { @@ -333,24 +364,34 @@ class DotMergerPass { auto shape_shared = builder->shape_dict().at(shared_input->id()); concat_nodes.push_back(input_operand(merge_nodes[0], axis)); for (size_t i = 1; i < merge_nodes.size(); ++i) { - auto shape_a = builder->shape_dict().at(input_operand(merge_nodes[i - 1], axis)->id()); - auto shape_b = builder->shape_dict().at(input_operand(merge_nodes[i], axis)->id()); + auto shape_a = builder->shape_dict().at( + input_operand(merge_nodes[i - 1], axis)->id()); + auto shape_b = + builder->shape_dict().at(input_operand(merge_nodes[i], axis)->id()); CHECK_EQ(shape_a[1 - axis], shape_b[1 - axis]) - << "The shape of matmul is error. " << shape_a.size() << ", " << shape_b.size(); + << "The shape of matmul is error. " << shape_a.size() << ", " + << shape_b.size(); concat_nodes.push_back(input_operand(merge_nodes[i], axis)); } auto* concat_out = builder->Concat(axis, concat_nodes); NodeData* matmul_out{}; if (!lhs) { - matmul_out = builder->Matmul(trans_a[0], trans_b[0], false, alpha[0], concat_out, shared_input); + matmul_out = builder->Matmul( + trans_a[0], trans_b[0], false, alpha[0], concat_out, shared_input); } else { - matmul_out = builder->Matmul(trans_a[0], trans_b[0], false, alpha[0], shared_input, concat_out); + matmul_out = builder->Matmul( + trans_a[0], trans_b[0], false, alpha[0], shared_input, concat_out); } auto start_shape = 0; for (size_t i = 0; i < concat_nodes.size(); ++i) { - auto shape = builder->shape_dict().at(input_operand(merge_nodes[i], axis)->id()); + auto shape = + builder->shape_dict().at(input_operand(merge_nodes[i], axis)->id()); auto* output = output_operand(merge_nodes[i], 0); - builder->Slice({axis}, {start_shape}, {start_shape + shape[axis]}, matmul_out, output); + builder->Slice({axis}, + {start_shape}, + {start_shape + shape[axis]}, + matmul_out, + output); start_shape += shape[axis]; } return builder->matmul_op(); @@ -358,10 +399,14 @@ class DotMergerPass { static Node* MergeDots(DotBuilder* builder, Node* a, Node* b) { CHECK(a && b) << "The pointer of node is illegal."; - const std::array trans_a{get_attr(a, "trans_a", false), get_attr(b, "trans_a", false)}; - const std::array trans_b{get_attr(a, "trans_b", false), get_attr(b, "trans_b", false)}; - const std::array trans_out{get_attr(a, "trans_out", false), get_attr(b, "trans_out", false)}; - const std::array alpha{get_attr(a, "alpha", 1.f), get_attr(b, "alpha", 1.f)}; + const std::array trans_a{get_attr(a, "trans_a", false), + get_attr(b, "trans_a", false)}; + const std::array trans_b{get_attr(a, "trans_b", false), + get_attr(b, "trans_b", false)}; + const std::array trans_out{get_attr(a, "trans_out", false), + get_attr(b, "trans_out", false)}; + const std::array alpha{get_attr(a, "alpha", 1.f), + get_attr(b, "alpha", 1.f)}; if (!all_equal(trans_a, trans_b, trans_out, alpha)) { return nullptr; } @@ -370,9 +415,9 @@ class DotMergerPass { NodeData *shared_input{}, *input_a{}, *input_b{}; if (input_operand(a, 1) == input_operand(b, 1)) { shared_input = input_operand(a, 1); - input_a = input_operand(a, 0); - input_b = input_operand(b, 0); - lhs = false; + input_a = input_operand(a, 0); + input_b = input_operand(b, 0); + lhs = false; if (!trans_a[0]) { axis = 0; } else if (trans_b[0]) { @@ -380,34 +425,43 @@ class DotMergerPass { } } else if (input_operand(a, 0) == input_operand(b, 0)) { shared_input = input_operand(a, 0); - input_a = input_operand(a, 1); - input_b = input_operand(b, 1); + input_a = input_operand(a, 1); + input_b = input_operand(b, 1); } else { return nullptr; } - auto* output_a = output_operand(a, 0); - auto* output_b = output_operand(b, 0); + auto* output_a = output_operand(a, 0); + auto* output_b = output_operand(b, 0); auto& graph_outs = builder->graph()->outputs; for (auto* n : {shared_input, input_a, input_b}) { - if (std::find(graph_outs.begin(), graph_outs.end(), n) != graph_outs.end()) { + if (std::find(graph_outs.begin(), graph_outs.end(), n) != + graph_outs.end()) { return nullptr; } } - CHECK(shared_input && input_a && input_b) << "The input node type must be variable."; + CHECK(shared_input && input_a && input_b) + << "The input node type must be variable."; auto shape_shared = builder->shape_dict().at(shared_input->id()); - auto shape_a = builder->shape_dict().at(input_a->id()); - auto shape_b = builder->shape_dict().at(input_b->id()); + auto shape_a = builder->shape_dict().at(input_a->id()); + auto shape_b = builder->shape_dict().at(input_b->id()); CHECK_EQ(shape_a[1 - axis], shape_b[1 - axis]) - << "The shape of matmul is error. " << shape_a.size() << ", " << shape_b.size(); + << "The shape of matmul is error. " << shape_a.size() << ", " + << shape_b.size(); auto* concat_out = builder->Concat(axis, {input_a, input_b}); NodeData* matmul_out{}; if (!lhs) { - matmul_out = builder->Matmul(trans_a[0], trans_b[0], false, alpha[0], concat_out, shared_input); + matmul_out = builder->Matmul( + trans_a[0], trans_b[0], false, alpha[0], concat_out, shared_input); } else { - matmul_out = builder->Matmul(trans_a[0], trans_b[0], false, alpha[0], shared_input, concat_out); + matmul_out = builder->Matmul( + trans_a[0], trans_b[0], false, alpha[0], shared_input, concat_out); } builder->Slice({axis}, {0}, {shape_a[axis]}, matmul_out, output_a); - builder->Slice({axis}, {shape_a[axis]}, {shape_a[axis] + shape_b[axis]}, matmul_out, output_b); + builder->Slice({axis}, + {shape_a[axis]}, + {shape_a[axis] + shape_b[axis]}, + matmul_out, + output_b); return builder->matmul_op(); } }; @@ -418,7 +472,8 @@ void DotMergerPassFunc(framework::Graph* graph) { // The cublas gemm is not yet supported. for (auto& dot_type : {"matmul", "cublas_matmul"}) { int n = DotMergerPass::Apply(graph, dot_type); - VLOG(3) << "The fusion of `" << dot_type << "` was performed " << n << " times."; + VLOG(3) << "The fusion of `" << dot_type << "` was performed " << n + << " times."; } } diff --git a/paddle/cinn/hlir/pass/fusion_helper_base.h b/paddle/cinn/hlir/pass/fusion_helper_base.h index 6f45fe5dc03d6..4b4a3a7ccaf7c 100644 --- a/paddle/cinn/hlir/pass/fusion_helper_base.h +++ b/paddle/cinn/hlir/pass/fusion_helper_base.h @@ -33,9 +33,12 @@ using namespace framework; class FusionHelperBase { public: FusionHelperBase(const framework::Graph* graph) - : shape_dict_(graph->GetAttrs>("infershape")), target_(graph->target_) { + : shape_dict_(graph->GetAttrs>( + "infershape")), + target_(graph->target_) { // get op pattern dict - op_pattern_dict_ = &framework::Operator::GetAttrs("OpPattern"); + op_pattern_dict_ = + &framework::Operator::GetAttrs("OpPattern"); // output node set for (auto node_data : graph->outputs) { CHECK(node_data->source_node.get()); @@ -45,11 +48,13 @@ class FusionHelperBase { public: OpPatternKind GetOpKind(const framework::Node* node) const { - CHECK(op_pattern_dict_->Find(node->op())) << "Don't find the pattern of op : " << node->id(); + CHECK(op_pattern_dict_->Find(node->op())) + << "Don't find the pattern of op : " << node->id(); auto kind = op_pattern_dict_[0][node->op()]; if (kind == framework::kBroadcast) { - // As binary op was defined as broadcast, actually it should be element-wise. + // As binary op was defined as broadcast, actually it should be + // element-wise. if (node->op()->name != "broadcast_to") { return framework::kElementWise; } @@ -59,7 +64,8 @@ class FusionHelperBase { } static bool IsConstOp(const framework::Node* node) { - static std::unordered_set const_op_type = {"const_scalar", "fill_constant", "arange"}; + static std::unordered_set const_op_type = { + "const_scalar", "fill_constant", "arange"}; if (const_op_type.count(node->op()->name)) { return true; } else { @@ -71,7 +77,8 @@ class FusionHelperBase { std::vector consumer_node_data; for (auto& edge : node->outlinks_in_order()) { auto output = edge->sink()->safe_as(); - CHECK(output) << "The op \"" << node->id() << "\" output should not be empty!"; + CHECK(output) << "The op \"" << node->id() + << "\" output should not be empty!"; consumer_node_data.push_back(output); } return consumer_node_data; @@ -85,21 +92,23 @@ class FusionHelperBase { shape_t GetNodeDataShape(const Node* node) const { auto* node_data = GetNodeData(node); - CHECK(shape_dict_.count(node_data->id())) << "Can't find " << node_data->id() << " 's shape!"; + CHECK(shape_dict_.count(node_data->id())) + << "Can't find " << node_data->id() << " 's shape!"; return shape_dict_.at(node_data->id()); } shape_t GetNodeInputShape(const Node* node) const { auto node_datas = GetProducerNodeData(node); CHECK_GT(node_datas.size(), 0); - CHECK(shape_dict_.count(node_datas[0]->id())) << "Can't find " << node_datas[0]->id() << " 's shape!"; + CHECK(shape_dict_.count(node_datas[0]->id())) + << "Can't find " << node_datas[0]->id() << " 's shape!"; return shape_dict_.at(node_datas[0]->id()); } static std::vector GetProducerNodeData(const Node* node) { std::vector producer_node_data; for (auto& edge : node->inlinks_in_order()) { - auto graph_node = edge->source(); + auto graph_node = edge->source(); auto producer_data = graph_node->safe_as(); CHECK(producer_data); producer_node_data.push_back(producer_data); @@ -110,7 +119,7 @@ class FusionHelperBase { std::vector GetProducerNode(const Node* node) const { std::vector producer_node; for (auto& edge : node->inlinks_in_order()) { - auto graph_node = edge->source(); + auto graph_node = edge->source(); auto producer_data = graph_node->safe_as(); CHECK(producer_data); auto producer = producer_data->source_node.get(); @@ -132,7 +141,8 @@ class FusionHelperBase { return consumer_nodes; } - bool WithoutLastDimInReduce(const std::vector& inshape, const std::vector& axes) const { + bool WithoutLastDimInReduce(const std::vector& inshape, + const std::vector& axes) const { // if last axis is in reduce. if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || std::find(axes.begin(), axes.end(), -1) != axes.end()) { @@ -155,7 +165,7 @@ class FusionHelperBase { auto producers = GetProducerNodeData(node); CHECK_GT(producers.size(), 0); auto inshape = shape_dict_.at(producers[0]->id()); - auto axes = absl::get>(node->attrs.attr_store.at("dim")); + auto axes = absl::get>(node->attrs.attr_store.at("dim")); if (WithoutLastDimInReduce(inshape, axes)) { int lane = 1; for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { @@ -175,14 +185,17 @@ class FusionHelperBase { break; } } - // if lane > (max_num_threads / 2),the loop break from lane > max_num_threads / 2. + // if lane > (max_num_threads / 2),the loop break from lane > + // max_num_threads / 2. int axis = lane > (max_num_threads / 2) ? axes[index] : axes[index + 1]; if (lane <= max_num_threads) { return lane * sizeof(float); } else { int prefix = inshape[axis]; - int tail = lane / prefix; - for (int idx = max_num_threads / tail; idx > ((max_num_threads / 2) / tail); --idx) { + int tail = lane / prefix; + for (int idx = max_num_threads / tail; + idx > ((max_num_threads / 2) / tail); + --idx) { if (prefix % idx == 0) { return idx * tail * sizeof(float); } diff --git a/paddle/cinn/hlir/pass/fusion_merge_pass.cc b/paddle/cinn/hlir/pass/fusion_merge_pass.cc index 50095cc762a05..af28654703702 100644 --- a/paddle/cinn/hlir/pass/fusion_merge_pass.cc +++ b/paddle/cinn/hlir/pass/fusion_merge_pass.cc @@ -30,12 +30,13 @@ using common::GraphEdge; using common::GraphNode; using Comparator = Graph::Group::SharedGroupComparator; -using Hasher = Graph::Group::SharedGroupHasher; +using Hasher = Graph::Group::SharedGroupHasher; -using GroupPtr = std::shared_ptr; +using GroupPtr = std::shared_ptr; using GroupList = std::vector; -using ConditionFunction = std::function; +using ConditionFunction = std::function; // Op Fusion Pass which performs Ops fusion, Ops are fused // "vertically", meaning producing Ops are fused into their consumers @@ -173,7 +174,9 @@ class FusionMergePassHelper : public FusionHelperBase { } } - bool HorizontalFusion(GroupPtr producer, std::unordered_set& consumers) { + bool HorizontalFusion( + GroupPtr producer, + std::unordered_set& consumers) { VLOG(3) << "HorizontalFusion...!"; if (consumers.size() <= 1) { return false; @@ -194,12 +197,14 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& candidate : candidates) { // check dependency if (IsDependencySimplify(producer, candidate, candidates)) { - VLOG(4) << "IsDependencySimplify, Can't fuse " << candidate->group_id << ", As it depency others!"; + VLOG(4) << "IsDependencySimplify, Can't fuse " << candidate->group_id + << ", As it depency others!"; continue; } if (IsDependency(producer, candidate, candidates)) { - VLOG(4) << "IsDependency, Can't fuse " << candidate->group_id << ", As it depency others!"; + VLOG(4) << "IsDependency, Can't fuse " << candidate->group_id + << ", As it depency others!"; continue; } @@ -210,14 +215,15 @@ class FusionMergePassHelper : public FusionHelperBase { // check each fusionable groups bool fusionable = false; - auto& relation = fusion_relation_map_[candidate->op_pattern_kind]; + auto& relation = fusion_relation_map_[candidate->op_pattern_kind]; for (auto& groups : fusionable_consumers) { auto& last = groups.back(); if (!relation.horizontal_relation.count(last->op_pattern_kind)) { continue; } - if (!relation.horizontal_relation[last->op_pattern_kind](this, candidate, last)) { + if (!relation.horizontal_relation[last->op_pattern_kind]( + this, candidate, last)) { continue; } @@ -256,8 +262,10 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& consumer : consumers) { VLOG(3) << "fuse consumer " << consumer->group_id << " into fused_group!"; // update depth - fused_group->max_depth = std::max(fused_group->max_depth, consumer->max_depth); - fused_group->min_depth = std::min(fused_group->min_depth, consumer->min_depth); + fused_group->max_depth = + std::max(fused_group->max_depth, consumer->max_depth); + fused_group->min_depth = + std::min(fused_group->min_depth, consumer->min_depth); // update group id if (fused_group->group_id.size()) { fused_group->group_id += "_" + consumer->group_id; @@ -266,7 +274,8 @@ class FusionMergePassHelper : public FusionHelperBase { } // set op pattern kind fused_group->op_pattern_kind = - static_cast(fused_group->op_pattern_kind) >= static_cast(consumer->op_pattern_kind) + static_cast(fused_group->op_pattern_kind) >= + static_cast(consumer->op_pattern_kind) ? fused_group->op_pattern_kind : consumer->op_pattern_kind; // input nodes @@ -333,9 +342,11 @@ class FusionMergePassHelper : public FusionHelperBase { // find the first consumer. CHECK(fusion_groups_index_.count(consumer)) - << "Can't find consumer " << consumer->group_id << " index in fusion_groups_index_!"; + << "Can't find consumer " << consumer->group_id + << " index in fusion_groups_index_!"; if (first_consumer.get()) { - if (fusion_groups_index_[consumer] < fusion_groups_index_[first_consumer]) { + if (fusion_groups_index_[consumer] < + fusion_groups_index_[first_consumer]) { first_consumer = consumer; } } else { @@ -354,14 +365,16 @@ class FusionMergePassHelper : public FusionHelperBase { } } - if (static_cast(framework::kReduction) > static_cast((consumers.back())->op_pattern_kind)) { + if (static_cast(framework::kReduction) > + static_cast((consumers.back())->op_pattern_kind)) { auto consumer = consumers.back(); for (auto& node : consumer->master_nodes) { fused_group->master_nodes.insert(node); } } else { - for (auto consumer = consumers.rbegin(); consumer != consumers.rend(); ++consumer) { + for (auto consumer = consumers.rbegin(); consumer != consumers.rend(); + ++consumer) { Node* master_node = nullptr; for (auto& node : (*consumer)->master_nodes) { if (GetOpKind(node) != framework::kReduction) { @@ -370,21 +383,26 @@ class FusionMergePassHelper : public FusionHelperBase { } } if (master_node) { - VLOG(3) << "Insert Master node : " << master_node->id() << " into group : " << fused_group->group_id; + VLOG(3) << "Insert Master node : " << master_node->id() + << " into group : " << fused_group->group_id; fused_group->master_nodes.insert(master_node); break; } } } - auto postion = fusion_groups_index_[first_consumer]; - fusion_groups_[postion] = fused_group; + auto postion = fusion_groups_index_[first_consumer]; + fusion_groups_[postion] = fused_group; fusion_groups_index_[fused_group] = postion; - CHECK(fused_group->output_nodes.size()) << "No output node is found, " << fused_group->group_id; + CHECK(fused_group->output_nodes.size()) + << "No output node is found, " << fused_group->group_id; } - bool VerticalFusion(GroupPtr& producer, std::unordered_set& consumers, bool recompute) { + bool VerticalFusion( + GroupPtr& producer, + std::unordered_set& consumers, + bool recompute) { VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size(); auto& relation = fusion_relation_map_[producer->op_pattern_kind]; // if producer can't fuse others @@ -395,36 +413,44 @@ class FusionMergePassHelper : public FusionHelperBase { std::unordered_set fuse_consumers_unsafe; std::unordered_set fuse_consumers; for (auto& consumer : consumers) { - VLOG(4) << "Check consuemr " << consumer->group_id << " can fuse to producer " << producer->group_id; + VLOG(4) << "Check consuemr " << consumer->group_id + << " can fuse to producer " << producer->group_id; // if can't fuse if (!relation.vertical_relation.count(consumer->op_pattern_kind)) { - VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " << consumer->group_id; + VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " + << consumer->group_id; continue; } // if condition function is false - if (!relation.vertical_relation[consumer->op_pattern_kind](this, producer, consumer)) { - VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " << consumer->group_id; + if (!relation.vertical_relation[consumer->op_pattern_kind]( + this, producer, consumer)) { + VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " + << consumer->group_id; continue; } fuse_consumers_unsafe.insert(consumer); if (IsDependencySimplify(producer, consumer, consumers)) { - VLOG(4) << "IsDependencySimplify, Consumer " << consumer->group_id << " can't be master fused group!"; + VLOG(4) << "IsDependencySimplify, Consumer " << consumer->group_id + << " can't be master fused group!"; continue; } if (IsDependency(producer, consumer, consumers)) { - VLOG(4) << "IsDependency, Consumer " << consumer->group_id << " can't be master fused group!"; + VLOG(4) << "IsDependency, Consumer " << consumer->group_id + << " can't be master fused group!"; continue; } fuse_consumers.insert(consumer); } - VLOG(3) << "VerticalFusion, Number of fuse Consumers : " << fuse_consumers.size(); - VLOG(3) << "VerticalFusion, Number of unsafe fuse Consumers : " << fuse_consumers.size(); + VLOG(3) << "VerticalFusion, Number of fuse Consumers : " + << fuse_consumers.size(); + VLOG(3) << "VerticalFusion, Number of unsafe fuse Consumers : " + << fuse_consumers.size(); if (fuse_consumers.size() == 0) { return false; @@ -456,21 +482,27 @@ class FusionMergePassHelper : public FusionHelperBase { return false; } - void VerticalFuse(GroupPtr& producer, std::unordered_set& fusionable_consumers) { + void VerticalFuse( + GroupPtr& producer, + std::unordered_set& fusionable_consumers) { VLOG(3) << "VerticalFuse...!"; GroupList fused_groups; GroupPtr master_fuesd_group(nullptr); for (auto& consumer : fusionable_consumers) { auto fused_group = std::make_shared(); // update depth using consumer depth. - fused_group->max_depth = std::max(producer->max_depth, consumer->max_depth); - fused_group->min_depth = std::min(producer->min_depth, consumer->min_depth); + fused_group->max_depth = + std::max(producer->max_depth, consumer->max_depth); + fused_group->min_depth = + std::min(producer->min_depth, consumer->min_depth); // update group id fused_group->group_id = producer->group_id + "_" + consumer->group_id; - VLOG(3) << "fuse producer " << producer->group_id << " into consumer " << consumer->group_id; + VLOG(3) << "fuse producer " << producer->group_id << " into consumer " + << consumer->group_id; // fuse producer into fusion group fused_group->op_pattern_kind = - static_cast(producer->op_pattern_kind) >= static_cast(consumer->op_pattern_kind) + static_cast(producer->op_pattern_kind) >= + static_cast(consumer->op_pattern_kind) ? producer->op_pattern_kind : consumer->op_pattern_kind; // input nodes @@ -568,8 +600,9 @@ class FusionMergePassHelper : public FusionHelperBase { // sub group if (consumer->fused_sub_groups.size()) { for (auto& sub_group : consumer->fused_sub_groups) { - if (std::find(fused_group->fused_sub_groups.begin(), fused_group->fused_sub_groups.end(), sub_group) == - fused_group->fused_sub_groups.end()) { + if (std::find(fused_group->fused_sub_groups.begin(), + fused_group->fused_sub_groups.end(), + sub_group) == fused_group->fused_sub_groups.end()) { fused_group->fused_sub_groups.push_back(sub_group); } // update belong group @@ -583,15 +616,17 @@ class FusionMergePassHelper : public FusionHelperBase { fused_groups.push_back(fused_group); CHECK(fusion_groups_index_.count(consumer)) - << "Can't find consumer " << consumer->group_id << " index in fusion_groups_index_!"; - auto postion = fusion_groups_index_[consumer]; - fusion_groups_[postion] = fused_group; + << "Can't find consumer " << consumer->group_id + << " index in fusion_groups_index_!"; + auto postion = fusion_groups_index_[consumer]; + fusion_groups_[postion] = fused_group; fusion_groups_index_[fused_group] = postion; if (!master_fuesd_group.get()) { master_fuesd_group = fused_group; } - CHECK(fused_group->output_nodes.size()) << "No output node is found, " << fused_group->group_id; + CHECK(fused_group->output_nodes.size()) + << "No output node is found, " << fused_group->group_id; } for (auto& node : producer->output_nodes) { @@ -617,7 +652,8 @@ class FusionMergePassHelper : public FusionHelperBase { } if (be_output) { - VLOG(4) << "Insert Id " << node->id() << " Into Group " << master_fuesd_group->group_id; + VLOG(4) << "Insert Id " << node->id() << " Into Group " + << master_fuesd_group->group_id; master_fuesd_group->output_nodes.insert(node); } } @@ -633,15 +669,17 @@ class FusionMergePassHelper : public FusionHelperBase { } } - void RecomputeEleGraph(const GroupPtr& producer, - std::unordered_set& fusionable_consumers) { + void RecomputeEleGraph( + const GroupPtr& producer, + std::unordered_set& fusionable_consumers) { if (producer->op_pattern_kind != framework::kElementWise) { SelectConsumerToFuse(producer, fusionable_consumers); } } - void SelectConsumerToFuse(const GroupPtr& producer, - std::unordered_set& fusionable_consumers) { + void SelectConsumerToFuse( + const GroupPtr& producer, + std::unordered_set& fusionable_consumers) { // if is const op if (is_const_group(this, producer)) { std::unordered_set candidates; @@ -650,12 +688,14 @@ class FusionMergePassHelper : public FusionHelperBase { if (is_same_shape(this, producer, consumer)) { candidates.insert(consumer); } else { - VLOG(4) << "Fuse Producer : " << producer->group_id << " into Consumer : " << consumer->group_id; + VLOG(4) << "Fuse Producer : " << producer->group_id + << " into Consumer : " << consumer->group_id; consumer->group_id = producer->group_id + "_" + consumer->group_id; // just merge the node into group. - auto& sub_group = consumer->fused_sub_groups.front(); + auto& sub_group = consumer->fused_sub_groups.front(); sub_group->group_id = producer->group_id + "_" + sub_group->group_id; - sub_group->nodes.insert(sub_group->nodes.begin(), producer->CollectNodes()[0]); + sub_group->nodes.insert(sub_group->nodes.begin(), + producer->CollectNodes()[0]); sub_group->nodes_set.insert(producer->CollectNodes()[0]); // remove depency. consumer->input_nodes.erase(producer->CollectNodes()[0]); @@ -686,28 +726,43 @@ class FusionMergePassHelper : public FusionHelperBase { continue; } - auto producer_output_shape = this->GetNodeDataShape(*producer->output_nodes.begin()); - auto consumer_output_shape = this->GetNodeDataShape(*consumer->output_nodes.begin()); - auto consumer_master_input_shape = this->GetNodeInputShape(*(consumer->master_nodes.begin())); + auto producer_output_shape = + this->GetNodeDataShape(*producer->output_nodes.begin()); + auto consumer_output_shape = + this->GetNodeDataShape(*consumer->output_nodes.begin()); + auto consumer_master_input_shape = + this->GetNodeInputShape(*(consumer->master_nodes.begin())); int producer_output_numel = - std::accumulate(producer_output_shape.begin(), producer_output_shape.end(), 1, std::multiplies()); + std::accumulate(producer_output_shape.begin(), + producer_output_shape.end(), + 1, + std::multiplies()); int consumer_output_numel = - std::accumulate(consumer_output_shape.begin(), consumer_output_shape.end(), 1, std::multiplies()); - int consumer_master_input_numel = std::accumulate( - consumer_master_input_shape.begin(), consumer_master_input_shape.end(), 1, std::multiplies()); + std::accumulate(consumer_output_shape.begin(), + consumer_output_shape.end(), + 1, + std::multiplies()); + int consumer_master_input_numel = + std::accumulate(consumer_master_input_shape.begin(), + consumer_master_input_shape.end(), + 1, + std::multiplies()); if (producer_output_numel == consumer_output_numel) { candidates.push_back(consumer); continue; } - if (producer->op_pattern_kind != framework::kInjective && consumer->op_pattern_kind == framework::kReduction && + if (producer->op_pattern_kind != framework::kInjective && + consumer->op_pattern_kind == framework::kReduction && producer_output_numel == consumer_master_input_numel) { candidates.push_back(consumer); } } - sort(candidates.begin(), candidates.end(), [](const auto& lhs, const auto& rhs) { - return lhs->op_pattern_kind < rhs->op_pattern_kind; - }); + sort(candidates.begin(), + candidates.end(), + [](const auto& lhs, const auto& rhs) { + return lhs->op_pattern_kind < rhs->op_pattern_kind; + }); fusionable_consumers.clear(); if (candidates.size()) { @@ -724,8 +779,10 @@ class FusionMergePassHelper : public FusionHelperBase { auto shape0 = this->GetNodeDataShape(*producer->output_nodes.begin()); auto shape1 = this->GetNodeDataShape(*consumer->output_nodes.begin()); - if (std::accumulate(shape0.begin(), shape0.end(), 1, std::multiplies()) == - std::accumulate(shape1.begin(), shape1.end(), 1, std::multiplies())) { + if (std::accumulate( + shape0.begin(), shape0.end(), 1, std::multiplies()) == + std::accumulate( + shape1.begin(), shape1.end(), 1, std::multiplies())) { candidates.insert(consumer); } } @@ -737,9 +794,10 @@ class FusionMergePassHelper : public FusionHelperBase { } } - bool IsDependency(const GroupPtr& producer_g, - const GroupPtr& consumer, - const std::unordered_set& consumers) { + bool IsDependency( + const GroupPtr& producer_g, + const GroupPtr& consumer, + const std::unordered_set& consumers) { std::queue candidates; candidates.push(consumer); @@ -763,9 +821,10 @@ class FusionMergePassHelper : public FusionHelperBase { return false; } - bool IsDependencySimplify(const GroupPtr& producer_g, - const GroupPtr& consumer, - const std::unordered_set& consumers) { + bool IsDependencySimplify( + const GroupPtr& producer_g, + const GroupPtr& consumer, + const std::unordered_set& consumers) { std::queue candidates; candidates.push(consumer); // check upper. @@ -865,16 +924,16 @@ class FusionMergePassHelper : public FusionHelperBase { VLOG(3) << "InitFusionGroupsAndIndex...!"; // init the postion of groups in fusion groups. for (int idx = 0; idx < fusion_groups_.size(); ++idx) { - auto group = fusion_groups_[idx]; + auto group = fusion_groups_[idx]; auto belong_group = std::make_shared(); // copy from group. - belong_group->max_depth = group->depth; - belong_group->min_depth = group->depth; - belong_group->group_id = group->group_id; - belong_group->input_nodes = group->input_nodes; - belong_group->output_nodes = group->output_nodes; + belong_group->max_depth = group->depth; + belong_group->min_depth = group->depth; + belong_group->group_id = group->group_id; + belong_group->input_nodes = group->input_nodes; + belong_group->output_nodes = group->output_nodes; belong_group->op_pattern_kind = group->op_pattern_kind; - belong_group->master_nodes = group->master_nodes; + belong_group->master_nodes = group->master_nodes; belong_group->producer_groups = group->producer_groups; belong_group->consumer_groups = group->consumer_groups; belong_group->fused_sub_groups.push_back(group); @@ -911,97 +970,109 @@ class FusionMergePassHelper : public FusionHelperBase { { auto& relation = fusion_relation_map_[OpPatternKind::kElementWise]; // horizontal - relation.horizontal_relation = {{framework::kElementWise, is_same_size}, - // element-wise and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // element-wise and injective op must be horizontal relation. - {OpPatternKind::kInjective, is_same_size}, - // element-wise and reduce op must be horizontal relation. - {OpPatternKind::kReduction, honrizontal_elementwise_fuse_reduce}}; + relation.horizontal_relation = { + {framework::kElementWise, is_same_size}, + // element-wise and broadcast op must be horizontal relation. + {OpPatternKind::kBroadcast, is_same_size}, + // element-wise and injective op must be horizontal relation. + {OpPatternKind::kInjective, is_same_size}, + // element-wise and reduce op must be horizontal relation. + {OpPatternKind::kReduction, honrizontal_elementwise_fuse_reduce}}; // vertical - relation.vertical_relation = {{OpPatternKind::kElementWise, is_same_size}, - // element-wise and broadcast can be vertical/horizontal relation. - {OpPatternKind::kBroadcast, elementwise_fuse_broadcast}, - // element-wise and injective op must be horizontal relation. - {OpPatternKind::kInjective, horizontal_with_injective}, - // element-wise and reduce can be vertical/horizontal relation. - {OpPatternKind::kReduction, elementwise_fuse_reduce}}; + relation.vertical_relation = { + {OpPatternKind::kElementWise, is_same_size}, + // element-wise and broadcast can be vertical/horizontal relation. + {OpPatternKind::kBroadcast, elementwise_fuse_broadcast}, + // element-wise and injective op must be horizontal relation. + {OpPatternKind::kInjective, horizontal_with_injective}, + // element-wise and reduce can be vertical/horizontal relation. + {OpPatternKind::kReduction, elementwise_fuse_reduce}}; } // kBroadcast { auto& relation = fusion_relation_map_[OpPatternKind::kBroadcast]; // horizontal - relation.horizontal_relation = {// broadcast and element-wise op must be horizontal relation. - {framework::kElementWise, is_same_size}, - // broadcast and broadcast op must be horizontal relation. - {framework::kBroadcast, is_same_size}, - // broadcast and injective op must be horizontal relation. - {OpPatternKind::kInjective, is_same_size}, - // broadcast and reduce op must be horizontal relation. - {OpPatternKind::kReduction, is_same_size}}; + relation.horizontal_relation = { + // broadcast and element-wise op must be horizontal relation. + {framework::kElementWise, is_same_size}, + // broadcast and broadcast op must be horizontal relation. + {framework::kBroadcast, is_same_size}, + // broadcast and injective op must be horizontal relation. + {OpPatternKind::kInjective, is_same_size}, + // broadcast and reduce op must be horizontal relation. + {OpPatternKind::kReduction, is_same_size}}; // vertical - relation.vertical_relation = {// broadcast and element-wise op must be vertical relation. - {OpPatternKind::kElementWise, is_same_size}, - // broadcast and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // broadcast and injective op must be horizontal relation. - {OpPatternKind::kInjective, horizontal_with_injective}, - // broadcast and reduce must be vertical relation. - {OpPatternKind::kReduction, broadcast_fuse_reduce}}; + relation.vertical_relation = { + // broadcast and element-wise op must be vertical relation. + {OpPatternKind::kElementWise, is_same_size}, + // broadcast and broadcast op must be horizontal relation. + {OpPatternKind::kBroadcast, is_same_size}, + // broadcast and injective op must be horizontal relation. + {OpPatternKind::kInjective, horizontal_with_injective}, + // broadcast and reduce must be vertical relation. + {OpPatternKind::kReduction, broadcast_fuse_reduce}}; } // kInjective { auto& relation = fusion_relation_map_[OpPatternKind::kInjective]; // horizontal - relation.horizontal_relation = {// injective and element-wise op must be horizontal relation. - {OpPatternKind::kElementWise, is_same_size}, - // injective and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // injective and injective op must be horizontal relation. - {OpPatternKind::kInjective, is_same_size}, - // injective and reduce must be horizontal relation. - {OpPatternKind::kReduction, is_same_size}}; + relation.horizontal_relation = { + // injective and element-wise op must be horizontal relation. + {OpPatternKind::kElementWise, is_same_size}, + // injective and broadcast op must be horizontal relation. + {OpPatternKind::kBroadcast, is_same_size}, + // injective and injective op must be horizontal relation. + {OpPatternKind::kInjective, is_same_size}, + // injective and reduce must be horizontal relation. + {OpPatternKind::kReduction, is_same_size}}; // vertical - relation.vertical_relation = {// injective and element-wise op must be horizontal relation. - {OpPatternKind::kElementWise, is_same_size}, - // injective and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // injective and injective op must be horizontal relation. - {OpPatternKind::kInjective, horizontal_with_injective}, - // injective and reduce can be horizontal/vertical relation. - {OpPatternKind::kReduction, injective_horizontal_with_reduce}}; + relation.vertical_relation = { + // injective and element-wise op must be horizontal relation. + {OpPatternKind::kElementWise, is_same_size}, + // injective and broadcast op must be horizontal relation. + {OpPatternKind::kBroadcast, is_same_size}, + // injective and injective op must be horizontal relation. + {OpPatternKind::kInjective, horizontal_with_injective}, + // injective and reduce can be horizontal/vertical relation. + {OpPatternKind::kReduction, injective_horizontal_with_reduce}}; } // kReduction { auto& relation = fusion_relation_map_[OpPatternKind::kReduction]; // horizontal - relation.horizontal_relation = {// reduce and element-wise op must be horizontal relation. - {OpPatternKind::kElementWise, honrizontal_elementwise_fuse_reduce}, - // reduce and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // reduce and injective op must be horizontal relation. - {OpPatternKind::kInjective, is_same_size}, - // reduce and reduce must be horizontal relation. - {OpPatternKind::kReduction, reduce_fuse_reduce}}; + relation.horizontal_relation = { + // reduce and element-wise op must be horizontal relation. + {OpPatternKind::kElementWise, honrizontal_elementwise_fuse_reduce}, + // reduce and broadcast op must be horizontal relation. + {OpPatternKind::kBroadcast, is_same_size}, + // reduce and injective op must be horizontal relation. + {OpPatternKind::kInjective, is_same_size}, + // reduce and reduce must be horizontal relation. + {OpPatternKind::kReduction, reduce_fuse_reduce}}; // vertical - relation.vertical_relation = {// reduce and elementwise can be horizontal/vertical relation. - {OpPatternKind::kElementWise, reduce_fuse_elementwise}, - // reduce and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, reduce_fuse_broadcast}, - // reduce and injective op must be horizontal relation. - {OpPatternKind::kInjective, horizontal_with_injective}, - // reduce and reduce must be horizontal relation. - {OpPatternKind::kReduction, reduce_fuse_reduce}}; + relation.vertical_relation = { + // reduce and elementwise can be horizontal/vertical relation. + {OpPatternKind::kElementWise, reduce_fuse_elementwise}, + // reduce and broadcast op must be horizontal relation. + {OpPatternKind::kBroadcast, reduce_fuse_broadcast}, + // reduce and injective op must be horizontal relation. + {OpPatternKind::kInjective, horizontal_with_injective}, + // reduce and reduce must be horizontal relation. + {OpPatternKind::kReduction, reduce_fuse_reduce}}; } } GroupList fusion_groups_; std::unordered_map fusion_groups_index_; - std::unordered_map> input_to_consumers_; + std::unordered_map> + input_to_consumers_; struct Relation { - std::unordered_map vertical_relation; - std::unordered_map horizontal_relation; + std::unordered_map + vertical_relation; + std::unordered_map + horizontal_relation; }; std::unordered_map fusion_relation_map_; }; @@ -1023,7 +1094,8 @@ void FusionMergePassInternal(Graph* graph) { CINN_REGISTER_HELPER(FusionMergePass) { CINN_REGISTER_PASS(FusionMergePass) .describe( - "Fusion Merge Pass which performs Fusion-Ops fusion, Producer Fusion-Ops are fused into Consumer Fusion-Ops " + "Fusion Merge Pass which performs Fusion-Ops fusion, Producer " + "Fusion-Ops are fused into Consumer Fusion-Ops " "with certain conditions.") .set_change_structure(false) .set_body(cinn::hlir::pass::FusionMergePassInternal); diff --git a/paddle/cinn/hlir/pass/fusion_merge_pass_test.cc b/paddle/cinn/hlir/pass/fusion_merge_pass_test.cc index 81658fcd6e019..f4582a5ce65be 100755 --- a/paddle/cinn/hlir/pass/fusion_merge_pass_test.cc +++ b/paddle/cinn/hlir/pass/fusion_merge_pass_test.cc @@ -34,7 +34,7 @@ TEST(FusionMergePass, ElementWise_Fusion_0) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -60,7 +60,7 @@ TEST(FusionMergePass, ElementWise_Fusion_1) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -89,7 +89,7 @@ TEST(FusionMergePass, ElementWise_Fusion_2) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -118,7 +118,7 @@ TEST(FusionMergePass, ElementWise_Fusion_3) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -147,7 +147,7 @@ TEST(FusionMergePass, ElementWise_Fusion_4) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -169,7 +169,7 @@ TEST(FusionMergePass, ElementWise_Fusion_5) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -194,7 +194,7 @@ TEST(FusionMergePass, Broadcast_Test_0) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -219,7 +219,7 @@ TEST(FusionMergePass, Broadcast_Test_1) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -244,7 +244,7 @@ TEST(FusionMergePass, Broadcast_Test_2) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -269,7 +269,7 @@ TEST(FusionMergePass, Broadcast_Test_3) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -296,7 +296,7 @@ TEST(FusionMergePass, Broadcast_Test_4) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -323,7 +323,7 @@ TEST(FusionMergePass, Broadcast_Test_5) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -347,7 +347,7 @@ TEST(FusionMergePass, Reduce_Test_0) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -370,7 +370,7 @@ TEST(FusionMergePass, Reduce_Test_1) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -396,7 +396,7 @@ TEST(FusionMergePass, Reduce_Test_2) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -422,7 +422,7 @@ TEST(FusionMergePass, Reduce_Test_3) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -449,7 +449,7 @@ TEST(FusionMergePass, Reduce_Test_4) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -473,7 +473,7 @@ TEST(FusionMergePass, Reduce_Test_5) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); diff --git a/paddle/cinn/hlir/pass/fusion_merge_pass_util.h b/paddle/cinn/hlir/pass/fusion_merge_pass_util.h index c50379d8079ad..108623eb2a400 100644 --- a/paddle/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/paddle/cinn/hlir/pass/fusion_merge_pass_util.h @@ -66,13 +66,17 @@ CONDITION_FUNC(is_same_size) { return true; } - auto size_0 = std::accumulate(output_var_0.begin(), output_var_0.end(), 1, std::multiplies()); - auto size_1 = std::accumulate(output_var_1.begin(), output_var_1.end(), 1, std::multiplies()); + auto size_0 = std::accumulate( + output_var_0.begin(), output_var_0.end(), 1, std::multiplies()); + auto size_1 = std::accumulate( + output_var_1.begin(), output_var_1.end(), 1, std::multiplies()); return size_0 == size_1; } -bool is_const_group(const FusionHelperBase* helper, const std::shared_ptr& group) { - return group->CollectNodes().size() == 1 && helper->IsConstOp(group->CollectNodes()[0]); +bool is_const_group(const FusionHelperBase* helper, + const std::shared_ptr& group) { + return group->CollectNodes().size() == 1 && + helper->IsConstOp(group->CollectNodes()[0]); }; CONDITION_FUNC(elementwise_fuse_broadcast) { @@ -104,10 +108,10 @@ CONDITION_FUNC(elementwise_fuse_broadcast) { CONDITION_FUNC(honrizontal_elementwise_fuse_reduce) { std::shared_ptr ele_group, reduce_group; if (first->op_pattern_kind == framework::kReduction) { - ele_group = second; + ele_group = second; reduce_group = first; } else { - ele_group = first; + ele_group = first; reduce_group = second; } // if same shape with horizontal relation @@ -115,12 +119,16 @@ CONDITION_FUNC(honrizontal_elementwise_fuse_reduce) { return true; } - shape_t ele_node_shape = helper->GetNodeDataShape(*ele_group->master_nodes.begin()); - int32_t size_ele = std::accumulate(ele_node_shape.begin(), ele_node_shape.end(), 1, std::multiplies()); + shape_t ele_node_shape = + helper->GetNodeDataShape(*ele_group->master_nodes.begin()); + int32_t size_ele = std::accumulate( + ele_node_shape.begin(), ele_node_shape.end(), 1, std::multiplies()); for (Node* master : reduce_group->master_nodes) { shape_t master_node_shape = helper->GetNodeDataShape(master); - int32_t size_master = - std::accumulate(master_node_shape.begin(), master_node_shape.end(), 1, std::multiplies()); + int32_t size_master = std::accumulate(master_node_shape.begin(), + master_node_shape.end(), + 1, + std::multiplies()); if (size_ele == size_master) { return true; } @@ -140,7 +148,7 @@ CONDITION_FUNC(elementwise_fuse_reduce) { // if reduce nodes not in consumers of first group std::queue candidates; - std::unordered_set first_node_set = first->NodeSet(); + std::unordered_set first_node_set = first->NodeSet(); std::unordered_set second_node_set = second->NodeSet(); for (const auto& pair : second->input_nodes) { if (first_node_set.find(pair.first) != first_node_set.end()) { @@ -169,13 +177,19 @@ CONDITION_FUNC(elementwise_fuse_reduce) { } } if (!masters_in_consumers.empty()) { - bool flag = true; - shape_t first_node_shape = helper->GetNodeDataShape(*first->master_nodes.begin()); - int32_t size_first = std::accumulate(first_node_shape.begin(), first_node_shape.end(), 1, std::multiplies()); + bool flag = true; + shape_t first_node_shape = + helper->GetNodeDataShape(*first->master_nodes.begin()); + int32_t size_first = std::accumulate(first_node_shape.begin(), + first_node_shape.end(), + 1, + std::multiplies()); for (Node* master : masters_in_consumers) { shape_t second_node_shape = helper->GetNodeDataShape(master); - int32_t size_second = - std::accumulate(second_node_shape.begin(), second_node_shape.end(), 1, std::multiplies()); + int32_t size_second = std::accumulate(second_node_shape.begin(), + second_node_shape.end(), + 1, + std::multiplies()); if (size_first != size_second) { flag = false; break; @@ -196,15 +210,18 @@ CONDITION_FUNC(elementwise_fuse_reduce) { } CHECK(reducer) << "Can't find reduce op in group " << second->group_id; - // If the elementwise's output should be fetched, the output var cannot be computed inline - // into reduce's loop, in other words, the elementwise's cannot fused into reduce's loop - // Like: group1 = {cast_0}, group2={broadcast_0 -> elementwise_0 -> cast_1 -> reduce_max_0} + // If the elementwise's output should be fetched, the output var cannot be + // computed inline into reduce's loop, in other words, the elementwise's + // cannot fused into reduce's loop Like: group1 = {cast_0}, + // group2={broadcast_0 -> elementwise_0 -> cast_1 -> reduce_max_0} if (helper->output_nodes_set_.count(*first->master_nodes.begin())) { return false; } - auto input_shape = helper->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); - auto reduce_axes = absl::get>(reducer->attrs.attr_store.at("dim")); + auto input_shape = + helper->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); + auto reduce_axes = + absl::get>(reducer->attrs.attr_store.at("dim")); int max_num_threads = helper->target_.max_num_threads(); // if without last dimension in reduce. @@ -220,7 +237,8 @@ CONDITION_FUNC(elementwise_fuse_reduce) { int index = reduce_axes.size() - 1; for (; index >= 0; --index) { - if (index + 1 < reduce_axes.size() && reduce_axes[index] + 1 != reduce_axes[index + 1]) { + if (index + 1 < reduce_axes.size() && + reduce_axes[index] + 1 != reduce_axes[index + 1]) { break; } lane *= input_shape[reduce_axes[index]]; @@ -233,8 +251,9 @@ CONDITION_FUNC(elementwise_fuse_reduce) { return true; } else { int prefix = input_shape[reduce_axes[index]]; - int tail = lane / prefix; - for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; --idx) { + int tail = lane / prefix; + for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; + --idx) { if (prefix % idx == 0) { return true; } @@ -257,11 +276,14 @@ CONDITION_FUNC(broadcast_fuse_reduce) { } CHECK(reducer) << "Can't find reduce op in group " << second->group_id; - auto input_shape = helper->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); - auto input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + auto input_shape = + helper->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); + auto input_size = std::accumulate( + input_shape.begin(), input_shape.end(), 1, std::multiplies()); auto output_shape = helper->GetNodeDataShape(*first->master_nodes.begin()); - auto output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + auto output_size = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); if (input_size == output_size) { return elementwise_fuse_reduce(helper, first, second); @@ -279,22 +301,25 @@ CONDITION_FUNC(reduce_fuse_elementwise) { return true; } -inline bool horizontal_relation(const FusionHelperBase* helper, - const std::shared_ptr& first, - const std::shared_ptr& second, - const framework::OpPatternKind op_pattern_kind) { +inline bool horizontal_relation( + const FusionHelperBase* helper, + const std::shared_ptr& first, + const std::shared_ptr& second, + const framework::OpPatternKind op_pattern_kind) { // merge injective auto merge_nodes_set = [](const std::shared_ptr& group) { std::unordered_set nodes_set = group->nodes_set; for (auto& sub_group : group->fused_sub_groups) { - nodes_set.insert(sub_group->nodes_set.begin(), sub_group->nodes_set.end()); + nodes_set.insert(sub_group->nodes_set.begin(), + sub_group->nodes_set.end()); } return nodes_set; }; - auto first_set = merge_nodes_set(first); + auto first_set = merge_nodes_set(first); auto second_set = merge_nodes_set(second); - auto select_node_set = [helper](const std::unordered_set& nodes, framework::OpPatternKind kind) { + auto select_node_set = [helper](const std::unordered_set& nodes, + framework::OpPatternKind kind) { std::unordered_set selected; for (auto node : nodes) { if (helper->GetOpKind(node) == kind) { @@ -351,12 +376,14 @@ CONDITION_FUNC(horizontal_with_injective) { if (!is_same_size(helper, first, second)) { return false; } - return horizontal_relation(helper, first, second, framework::OpPatternKind::kInjective); + return horizontal_relation( + helper, first, second, framework::OpPatternKind::kInjective); } CONDITION_FUNC(injective_horizontal_with_reduce) { // check injective with injective. - if (!horizontal_relation(helper, first, second, framework::OpPatternKind::kInjective)) { + if (!horizontal_relation( + helper, first, second, framework::OpPatternKind::kInjective)) { return false; } return elementwise_fuse_reduce(helper, first, second); @@ -368,11 +395,12 @@ CONDITION_FUNC(reduce_fuse_broadcast) { return true; } - // Traversing all reducers in all producers requires two types of conditions to be met. - // The first type is the condition that the reducer itself needs to meet, - // and the second type is the condition that the relationship between each reducer and its consumers with type of - // Broadcast needs to meet. It is required that each consumer of type Broadcast meet the same shape after broadcast as - // before reduce. + // Traversing all reducers in all producers requires two types of conditions + // to be met. The first type is the condition that the reducer itself needs to + // meet, and the second type is the condition that the relationship between + // each reducer and its consumers with type of Broadcast needs to meet. It is + // required that each consumer of type Broadcast meet the same shape after + // broadcast as before reduce. for (auto& node_in_master : first->master_nodes) { if (helper->GetOpKind(node_in_master) != OpPatternKind::kReduction) { continue; @@ -380,10 +408,11 @@ CONDITION_FUNC(reduce_fuse_broadcast) { Node* reducer = node_in_master; // First type conditions // Get some reduce information - auto reducer_input_shape = helper->GetNodeInputShape(reducer); + auto reducer_input_shape = helper->GetNodeInputShape(reducer); auto reducer_output_shape = helper->GetNodeDataShape(reducer); - auto reduce_axes = absl::get>(reducer->attrs.attr_store.at("dim")); - auto keep_dim = absl::get(reducer->attrs.attr_store.at("keep_dim")); + auto reduce_axes = + absl::get>(reducer->attrs.attr_store.at("dim")); + auto keep_dim = absl::get(reducer->attrs.attr_store.at("keep_dim")); for (auto& axis : reduce_axes) { if (axis == -1) { axis = reducer_input_shape.size() - 1; @@ -398,13 +427,16 @@ CONDITION_FUNC(reduce_fuse_broadcast) { reduce_size *= reducer_input_shape[idx - 1]; } // Check if the reduce size exceeds the hardware limit - if (helper->target_ == common::DefaultNVGPUTarget() && reduce_size > helper->target_.max_num_threads()) { + if (helper->target_ == common::DefaultNVGPUTarget() && + reduce_size > helper->target_.max_num_threads()) { return false; } // Second type conditions - // Find directly or indirectly consumers with type of Broadcast in the second group - auto find_broadcasters_in_descendants = [&](const Node* producer) -> std::unordered_set { + // Find directly or indirectly consumers with type of Broadcast in the + // second group + auto find_broadcasters_in_descendants = + [&](const Node* producer) -> std::unordered_set { std::queue candidates; std::unordered_set visited_set; std::unordered_set broadcasters; @@ -430,10 +462,13 @@ CONDITION_FUNC(reduce_fuse_broadcast) { }; // Check if each broadcast node meets the conditions - std::unordered_set broadcasters_in_consumers = find_broadcasters_in_descendants(reducer); + std::unordered_set broadcasters_in_consumers = + find_broadcasters_in_descendants(reducer); for (auto broadcaster : broadcasters_in_consumers) { - auto broadcaster_output_shape = absl::get>(broadcaster->attrs.attr_store.at("out_shape")); - auto broadcast_axes = absl::get>(broadcaster->attrs.attr_store.at("broadcast_axes")); + auto broadcaster_output_shape = absl::get>( + broadcaster->attrs.attr_store.at("out_shape")); + auto broadcast_axes = absl::get>( + broadcaster->attrs.attr_store.at("broadcast_axes")); for (auto& axis : broadcast_axes) { if (axis == -1) { axis = broadcaster_output_shape.size() - 1; @@ -453,8 +488,10 @@ CONDITION_FUNC(reduce_fuse_broadcast) { } // check union [reduce_axes, broadcast_axes] = reducer_input_shape for (int idx = 0; idx < reducer_input_shape.size(); ++idx) { - if (!(std::find(broadcast_axes.begin(), broadcast_axes.end(), idx) == broadcast_axes.end()) ^ - std::find(reduce_axes.begin(), reduce_axes.end(), idx) == reduce_axes.end()) { + if (!(std::find(broadcast_axes.begin(), broadcast_axes.end(), idx) == + broadcast_axes.end()) ^ + std::find(reduce_axes.begin(), reduce_axes.end(), idx) == + reduce_axes.end()) { return false; } } @@ -488,14 +525,20 @@ CONDITION_FUNC(reduce_fuse_reduce) { CHECK(reducer_1) << "Can't find reduce op in group " << second->group_id; // check reduce has same input shape and output shape - auto reducer_0_input_shape = helper->shape_dict_.at(reducer_0->inlinks_in_order()[0]->source()->id()); - auto reducer_0_output_shape = helper->shape_dict_.at(reducer_0->outlinks_in_order()[0]->sink()->id()); + auto reducer_0_input_shape = + helper->shape_dict_.at(reducer_0->inlinks_in_order()[0]->source()->id()); + auto reducer_0_output_shape = + helper->shape_dict_.at(reducer_0->outlinks_in_order()[0]->sink()->id()); - auto reducer_1_input_shape = helper->shape_dict_.at(reducer_1->inlinks_in_order()[0]->source()->id()); - auto reducer_1_output_shape = helper->shape_dict_.at(reducer_1->outlinks_in_order()[0]->sink()->id()); + auto reducer_1_input_shape = + helper->shape_dict_.at(reducer_1->inlinks_in_order()[0]->source()->id()); + auto reducer_1_output_shape = + helper->shape_dict_.at(reducer_1->outlinks_in_order()[0]->sink()->id()); - auto reducer_0_reduce_dim = absl::get>(reducer_0->attrs.attr_store.at("dim")); - auto reducer_1_reduce_dim = absl::get>(reducer_1->attrs.attr_store.at("dim")); + auto reducer_0_reduce_dim = + absl::get>(reducer_0->attrs.attr_store.at("dim")); + auto reducer_1_reduce_dim = + absl::get>(reducer_1->attrs.attr_store.at("dim")); for (auto& dim : reducer_0_reduce_dim) { // if dim = -1, set as shape.size() - 1 @@ -512,7 +555,8 @@ CONDITION_FUNC(reduce_fuse_reduce) { } // check shape is same - if (reducer_0_input_shape == reducer_1_input_shape && reducer_0_output_shape == reducer_1_output_shape && + if (reducer_0_input_shape == reducer_1_input_shape && + reducer_0_output_shape == reducer_1_output_shape && reducer_0_reduce_dim == reducer_1_reduce_dim) { auto shared_size = 0; for (auto& fusion_group : {first, second}) { @@ -531,9 +575,12 @@ CONDITION_FUNC(reduce_fuse_reduce) { return true; } - if (helper->WithoutLastDimInReduce(reducer_0_input_shape, reducer_0_reduce_dim) && - helper->WithoutLastDimInReduce(reducer_1_input_shape, reducer_1_reduce_dim) && - reducer_0_output_shape == reducer_1_output_shape && reducer_0_reduce_dim == reducer_1_reduce_dim) { + if (helper->WithoutLastDimInReduce(reducer_0_input_shape, + reducer_0_reduce_dim) && + helper->WithoutLastDimInReduce(reducer_1_input_shape, + reducer_1_reduce_dim) && + reducer_0_output_shape == reducer_1_output_shape && + reducer_0_reduce_dim == reducer_1_reduce_dim) { auto shared_size = 0; for (auto& fusion_group : {first, second}) { for (auto* master : fusion_group->master_nodes) { diff --git a/paddle/cinn/hlir/pass/infershape.cc b/paddle/cinn/hlir/pass/infershape.cc index 890ed7329d191..87d8e263567c0 100755 --- a/paddle/cinn/hlir/pass/infershape.cc +++ b/paddle/cinn/hlir/pass/infershape.cc @@ -30,13 +30,16 @@ using framework::Node; using framework::NodeData; using framework::Operator; -using infershape_t = std::function(const std::vector&, - const framework::AttrMapType&)>; -using inferdtype_t = std::function(const std::vector&, const framework::AttrMapType&)>; +using infershape_t = std::function( + const std::vector&, const framework::AttrMapType&)>; +using inferdtype_t = std::function( + const std::vector&, const framework::AttrMapType&)>; using dtype_dict_t = absl::flat_hash_map; using shape_dict_t = absl::flat_hash_map; -void InferShape(Node* node, dtype_dict_t& dtype_dict, shape_dict_t& shape_dict) { +void InferShape(Node* node, + dtype_dict_t& dtype_dict, + shape_dict_t& shape_dict) { VLOG(3) << "Begin InferShape of node " << node->id(); auto op_infershape = Operator::GetAttrs("infershape"); auto op_inferdtype = Operator::GetAttrs("inferdtype"); @@ -44,7 +47,9 @@ void InferShape(Node* node, dtype_dict_t& dtype_dict, shape_dict_t& shape_dict) auto product = [](const framework::shape_t& shape) { framework::dim_t numel = 1; - std::for_each(shape.begin(), shape.end(), [&numel](framework::dim_t dim) { numel *= dim; }); + std::for_each(shape.begin(), shape.end(), [&numel](framework::dim_t dim) { + numel *= dim; + }); return numel; }; @@ -53,40 +58,49 @@ void InferShape(Node* node, dtype_dict_t& dtype_dict, shape_dict_t& shape_dict) for (auto& in_edge : node->inlinks_in_order()) { auto* source_node = in_edge->source()->safe_as(); CHECK(source_node); - CHECK(shape_dict.count(source_node->id())) << "No shape for " << source_node->id(); - CHECK(dtype_dict.count(source_node->id())) << "No dtype for " << source_node->id(); + CHECK(shape_dict.count(source_node->id())) + << "No shape for " << source_node->id(); + CHECK(dtype_dict.count(source_node->id())) + << "No dtype for " << source_node->id(); inputs_shape.push_back(shape_dict[source_node->id()]); inputs_dtype.push_back(dtype_dict[source_node->id()]); - CHECK(product(inputs_shape.back())) << node->id() << " 's Input Node " << source_node->id() << "[" - << utils::Join(inputs_shape.back(), ",") - << "]'s size should not zero ! Please check."; + CHECK(product(inputs_shape.back())) + << node->id() << " 's Input Node " << source_node->id() << "[" + << utils::Join(inputs_shape.back(), ",") + << "]'s size should not zero ! Please check."; } - auto out_shape = op_infershape[node->op()](inputs_shape, node->attrs.attr_store); - auto out_dtype = op_inferdtype[node->op()](inputs_dtype, node->attrs.attr_store); + auto out_shape = + op_infershape[node->op()](inputs_shape, node->attrs.attr_store); + auto out_dtype = + op_inferdtype[node->op()](inputs_dtype, node->attrs.attr_store); CHECK_GE(node->outlinks_in_order().size(), out_shape.size()) - << "The output number of node " << node->id() << " is " << node->outlinks_in_order().size() - << " , which is smaller than the output shape size " << out_shape.size() << " . And the op type is " - << node->op()->name; + << "The output number of node " << node->id() << " is " + << node->outlinks_in_order().size() + << " , which is smaller than the output shape size " << out_shape.size() + << " . And the op type is " << node->op()->name; CHECK_GE(node->outlinks_in_order().size(), out_dtype.size()) - << "The output number of node " << node->id() << " is " << node->outlinks_in_order().size() - << " , which is smaller than the output dtype size " << out_dtype.size() << " . And the op type is " - << node->op()->name; + << "The output number of node " << node->id() << " is " + << node->outlinks_in_order().size() + << " , which is smaller than the output dtype size " << out_dtype.size() + << " . And the op type is " << node->op()->name; int counter = 0; for (auto& out_edge : node->outlinks_in_order()) { auto* sink_node = out_edge->sink()->safe_as(); CHECK(sink_node); - VLOG(3) << "Infershape: " << sink_node->id() << " " << utils::Join(out_shape[counter], ","); + VLOG(3) << "Infershape: " << sink_node->id() << " " + << utils::Join(out_shape[counter], ","); shape_dict[sink_node->id()] = out_shape[counter]; dtype_dict[sink_node->id()] = out_dtype[counter]; - CHECK(product(out_shape[counter])) << node->id() << " 's Output Node " << sink_node->id() << "[" - << utils::Join(out_shape[counter], ",") - << "]'s size should not zero ! Please check."; + CHECK(product(out_shape[counter])) + << node->id() << " 's Output Node " << sink_node->id() << "[" + << utils::Join(out_shape[counter], ",") + << "]'s size should not zero ! Please check."; counter++; } @@ -94,13 +108,18 @@ void InferShape(Node* node, dtype_dict_t& dtype_dict, shape_dict_t& shape_dict) void InferShapePass(Graph* graph) { VLOG(3) << "Begin InferShapePass"; - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); + auto& shape_dict = graph->GetMutableAttrs< + absl::flat_hash_map>("infershape"); + auto& dtype_dict = + graph->GetMutableAttrs>( + "inferdtype"); auto store_nodes = std::get<0>(graph->topological_order()); auto product = [](const framework::shape_t& shape) { framework::dim_t numel = 1; - std::for_each(shape.begin(), shape.end(), [&numel](framework::dim_t dim) { numel *= dim; }); + std::for_each(shape.begin(), shape.end(), [&numel](framework::dim_t dim) { + numel *= dim; + }); return numel; }; @@ -118,7 +137,8 @@ void InferShapePass(Graph* graph) { CINN_REGISTER_HELPER(InferShape) { CINN_REGISTER_PASS(InferShape) .describe( - "This pass infer the shape and data type of tensor and save to g.attrs[\"infershape\"] and " + "This pass infer the shape and data type of tensor and save to " + "g.attrs[\"infershape\"] and " "g.attrs[\"inferdtype\"].") .set_change_structure(false) .provide_graph_attr("infershape") diff --git a/paddle/cinn/hlir/pass/infershape.h b/paddle/cinn/hlir/pass/infershape.h index 8a3c77687cc97..a9b9c8a528ee7 100644 --- a/paddle/cinn/hlir/pass/infershape.h +++ b/paddle/cinn/hlir/pass/infershape.h @@ -20,9 +20,10 @@ namespace cinn { namespace hlir { namespace pass { -void InferShape(framework::Node* node, - absl::flat_hash_map& dtype_dict, - absl::flat_hash_map& shape_dict); +void InferShape( + framework::Node* node, + absl::flat_hash_map& dtype_dict, + absl::flat_hash_map& shape_dict); } // namespace pass } // namespace hlir diff --git a/paddle/cinn/hlir/pass/op_fusion_pass.cc b/paddle/cinn/hlir/pass/op_fusion_pass.cc index 1337fabbde18e..6648cb036131c 100644 --- a/paddle/cinn/hlir/pass/op_fusion_pass.cc +++ b/paddle/cinn/hlir/pass/op_fusion_pass.cc @@ -28,10 +28,11 @@ using framework::shape_t; using common::GraphEdge; using common::GraphNode; -using GroupPtr = std::shared_ptr; +using GroupPtr = std::shared_ptr; using GroupList = std::vector; -using ConditionFunction = std::function; +using ConditionFunction = + std::function; // Op Fusion Pass which performs Ops fusion, Ops are fused // "vertically", meaning producing Ops are fused into their consumers @@ -56,7 +57,7 @@ class OpFusionPassHelper : public FusionHelperBase { // input node for (auto& edge : node->inlinks()) { auto input_graph_node = edge->source(); - auto input_node_data = input_graph_node->safe_as(); + auto input_node_data = input_graph_node->safe_as(); CHECK(input_node_data); // input data has no source node if (input_node_data->source_node.get()) { @@ -68,7 +69,7 @@ class OpFusionPassHelper : public FusionHelperBase { group->op_pattern_kind = GetOpKind(node); // use current node as master node for schedule group->master_nodes.insert(node); - group->group_id = node->id(); + group->group_id = node->id(); fusion_groups_[node] = group; } } @@ -132,7 +133,7 @@ class OpFusionPassHelper : public FusionHelperBase { auto consumer_fusion = fusion_groups_[consumer]; // // check all linkin node for (auto& edge : consumer->inlinks()) { - auto graph_node = edge->source(); + auto graph_node = edge->source(); auto producer_data = graph_node->safe_as(); CHECK(producer_data); @@ -152,30 +153,36 @@ class OpFusionPassHelper : public FusionHelperBase { if (producer_kind == framework::kNonFusible) { continue; } - VLOG(3) << "Producer Op: " << producer->id() << ", Op Pattern: " << producer_kind - << " -> Consumer Op: " << consumer->id() << ", Op Pattern: " << consumer_kind; + VLOG(3) << "Producer Op: " << producer->id() + << ", Op Pattern: " << producer_kind + << " -> Consumer Op: " << consumer->id() + << ", Op Pattern: " << consumer_kind; bool can_fuse = true; // checkout producer node outputs are all in fusion op for (auto& link : producer_data->outlinks()) { auto consumer_node = link->sink()->safe_as(); CHECK(consumer_node); // if fusion group can't find node, can't merge - if (consumer_fusion->nodes_set.find(consumer_node) == consumer_fusion->nodes_set.end()) { + if (consumer_fusion->nodes_set.find(consumer_node) == + consumer_fusion->nodes_set.end()) { can_fuse = false; break; } } if (!can_fuse || !CanFuse(producer, consumer)) continue; - VLOG(3) << "Fuse Op " << producer->id() << " into Op " << consumer->id(); + VLOG(3) << "Fuse Op " << producer->id() << " into Op " + << consumer->id(); // fuse producer to fusion group - consumer_fusion->group_id = producer->id() + "_" + consumer_fusion->group_id; + consumer_fusion->group_id = + producer->id() + "_" + consumer_fusion->group_id; consumer_fusion->nodes.push_back(producer); consumer_fusion->nodes_set.insert(producer); consumer_fusion->input_nodes.erase(producer); consumer_fusion->op_pattern_kind = - static_cast(consumer_fusion->op_pattern_kind) > static_cast(producer_kind) + static_cast(consumer_fusion->op_pattern_kind) > + static_cast(producer_kind) ? consumer_fusion->op_pattern_kind : producer_kind; @@ -186,7 +193,8 @@ class OpFusionPassHelper : public FusionHelperBase { if (this->output_nodes_set_.count(producer)) { VLOG(3) << "Insert Global Output Node : " << producer->id(); consumer_fusion->output_nodes.insert(producer); - } else if (producer_data->outlinks().size() > 1 && producer->inlinks().size() > 0 && + } else if (producer_data->outlinks().size() > 1 && + producer->inlinks().size() > 0 && is_same_size(this, producer, consumer_fusion)) { // producer is not a const value node. consumer_fusion->internal_nodes.insert(producer); @@ -214,23 +222,32 @@ class OpFusionPassHelper : public FusionHelperBase { { FusionRelation relation; // producer -> consumer - relation.op_kind = {framework::kElementWise, framework::kBroadcast, framework::kReduction, framework::kInjective}; + relation.op_kind = {framework::kElementWise, + framework::kBroadcast, + framework::kReduction, + framework::kInjective}; // producer -> fusion relation.fusion_op_kind = { - // horizontal or vertical relation(Elementwise + *Elementwise*). As has same output shape, can always fuse. + // horizontal or vertical relation(Elementwise + *Elementwise*). As + // has same output shape, can always fuse. {framework::kElementWise, always_fuse}, - // must be horizontal, as Elementwise + Broadcast is left to fusion merge pass. + // must be horizontal, as Elementwise + Broadcast is left to fusion + // merge pass. {framework::kBroadcast, - [](const FusionHelperBase* helper, const Node* producer, const GroupPtr& consumer) -> bool { + [](const FusionHelperBase* helper, + const Node* producer, + const GroupPtr& consumer) -> bool { if (is_same_size(helper, producer, consumer)) { return true; } return !helper->output_nodes_set_.count(producer); }}, - // horizontal or vertical relation, check with same output shape with horizontal relation or with last + // horizontal or vertical relation, check with same output shape with + // horizontal relation or with last // successive dimension less than 1024 for gpu. {framework::kReduction, horizontal_or_vertical_reduce_relation}, - // can be horizontal or can compute inline, check with same output shape or can compute inline. + // can be horizontal or can compute inline, check with same output + // shape or can compute inline. {framework::kInjective, horizontal_or_can_inline}, // must be horizontal, check with same output shape. {framework::kOutFusible, is_same_shape}}; @@ -240,16 +257,20 @@ class OpFusionPassHelper : public FusionHelperBase { { FusionRelation relation; // producer -> consumer - relation.op_kind = {framework::kElementWise, framework::kReduction, framework::kInjective}; + relation.op_kind = {framework::kElementWise, + framework::kReduction, + framework::kInjective}; // producer -> fusion relation.fusion_op_kind = { - // horizontal or vertical relation(Broadcast + *Elementwise*), check with same output shape. + // horizontal or vertical relation(Broadcast + *Elementwise*), check + // with same output shape. {framework::kElementWise, is_same_size}, // must be horizontal, as Broadcast + Broadcast is not allowed. {framework::kBroadcast, is_same_size}, // horizontal or vertical relation(Broadcast + Reduce). {framework::kReduction, horizontal_or_vertical_reduce_relation}, - // can be horizontal or can compute inline, check with same output shape or just one consumer. + // can be horizontal or can compute inline, check with same output + // shape or just one consumer. {framework::kInjective, horizontal_or_can_inline}, // must be horizontal, check with same output shape. {framework::kOutFusible, is_same_shape}}; @@ -262,9 +283,11 @@ class OpFusionPassHelper : public FusionHelperBase { relation.op_kind = {framework::kElementWise, framework::kBroadcast}; // producer -> fusion relation.fusion_op_kind = { - // horizontal or vertical relation(Reduce + Elementwise*), check without last dimension in reduce. + // horizontal or vertical relation(Reduce + Elementwise*), check + // without last dimension in reduce. {framework::kElementWise, is_same_size}, - // must be horizontal relation, check with same output shape and without last dimension in reduce. + // must be horizontal relation, check with same output shape and + // without last dimension in reduce. {framework::kBroadcast, reduce_fuse_broadcast}, // must be horizontal relation and with same reduce attr. {framework::kReduction, reduce_fuse_reduce}, @@ -281,7 +304,8 @@ class OpFusionPassHelper : public FusionHelperBase { relation.op_kind = {framework::kElementWise, framework::kInjective}; // producer -> fusion relation.fusion_op_kind = { - // can be horizontal or vertical(Injective + Elementwise), check with same output shape. + // can be horizontal or vertical(Injective + Elementwise), check with + // same output shape. {framework::kElementWise, is_same_size}, // must be horizontal relation, check with same output shape. {framework::kBroadcast, horizontal_with_same_size}, @@ -322,9 +346,11 @@ class OpFusionPassHelper : public FusionHelperBase { if (relation.op_kind.count(GetOpKind(consumer))) { auto& consumer_group = fusion_groups_[consumer]; // second step: check producer can be fused into consumer group - VLOG(3) << "Call ConditionFunction, Producer Op Pattern : " << GetOpKind(producer) - << " , Consumer Group Pattern : " << consumer_group->op_pattern_kind; - return relation.fusion_op_kind[consumer_group->op_pattern_kind](this, producer, fusion_groups_[consumer]); + VLOG(3) << "Call ConditionFunction, Producer Op Pattern : " + << GetOpKind(producer) << " , Consumer Group Pattern : " + << consumer_group->op_pattern_kind; + return relation.fusion_op_kind[consumer_group->op_pattern_kind]( + this, producer, fusion_groups_[consumer]); } return false; @@ -336,15 +362,17 @@ class OpFusionPassHelper : public FusionHelperBase { // producer -> consumer std::unordered_set op_kind = {}; // producer -> fusion sonsumer - std::unordered_map fusion_op_kind = {}; + std::unordered_map + fusion_op_kind = {}; }; - std::unordered_map fusion_relation_map_; + std::unordered_map + fusion_relation_map_; }; void OpFusionPassInternal(Graph* graph) { VLOG(3) << "OpFusionPass...!"; auto op_fusion_helper = OpFusionPassHelper(graph); - graph->fusion_groups = op_fusion_helper(); + graph->fusion_groups = op_fusion_helper(); for (auto& group : graph->fusion_groups) { VLOG(3) << "Group Id : " << group->group_id; @@ -371,7 +399,8 @@ void BuildNonFusedGroupsPassInternal(framework::Graph* graph) { CINN_REGISTER_HELPER(OpFusionPass) { CINN_REGISTER_PASS(OpFusionPass) .describe( - "Op Fusion Pass which performs Ops fusion, Producer Ops are fused into Consumer Ops with certain conditions.") + "Op Fusion Pass which performs Ops fusion, Producer Ops are fused " + "into Consumer Ops with certain conditions.") .set_change_structure(false) .set_body(cinn::hlir::pass::OpFusionPassInternal); diff --git a/paddle/cinn/hlir/pass/op_fusion_pass_test.cc b/paddle/cinn/hlir/pass/op_fusion_pass_test.cc index 97eb048346b70..f433cac8ca43d 100755 --- a/paddle/cinn/hlir/pass/op_fusion_pass_test.cc +++ b/paddle/cinn/hlir/pass/op_fusion_pass_test.cc @@ -34,7 +34,7 @@ TEST(OpFusionPass, ElementWise_Fusion_0) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -58,7 +58,7 @@ TEST(OpFusionPass, ElementWise_Fusion_1) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -81,7 +81,7 @@ TEST(OpFusionPass, Brodcast_Test_0) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -106,7 +106,7 @@ TEST(OpFusionPass, Brodcast_Test_1) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -126,7 +126,7 @@ TEST(OpFusionPass, Brodcast_Test_2) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -150,7 +150,7 @@ TEST(OpFusionPass, Reduce_Test_0) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -175,7 +175,7 @@ TEST(OpFusionPass, Reduce_Test_1) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -200,7 +200,7 @@ TEST(OpFusionPass, Reduce_Test_2) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -224,7 +224,7 @@ TEST(OpFusionPass, Injective_Test_0) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -242,7 +242,7 @@ TEST(OP_LOWERING, Injective_Test_1) { auto F = net_builder.Add(D, E); auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); RunDecomposer(&program, target); auto graph = std::make_shared(program, target); @@ -264,7 +264,7 @@ TEST(OpFusionPass, Test_Insert_BroadcastTo) { } auto program = net_builder.Build(); - auto target = common::DefaultTarget(); + auto target = common::DefaultTarget(); auto graph = std::make_shared(program, target); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); diff --git a/paddle/cinn/hlir/pass/op_fusion_pass_util.h b/paddle/cinn/hlir/pass/op_fusion_pass_util.h index b04d7557d28bc..34845299e4e79 100644 --- a/paddle/cinn/hlir/pass/op_fusion_pass_util.h +++ b/paddle/cinn/hlir/pass/op_fusion_pass_util.h @@ -21,8 +21,10 @@ namespace cinn { namespace hlir { namespace pass { -#define CONDITION_FUNC(func) \ - inline bool func(const FusionHelperBase* helper, const Node* producer, const std::shared_ptr& consumer) +#define CONDITION_FUNC(func) \ + inline bool func(const FusionHelperBase* helper, \ + const Node* producer, \ + const std::shared_ptr& consumer) CONDITION_FUNC(always_fuse) { return true; } @@ -30,24 +32,29 @@ CONDITION_FUNC(no_fuse) { return false; } CONDITION_FUNC(is_same_shape) { auto master_node = consumer->master_nodes.begin(); - return helper->GetNodeDataShape(producer) == helper->GetNodeDataShape(*master_node); + return helper->GetNodeDataShape(producer) == + helper->GetNodeDataShape(*master_node); } CONDITION_FUNC(is_same_size) { - auto master_node = consumer->master_nodes.begin(); + auto master_node = consumer->master_nodes.begin(); auto producer_shape = helper->GetNodeDataShape(producer); auto consumer_shape = helper->GetNodeDataShape(*master_node); if (producer_shape == consumer_shape) { return true; } - auto psize = std::accumulate(producer_shape.begin(), producer_shape.end(), 1, std::multiplies()); - auto csize = std::accumulate(consumer_shape.begin(), consumer_shape.end(), 1, std::multiplies()); + auto psize = std::accumulate( + producer_shape.begin(), producer_shape.end(), 1, std::multiplies()); + auto csize = std::accumulate( + consumer_shape.begin(), consumer_shape.end(), 1, std::multiplies()); return psize == csize; } CONDITION_FUNC(without_last_dimension_in_reduce) { - auto in_shape = helper->shape_dict_.at(producer->inlinks_in_order()[0]->source()->id()); - auto reduce_axes = absl::get>(producer->attrs.attr_store.at("dim")); + auto in_shape = + helper->shape_dict_.at(producer->inlinks_in_order()[0]->source()->id()); + auto reduce_axes = + absl::get>(producer->attrs.attr_store.at("dim")); return helper->WithoutLastDimInReduce(in_shape, reduce_axes); } @@ -60,14 +67,20 @@ CONDITION_FUNC(reduce_fuse_reduce) { } } // check reduce has same input shape and output shape - auto producer_input_shape = helper->shape_dict_.at(producer->inlinks_in_order()[0]->source()->id()); - auto producer_output_shape = helper->shape_dict_.at(producer->outlinks_in_order()[0]->sink()->id()); + auto producer_input_shape = + helper->shape_dict_.at(producer->inlinks_in_order()[0]->source()->id()); + auto producer_output_shape = + helper->shape_dict_.at(producer->outlinks_in_order()[0]->sink()->id()); - auto reducer_input_shape = helper->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); - auto reducer_output_shape = helper->shape_dict_.at(reducer->outlinks_in_order()[0]->sink()->id()); + auto reducer_input_shape = + helper->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); + auto reducer_output_shape = + helper->shape_dict_.at(reducer->outlinks_in_order()[0]->sink()->id()); - auto producer_reduce_dim = absl::get>(producer->attrs.attr_store.at("dim")); - auto reducer_reduce_dim = absl::get>(reducer->attrs.attr_store.at("dim")); + auto producer_reduce_dim = + absl::get>(producer->attrs.attr_store.at("dim")); + auto reducer_reduce_dim = + absl::get>(reducer->attrs.attr_store.at("dim")); for (auto& dim : producer_reduce_dim) { // if dim = -1, set as shape.size() - 1 @@ -83,10 +96,13 @@ CONDITION_FUNC(reduce_fuse_reduce) { } } - if (producer_output_shape == reducer_output_shape && producer_reduce_dim == reducer_reduce_dim) { + if (producer_output_shape == reducer_output_shape && + producer_reduce_dim == reducer_reduce_dim) { bool input_shape_same = producer_input_shape == reducer_input_shape; - bool without_last_dim = helper->WithoutLastDimInReduce(producer_input_shape, producer_reduce_dim) && - helper->WithoutLastDimInReduce(reducer_input_shape, reducer_reduce_dim); + bool without_last_dim = + helper->WithoutLastDimInReduce(producer_input_shape, + producer_reduce_dim) && + helper->WithoutLastDimInReduce(reducer_input_shape, reducer_reduce_dim); // check shape is same if (input_shape_same || without_last_dim) { auto shared_size = helper->GetSharedSize(producer); @@ -165,8 +181,10 @@ CONDITION_FUNC(horizontal_or_vertical_reduce_relation) { } // check producer has same shape with reducer node. - auto reduce_shape = helper->shape_dict_.at(helper->GetProducerNodeData(reducer)[0]->id()); - auto reduce_axes = absl::get>(reducer->attrs.attr_store.at("dim")); + auto reduce_shape = + helper->shape_dict_.at(helper->GetProducerNodeData(reducer)[0]->id()); + auto reduce_axes = + absl::get>(reducer->attrs.attr_store.at("dim")); for (auto& axis : reduce_axes) { // if axis = -1, set as shape.size() - 1 if (axis < 0) { @@ -174,9 +192,11 @@ CONDITION_FUNC(horizontal_or_vertical_reduce_relation) { } } - auto node_shape = helper->GetNodeDataShape(producer); - auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); - auto reduce_size = std::accumulate(reduce_shape.begin(), reduce_shape.end(), 1, std::multiplies()); + auto node_shape = helper->GetNodeDataShape(producer); + auto node_size = std::accumulate( + node_shape.begin(), node_shape.end(), 1, std::multiplies()); + auto reduce_size = std::accumulate( + reduce_shape.begin(), reduce_shape.end(), 1, std::multiplies()); // is not same size with reduce size. if (node_size != reduce_size) { @@ -197,7 +217,9 @@ CONDITION_FUNC(horizontal_or_vertical_reduce_relation) { } return helper->target_ == common::DefaultNVGPUTarget() - ? (succesive_reduce_dimension <= helper->target_.max_num_threads() ? true : false) + ? (succesive_reduce_dimension <= helper->target_.max_num_threads() + ? true + : false) : true; } @@ -212,7 +234,8 @@ CONDITION_FUNC(horizontal_or_can_inline) { } } // vertical relation: 1.can compute inline - if (helper->GetNodeData(producer)->outlinks().size() == 1 && helper->output_nodes_set_.count(producer) == 0) { + if (helper->GetNodeData(producer)->outlinks().size() == 1 && + helper->output_nodes_set_.count(producer) == 0) { return true; } @@ -228,7 +251,8 @@ CONDITION_FUNC(horizontal_or_can_inline) { } CONDITION_FUNC(horizontal_with_same_size) { - return is_horizontal_relation(helper, producer, consumer) && is_same_size(helper, producer, consumer); + return is_horizontal_relation(helper, producer, consumer) && + is_same_size(helper, producer, consumer); } CONDITION_FUNC(reduce_fuse_broadcast) { @@ -244,8 +268,9 @@ CONDITION_FUNC(reduce_fuse_broadcast) { } auto rinput_shape = helper->GetNodeInputShape(producer); - auto reduce_axes = absl::get>(producer->attrs.attr_store.at("dim")); - auto keep_dim = absl::get(producer->attrs.attr_store.at("keep_dim")); + auto reduce_axes = + absl::get>(producer->attrs.attr_store.at("dim")); + auto keep_dim = absl::get(producer->attrs.attr_store.at("keep_dim")); for (auto& axis : reduce_axes) { if (axis < 0) { axis += rinput_shape.size(); @@ -265,7 +290,9 @@ CONDITION_FUNC(reduce_fuse_broadcast) { } auto routput_shape = helper->GetNodeDataShape(producer); - auto find_reducer = [&](const Node* node, const Node* reducer, const std::unordered_set& nodes_set) { + auto find_reducer = [&](const Node* node, + const Node* reducer, + const std::unordered_set& nodes_set) { std::queue candidates; candidates.push(node); @@ -296,8 +323,10 @@ CONDITION_FUNC(reduce_fuse_broadcast) { continue; } - auto broadcast_shape = absl::get>(node->attrs.attr_store.at("out_shape")); - auto broadcast_axes = absl::get>(node->attrs.attr_store.at("broadcast_axes")); + auto broadcast_shape = + absl::get>(node->attrs.attr_store.at("out_shape")); + auto broadcast_axes = absl::get>( + node->attrs.attr_store.at("broadcast_axes")); for (auto& axis : broadcast_axes) { if (axis < 0) { axis += broadcast_shape.size(); @@ -318,8 +347,10 @@ CONDITION_FUNC(reduce_fuse_broadcast) { // check [reduce_axes, axes] = {0, 1, 2, 3, 4, 5, 6, ...} for (int idx = 0; idx < rinput_shape.size(); ++idx) { // note: !x ^ y == (!x) ^ y == !(x ^ y) - if ((std::find(broadcast_axes.begin(), broadcast_axes.end(), idx) != broadcast_axes.end()) ^ - std::find(reduce_axes.begin(), reduce_axes.end(), idx) == reduce_axes.end()) { + if ((std::find(broadcast_axes.begin(), broadcast_axes.end(), idx) != + broadcast_axes.end()) ^ + std::find(reduce_axes.begin(), reduce_axes.end(), idx) == + reduce_axes.end()) { return false; } } diff --git a/paddle/cinn/hlir/pass/opfusion.cc b/paddle/cinn/hlir/pass/opfusion.cc index fd8a693216e94..f95eed9873d95 100644 --- a/paddle/cinn/hlir/pass/opfusion.cc +++ b/paddle/cinn/hlir/pass/opfusion.cc @@ -41,11 +41,12 @@ struct DomNode { int depth{0}; }; -void GetBroadcastPattern(Node* op_node, - OpPatternKind* pattern, - const absl::flat_hash_map& shape_dict) { +void GetBroadcastPattern( + Node* op_node, + OpPatternKind* pattern, + const absl::flat_hash_map& shape_dict) { if (*pattern == framework::kBroadcast) { - auto inlinks = op_node->inlinks(); + auto inlinks = op_node->inlinks(); auto outlinks = op_node->outlinks(); CHECK_EQ(inlinks.size(), 2U); CHECK_EQ(outlinks.size(), 1U); @@ -56,8 +57,9 @@ void GetBroadcastPattern(Node* op_node, input_shapes.push_back(shape_dict.at(source->id())); } int small_index = input_shapes[0].size() <= input_shapes[1].size() ? 0 : 1; - auto begin = std::find( - input_shapes[1 - small_index].begin(), input_shapes[1 - small_index].end(), input_shapes[small_index][0]); + auto begin = std::find(input_shapes[1 - small_index].begin(), + input_shapes[1 - small_index].end(), + input_shapes[small_index][0]); bool is_same = true; for (int i = 0; i < input_shapes[small_index].size(); i++) { if (input_shapes[small_index][i] != (*begin)) { @@ -77,17 +79,20 @@ void GetBroadcastPattern(Node* op_node, class DomTree { public: - std::vector& CreatePostDomTree(const std::vector& nodes) { + std::vector& CreatePostDomTree( + const std::vector& nodes) { int size = nodes.size(); dom_nodes_.resize(nodes.size()); // construct postdom tree, reverse topological_order for (int i = size - 1; i >= 0; i--) { auto* dom_node = CreateDomNode(nodes[i]); CHECK(dom_node); - VLOG(2) << "dom_node: " << dom_node->ref_node->id() << ", pattern: " << dom_node->pattern + VLOG(2) << "dom_node: " << dom_node->ref_node->id() + << ", pattern: " << dom_node->pattern << ", depth: " << dom_node->depth; if (dom_node->parent) { - VLOG(2) << dom_node->ref_node->id() << " parent: " << dom_node->parent->ref_node->id(); + VLOG(2) << dom_node->ref_node->id() + << " parent: " << dom_node->parent->ref_node->id(); } dom_nodes_[i] = dom_node; } @@ -97,44 +102,48 @@ class DomTree { std::vector dom_nodes_; private: - OpPatternKind FusePattern(OpPatternKind p0, OpPatternKind p1) { return p0 > p1 ? p0 : p1; } + OpPatternKind FusePattern(OpPatternKind p0, OpPatternKind p1) { + return p0 > p1 ? p0 : p1; + } DomNode* LCA(DomNode* l, DomNode* r, OpPatternKind* pattern) { while (l != r) { if (!l || !r) return nullptr; if (l->depth < r->depth) { *pattern = FusePattern(*pattern, r->pattern); - r = r->parent; + r = r->parent; } else if (l->depth > r->depth) { *pattern = FusePattern(*pattern, l->pattern); - l = l->parent; + l = l->parent; } else { *pattern = FusePattern(*pattern, l->pattern); *pattern = FusePattern(*pattern, r->pattern); - l = l->parent; - r = r->parent; + l = l->parent; + r = r->parent; } } return l; } DomNode* FindLCA(GraphNode* graph_node, OpPatternKind* pattern) { - static auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + static auto& op_pattern_dict = + Operator::GetAttrs("OpPattern"); CHECK(graph_node); CHECK(pattern); DomNode* parent = nullptr; - int count = 0; + int count = 0; if (graph_node->safe_as()) { - auto* node = graph_node->safe_as(); + auto* node = graph_node->safe_as(); const auto& out_links = node->outlinks_in_order(); for (int i = 0; i < out_links.size(); i++) { - auto sink = out_links[i]->sink(); + auto sink = out_links[i]->sink(); bool has_no_links = sink->outlinks().empty(); if (i) { - // CHECK(has_no_links) << "only the first out_var of " << node->id() << " links to other op node!"; + // CHECK(has_no_links) << "only the first out_var of " << node->id() + // << " links to other op node!"; } else { int index = sink->get_index(); // the first out_var is the parent of the op node - parent = dom_nodes_[index]; + parent = dom_nodes_[index]; *pattern = FusePattern(*pattern, parent->pattern); } } @@ -142,23 +151,25 @@ class DomTree { auto* node_data = graph_node->safe_as(); CHECK(node_data); auto out_links = node_data->outlinks(); - int count = 0; + int count = 0; for (auto link : out_links) { - auto sink = link->sink(); - int index = sink->get_index(); + auto sink = link->sink(); + int index = sink->get_index(); auto dom_node = dom_nodes_[index]; if (!count) { parent = dom_node; CHECK(parent); } else { - // if the out_var links to more than one opnode, then we need to find the LCA + // if the out_var links to more than one opnode, then we need to find + // the LCA parent = LCA(parent, dom_node, pattern); } auto* op_node = sink->safe_as(); CHECK(op_node); auto op_pattern = op_pattern_dict[op_node->op()]; VLOG(2) << sink->id() << "'s op pattern is " << op_pattern; - if (op_node->attrs.attr_store.count("pre_run") && absl::get(op_node->attrs.attr_store["pre_run"])) { + if (op_node->attrs.attr_store.count("pre_run") && + absl::get(op_node->attrs.attr_store["pre_run"])) { // not fuse pre_run opnode op_pattern = framework::kNonFusible; VLOG(3) << op_node->op()->name << " do pre_run and not fuse"; @@ -171,20 +182,20 @@ class DomTree { } DomNode* CreateDomNode(GraphNode* graph_node) { CHECK(graph_node); - DomNode* dom_node = new DomNode(); + DomNode* dom_node = new DomNode(); dom_node->ref_node = graph_node; if (graph_node->inlinks().empty() && graph_node->safe_as()) { CHECK(graph_node->safe_as()); // extern input vars - dom_node->parent = nullptr; + dom_node->parent = nullptr; dom_node->pattern = framework::kNonFusible; - dom_node->depth = 0; + dom_node->depth = 0; } else { OpPatternKind pattern{framework::kElementWise}; - auto* parent = FindLCA(graph_node, &pattern); - dom_node->parent = parent; + auto* parent = FindLCA(graph_node, &pattern); + dom_node->parent = parent; dom_node->pattern = pattern; - dom_node->depth = parent ? parent->depth + 1 : 0; + dom_node->depth = parent ? parent->depth + 1 : 0; } return dom_node; } @@ -209,16 +220,19 @@ struct GroupNode { while (node != root_node) { auto* parent = node->parent; node->parent = root_node; - node = parent; + node = parent; } return root_node; } }; class GraphPartition { public: - GraphPartition(const absl::flat_hash_map& shape_dict) : shape_dict_(shape_dict) {} - std::vector> Partition(const std::vector& graph_nodes, - const std::vector& dom_nodes) { + GraphPartition( + const absl::flat_hash_map& shape_dict) + : shape_dict_(shape_dict) {} + std::vector> Partition( + const std::vector& graph_nodes, + const std::vector& dom_nodes) { CHECK_EQ(graph_nodes.size(), dom_nodes.size()); InitGroups(graph_nodes); for (int i = 0; i < 2; i++) { @@ -237,22 +251,24 @@ class GraphPartition { std::unordered_set visited_nodes_; const absl::flat_hash_map& shape_dict_; void InitGroups(const std::vector& graph_nodes) { - static auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + static auto& op_pattern_dict = + Operator::GetAttrs("OpPattern"); for (int i = 0; i < graph_nodes.size(); i++) { GroupNode* group_node = new GroupNode(); GraphNode* graph_node = graph_nodes[i]; CHECK(graph_node); - auto op_node = graph_node->safe_as(); + auto op_node = graph_node->safe_as(); group_node->ref_node = graph_node; - group_node->index = graph_node->get_index(); + group_node->index = graph_node->get_index(); if (op_node) { auto pattern = op_pattern_dict[op_node->op()]; - if (op_node->attrs.attr_store.count("pre_run") && absl::get(op_node->attrs.attr_store["pre_run"])) { + if (op_node->attrs.attr_store.count("pre_run") && + absl::get(op_node->attrs.attr_store["pre_run"])) { // not fuse pre_run opnode pattern = framework::kNonFusible; VLOG(3) << op_node->op()->name << " do pre_run and not fuse"; } - group_node->pattern = pattern; + group_node->pattern = pattern; group_node->op_nodes_count = 1; if (pattern == framework::kOutFusible) { group_node->master_node = graph_node; @@ -268,7 +284,8 @@ class GraphPartition { group_nodes_.push_back(group_node); } } - bool IsSameShape(const std::vector& shape1, const std::vector& shape2) { + bool IsSameShape(const std::vector& shape1, + const std::vector& shape2) { if (shape1.size() != shape2.size()) return false; for (int i = 0; i < shape1.size(); i++) { if (shape1[i] != shape2[i]) return false; @@ -326,7 +343,9 @@ class GraphPartition { if (!i) { if (!CanFuse(new_source, sink, fn)) return false; } else { - CHECK(new_source->outlinks().empty()) << "only the first out_var of the op node can link to other op node"; + CHECK(new_source->outlinks().empty()) + << "only the first out_var of the op node can link to other op " + "node"; } } } else { @@ -341,15 +360,16 @@ class GraphPartition { // check all the nodes between source and sink meet the function of fusion. template bool VerifyFuse(GraphNode* source, GraphNode* sink, T fn) { - static auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - auto op_node = source->safe_as(); + static auto& op_pattern_dict = + Operator::GetAttrs("OpPattern"); + auto op_node = source->safe_as(); visited_nodes_.clear(); CHECK(source != sink); auto sink_op_node = sink->safe_as(); if (sink_op_node && GetRootPattern(source) == framework::kOutFusible && op_pattern_dict[sink_op_node->op()] >= framework::kBroadcast) { - // verify conv and sink out shape. If sink's out shape is different from conv's, then no fuse for computeAt - // lowering incompatible + // verify conv and sink out shape. If sink's out shape is different from + // conv's, then no fuse for computeAt lowering incompatible if (!VerifyOutShape(source, sink)) { VLOG(2) << "source node: " << source->id(); VLOG(2) << "sink node: " << sink->id(); @@ -364,7 +384,8 @@ class GraphPartition { // verify all the nodes in the fuse path recursively if (!CanFuse(new_source, sink, fn)) return false; } else { - CHECK(new_source->outlinks().empty()) << "only the first out_var of op_node links to other op_node"; + CHECK(new_source->outlinks().empty()) + << "only the first out_var of op_node links to other op_node"; } } } else { @@ -378,7 +399,7 @@ class GraphPartition { return true; } void MergeNodes(GroupNode* child, GroupNode* parent) { - child = child->GetRootNode(); + child = child->GetRootNode(); parent = parent->GetRootNode(); CHECK(child); CHECK(parent); @@ -389,10 +410,12 @@ class GraphPartition { if (child->master_node) { CHECK(!parent->master_node); parent->master_node = child->master_node; - if (child->pattern > framework::kBroadcast && parent->pattern > framework::kBroadcast) { + if (child->pattern > framework::kBroadcast && + parent->pattern > framework::kBroadcast) { LOG(FATAL) << "can't fuse 2 groups both with complex pattern"; } else { - parent->pattern = child->pattern > parent->pattern ? child->pattern : parent->pattern; + parent->pattern = + child->pattern > parent->pattern ? child->pattern : parent->pattern; } } } @@ -411,7 +434,8 @@ class GraphPartition { if (!i) { Fuse(new_source, sink, target); } else { - CHECK(new_source->outlinks().empty()) << "only the first out_var of op_node links to other op_node"; + CHECK(new_source->outlinks().empty()) + << "only the first out_var of op_node links to other op_node"; } } } else { @@ -429,28 +453,35 @@ class GraphPartition { CHECK(source != sink); Fuse(source, sink, group_node); } - void FuseGroups(const std::vector& graph_nodes, const std::vector& dom_nodes, int phase) { + void FuseGroups(const std::vector& graph_nodes, + const std::vector& dom_nodes, + int phase) { CHECK_EQ(graph_nodes.size(), dom_nodes.size()); CHECK_EQ(group_nodes_.size(), dom_nodes.size()); for (int i = 0; i < graph_nodes.size(); i++) { auto* graph_node = graph_nodes[i]; - auto* dom_node = dom_nodes[i]; + auto* dom_node = dom_nodes[i]; auto* group_node = group_nodes_[i]; CHECK(graph_node); CHECK(dom_node); CHECK(group_node); if (!dom_node->parent) continue; if (group_node->pattern == framework::kNonFusible) continue; - int parent_index = dom_node->parent->ref_node->get_index(); + int parent_index = dom_node->parent->ref_node->get_index(); auto parent_group_node = group_nodes_[parent_index]; - if (parent_group_node && parent_group_node->GetRootNode() == group_node->GetRootNode()) continue; + if (parent_group_node && + parent_group_node->GetRootNode() == group_node->GetRootNode()) + continue; if (group_node->pattern == framework::kOutFusible) { if (dom_node->pattern <= framework::kBroadcast) { - auto fn = [](OpPatternKind pattern, bool is_sink) { return pattern <= framework::kBroadcast; }; + auto fn = [](OpPatternKind pattern, bool is_sink) { + return pattern <= framework::kBroadcast; + }; auto lca_node = dom_node->parent->ref_node; if (VerifyFuse(graph_node, lca_node, fn)) { - VLOG(2) << "fuse between " << graph_node->id() << " and " << lca_node->id(); + VLOG(2) << "fuse between " << graph_node->id() << " and " + << lca_node->id(); DoFuse(graph_node, lca_node); } } @@ -465,17 +496,22 @@ class GraphPartition { }; auto lca_node = dom_node->parent->ref_node; if (VerifyFuse(graph_node, lca_node, fn)) { - VLOG(2) << "fuse between " << graph_node->id() << " and " << lca_node->id(); + VLOG(2) << "fuse between " << graph_node->id() << " and " + << lca_node->id(); DoFuse(graph_node, lca_node); } } } else if (group_node->pattern == framework::kInjective && phase == 1) { - // fuse injective ops in the second phase so that conv2d can always finish fusing + // fuse injective ops in the second phase so that conv2d can always + // finish fusing if (dom_node->pattern <= framework::kInjective) { - auto fn = [](OpPatternKind pattern, bool is_sink) { return pattern <= framework::kInjective; }; + auto fn = [](OpPatternKind pattern, bool is_sink) { + return pattern <= framework::kInjective; + }; auto lca_node = dom_node->parent->ref_node; if (VerifyFuse(graph_node, lca_node, fn)) { - VLOG(2) << "fuse between " << graph_node->id() << " and " << lca_node->id(); + VLOG(2) << "fuse between " << graph_node->id() << " and " + << lca_node->id(); DoFuse(graph_node, lca_node); } } @@ -512,12 +548,13 @@ class GraphPartition { void OpFusionPass(Graph* graph) { auto store_nodes = std::get<0>(graph->topological_order()); - int node_size = store_nodes.size(); + int node_size = store_nodes.size(); // construct postdom tree, reverse topological_order DomTree tree; auto& dom_nodes = tree.CreatePostDomTree(store_nodes); // graph partition - auto& shape_dict = graph->GetMutableAttrs>("infershape"); + auto& shape_dict = graph->GetMutableAttrs< + absl::flat_hash_map>("infershape"); GraphPartition partition(shape_dict); graph->groups = partition.Partition(store_nodes, dom_nodes); } diff --git a/paddle/cinn/hlir/pass/opfusion_test.cc b/paddle/cinn/hlir/pass/opfusion_test.cc index ae8546055a0b6..870f7e4bc591e 100755 --- a/paddle/cinn/hlir/pass/opfusion_test.cc +++ b/paddle/cinn/hlir/pass/opfusion_test.cc @@ -55,11 +55,11 @@ TEST(complex2, complex2) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; absl::flat_hash_map attrs1; attrs1["epsilon"] = static_cast(0.001); @@ -110,11 +110,11 @@ TEST(complex1, complex1) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; absl::flat_hash_map attrs1; attrs1["epsilon"] = static_cast(0.001); @@ -242,11 +242,11 @@ TEST(conv_bn_conv, conv_bn_conv) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; absl::flat_hash_map attrs1; attrs1["epsilon"] = static_cast(0.001); @@ -299,11 +299,11 @@ TEST(fuse_conv_add, fuse_conv_add) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; auto c = program.conv2d(A, B, attrs); auto d = program.elementwise_add(c, C, 1); @@ -353,11 +353,11 @@ TEST(conv_add_mul, conv_add_mul) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; absl::flat_hash_map attrs1; attrs1["epsilon"] = static_cast(0.001); @@ -405,11 +405,11 @@ TEST(fuse_conv_add1, fuse_conv_add1) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({1, 1}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({0, 0}); + attrs["stride"] = std::vector({1, 1}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({0, 0}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; auto c = program.conv2d(A, B, attrs); auto d = program.elementwise_add(c, C); @@ -493,17 +493,18 @@ TEST(conv_bn, conv_bn) { Program program; absl::flat_hash_map attrs; - attrs["stride"] = std::vector({2, 2}); - attrs["dilation"] = std::vector({1, 1}); - attrs["padding"] = std::vector({3, 3}); + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); std::string src_layout = "NCHW"; - attrs["data_format"] = src_layout; + attrs["data_format"] = src_layout; absl::flat_hash_map attrs1; attrs1["epsilon"] = static_cast(0.001); auto c = program.conv2d(A, B, attrs); - auto d = program.fused_batchnorm_inference(c, Scale, Bias, Mean, Variance, attrs1); + auto d = + program.fused_batchnorm_inference(c, Scale, Bias, Mean, Variance, attrs1); Target target = common::DefaultTarget(); program.SetInputs({A, B, Scale, Bias, Mean, Variance}); diff --git a/paddle/cinn/hlir/pass/reduce_split_pass.cc b/paddle/cinn/hlir/pass/reduce_split_pass.cc index 513930d87d33b..39cf77f7ab32e 100644 --- a/paddle/cinn/hlir/pass/reduce_split_pass.cc +++ b/paddle/cinn/hlir/pass/reduce_split_pass.cc @@ -31,8 +31,12 @@ using framework::Operator; using framework::shape_t; bool IsReduceOp(const framework::Node* node) { - static std::unordered_set reduce_op_type = { - "reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_all", "reduce_any"}; + static std::unordered_set reduce_op_type = {"reduce_sum", + "reduce_mean", + "reduce_max", + "reduce_min", + "reduce_all", + "reduce_any"}; if (reduce_op_type.count(node->op()->name)) { return true; } else { @@ -69,12 +73,16 @@ class ReduceSplitPass { public: // Find the reduce op with nwhc format and large shape, split it into two ops static int Apply(framework::Graph* graph) { - int MAX_NUM_THREADS = common::DefaultNVGPUTarget().max_num_threads(); + int MAX_NUM_THREADS = common::DefaultNVGPUTarget().max_num_threads(); constexpr int MAX_ITER_PER_THREAD = 32; // empirical value - int cnt = 0; - auto& shape_dict = graph->GetMutableAttrs>("infershape"); - auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); + int cnt = 0; + auto& shape_dict = + graph->GetMutableAttrs>( + "infershape"); + auto& dtype_dict = + graph->GetMutableAttrs>( + "inferdtype"); // loop the nodes in graph and find reduce_xx op auto nodes_inorder = std::get<0>(graph->topological_order()); @@ -84,15 +92,15 @@ class ReduceSplitPass { } auto n = node->safe_as(); if (IsReduceOp(n)) { - auto* op = n->op(); + auto* op = n->op(); auto name = op->name; - auto dims = absl::get>(n->attrs.attr_store.at("dim")); + auto dims = absl::get>(n->attrs.attr_store.at("dim")); bool keep_dim = absl::get(n->attrs.attr_store.at("keep_dim")); - auto in = (*n->inlinks().begin())->source()->safe_as(); - auto out = (*n->outlinks().begin())->sink()->safe_as(); + auto in = (*n->inlinks().begin())->source()->safe_as(); + auto out = (*n->outlinks().begin())->sink()->safe_as(); - auto in_shape = shape_dict.at(in->id()); + auto in_shape = shape_dict.at(in->id()); auto out_shape = shape_dict.at(out->id()); // all preceding reduced CHECK(in_shape.size() > 1); @@ -104,92 +112,127 @@ class ReduceSplitPass { } } bool reduce_all = - all_preceding_dim_reduced && std::find(dims.begin(), dims.end(), in_shape.size() - 1) != dims.end(); + all_preceding_dim_reduced && + std::find(dims.begin(), dims.end(), in_shape.size() - 1) != + dims.end(); if (!all_preceding_dim_reduced || reduce_all) { continue; } - int numel = std::accumulate(in_shape.begin(), in_shape.end(), 1, std::multiplies()); - int reduce_numel = std::accumulate(in_shape.begin(), in_shape.end() - 1, 1, std::multiplies()); + int numel = std::accumulate( + in_shape.begin(), in_shape.end(), 1, std::multiplies()); + int reduce_numel = std::accumulate( + in_shape.begin(), in_shape.end() - 1, 1, std::multiplies()); CHECK(reduce_numel > 0); // if the numel is not large enough, it is no need to split // if loop times is too large with reduce optimize - int size = std::accumulate(in_shape.begin(), (in_shape.end() - 1), 1, std::multiplies()); - int tail = 0; + int size = std::accumulate( + in_shape.begin(), (in_shape.end() - 1), 1, std::multiplies()); + int tail = 0; bool bound = true; - auto shape = pe::GetFirstStepReduceShape({size, in_shape.back()}, {0}, bound, tail); + auto shape = pe::GetFirstStepReduceShape( + {size, in_shape.back()}, {0}, bound, tail); CHECK(bound); CHECK_EQ(shape.size(), 3); - auto res = DivideToClosetNum(reduce_numel); + auto res = DivideToClosetNum(reduce_numel); int reduce_numel0 = std::get<0>(res), reduce_numel1 = std::get<1>(res); VLOG(3) << "InShape -> " << std::accumulate( - in_shape.begin(), in_shape.end(), std::string(""), [](const std::string& left, const int right) { + in_shape.begin(), + in_shape.end(), + std::string(""), + [](const std::string& left, const int right) { return left + std::to_string(right) + " "; }); - VLOG(3) << " reduce split : " << reduce_numel0 << " " << reduce_numel1 << " " << in_shape.back(); + VLOG(3) << " reduce split : " << reduce_numel0 << " " << reduce_numel1 + << " " << in_shape.back(); VLOG(3) << " reshape split : " - << std::accumulate(shape.begin(), shape.end(), std::string(""), [](std::string left, int right) { - return left + std::to_string(right) + " "; - }); + << std::accumulate(shape.begin(), + shape.end(), + std::string(""), + [](std::string left, int right) { + return left + std::to_string(right) + " "; + }); // Two do reduce split: // 1. reshape_loop > split_loop // 2. reshape thread > max_threads. - if (shape[0] <= reduce_numel0 && shape[1] * shape[2] <= common::GetMaxThreads()) { + if (shape[0] <= reduce_numel0 && + shape[1] * shape[2] <= common::GetMaxThreads()) { VLOG(3) << " Don't Do Reduce Split!"; continue; } VLOG(3) << " Do Reduce Split!"; /* - if ((!all_preceding_dim_reduced) || numel <= MAX_NUM_THREADS * MAX_ITER_PER_THREAD || reduce_all) { - continue; + if ((!all_preceding_dim_reduced) || numel <= MAX_NUM_THREADS * + MAX_ITER_PER_THREAD || reduce_all) { continue; } */ // create reshape node0 - Node* reshape0 = new Node(Operator::Get("reshape"), "reshape", common::UniqName("reshape_split")); - reshape0->attrs.attr_store["shape"] = - std::vector{reduce_numel0, reduce_numel1, in_shape[in_shape.size() - 1]}; + Node* reshape0 = new Node(Operator::Get("reshape"), + "reshape", + common::UniqName("reshape_split")); + reshape0->attrs.attr_store["shape"] = std::vector{ + reduce_numel0, reduce_numel1, in_shape[in_shape.size() - 1]}; graph->RegisterNode(reshape0->id(), reshape0); in->LinkTo(reshape0); in->UnLinkSingleTo(node); node->UnLinkSingleTo(out); - auto reshape0_data = new NodeData(Shared(reshape0), 0, 0, common::UniqName("var"), false); + auto reshape0_data = new NodeData( + Shared(reshape0), 0, 0, common::UniqName("var"), false); graph->RegisterNode(reshape0_data->id(), reshape0_data); reshape0->LinkTo(reshape0_data); - shape_dict[reshape0_data->id()] = absl::get>(reshape0->attrs.attr_store.at("shape")); - dtype_dict[reshape0_data->id()] = common::Str2Type(common::Type2Str(dtype_dict[in->id()])); + shape_dict[reshape0_data->id()] = + absl::get>(reshape0->attrs.attr_store.at("shape")); + dtype_dict[reshape0_data->id()] = + common::Str2Type(common::Type2Str(dtype_dict[in->id()])); // create reduce node0 - Node* reduce0 = new Node(Operator::Get(name), name, common::UniqName(name + "_split")); - reduce0->attrs.attr_store["dim"] = std::vector{0}; - reduce0->attrs.attr_store["keep_dim"] = absl::get(n->attrs.attr_store.at("keep_dim")); + Node* reduce0 = new Node( + Operator::Get(name), name, common::UniqName(name + "_split")); + reduce0->attrs.attr_store["dim"] = std::vector{0}; + reduce0->attrs.attr_store["keep_dim"] = + absl::get(n->attrs.attr_store.at("keep_dim")); graph->RegisterNode(reduce0->id(), reduce0); reshape0_data->LinkTo(reduce0); - auto reduce0_data = new NodeData(Shared(reduce0), 0, 0, common::UniqName("var"), false); + auto reduce0_data = new NodeData( + Shared(reduce0), 0, 0, common::UniqName("var"), false); graph->RegisterNode(reduce0_data->id(), reduce0_data); reduce0->LinkTo(reduce0_data); - shape_dict[reduce0_data->id()] = keep_dim ? std::vector{1, reduce_numel1, in_shape[in_shape.size() - 1]} - : std::vector{reduce_numel1, in_shape[in_shape.size() - 1]}; - dtype_dict[reduce0_data->id()] = common::Str2Type(common::Type2Str(dtype_dict[in->id()])); + shape_dict[reduce0_data->id()] = + keep_dim ? std::vector{1, + reduce_numel1, + in_shape[in_shape.size() - 1]} + : std::vector{reduce_numel1, + in_shape[in_shape.size() - 1]}; + dtype_dict[reduce0_data->id()] = + common::Str2Type(common::Type2Str(dtype_dict[in->id()])); // create reduce node1 - Node* reduce1 = new Node(Operator::Get(name), name, common::UniqName(name + "_split")); - reduce1->attrs.attr_store["dim"] = keep_dim ? std::vector{0, 1} : std::vector{0}; - reduce1->attrs.attr_store["keep_dim"] = absl::get(n->attrs.attr_store.at("keep_dim")); + Node* reduce1 = new Node( + Operator::Get(name), name, common::UniqName(name + "_split")); + reduce1->attrs.attr_store["dim"] = + keep_dim ? std::vector{0, 1} : std::vector{0}; + reduce1->attrs.attr_store["keep_dim"] = + absl::get(n->attrs.attr_store.at("keep_dim")); graph->RegisterNode(reduce1->id(), reduce1); reduce0_data->LinkTo(reduce1); - auto reduce1_data = new NodeData(Shared(reduce1), 0, 0, common::UniqName("var"), false); + auto reduce1_data = new NodeData( + Shared(reduce1), 0, 0, common::UniqName("var"), false); graph->RegisterNode(reduce1_data->id(), reduce1_data); reduce1->LinkTo(reduce1_data); - shape_dict[reduce1_data->id()] = keep_dim ? std::vector{1, 1, in_shape[in_shape.size() - 1]} - : std::vector{in_shape[in_shape.size() - 1]}; - dtype_dict[reduce1_data->id()] = common::Str2Type(common::Type2Str(dtype_dict[in->id()])); + shape_dict[reduce1_data->id()] = + keep_dim ? std::vector{1, 1, in_shape[in_shape.size() - 1]} + : std::vector{in_shape[in_shape.size() - 1]}; + dtype_dict[reduce1_data->id()] = + common::Str2Type(common::Type2Str(dtype_dict[in->id()])); // create reshape node1 - Node* reshape1 = new Node(Operator::Get("reshape"), "reshape", common::UniqName("reshape_split")); + Node* reshape1 = new Node(Operator::Get("reshape"), + "reshape", + common::UniqName("reshape_split")); reshape1->attrs.attr_store["shape"] = out_shape; graph->RegisterNode(reshape1->id(), reshape1); reduce1_data->LinkTo(reshape1); diff --git a/paddle/cinn/hlir/pass/reduce_split_pass_test.cc b/paddle/cinn/hlir/pass/reduce_split_pass_test.cc index eec8e861e5a1a..688d6a0bd607c 100644 --- a/paddle/cinn/hlir/pass/reduce_split_pass_test.cc +++ b/paddle/cinn/hlir/pass/reduce_split_pass_test.cc @@ -25,7 +25,8 @@ std::unordered_map> RunModelTest( const std::unordered_map>& input_data, const std::unordered_set& fetch_ids) { auto target = common::DefaultTarget(); - auto graph = std::make_shared(program, fetch_ids, target); + auto graph = + std::make_shared(program, fetch_ids, target); hlir::framework::ApplyPasses(graph.get(), passes); auto scope = BuildScope(target, graph); @@ -54,16 +55,21 @@ TEST(ReduceSplit, reduce_mean_nhwc) { NetBuilder net_builder("reduce_sum_nhwc"); // create model int N = 64, H = 14, W = 14, C = 256; - auto in = net_builder.CreateInput(Float(32), {N, H, W, C}, "in"); + auto in = net_builder.CreateInput(Float(32), {N, H, W, C}, "in"); auto out = net_builder.ReduceSum(in, {0, 1, 2}); auto fetch_ids = {out->id}; std::vector input_data(N * H * W * C); InitRandomVector(&input_data, input_data.size(), 0.0f, 1.0f, 1e-3); - std::unordered_map> feeds = {{"in", input_data}}; - auto program = net_builder.Build(); - auto output = RunModelTest(program, {"ReduceSplit", "OpFusionPass", "FusionMergePass"}, feeds, fetch_ids); - auto output_expect = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, feeds, fetch_ids); + std::unordered_map> feeds = { + {"in", input_data}}; + auto program = net_builder.Build(); + auto output = RunModelTest(program, + {"ReduceSplit", "OpFusionPass", "FusionMergePass"}, + feeds, + fetch_ids); + auto output_expect = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, feeds, fetch_ids); for (auto& out : output) { CheckOutput(out.second, output_expect[out.first], 1e-8, 1e-4); @@ -74,16 +80,21 @@ TEST(ReduceSplit, reduce_mean_nhwc_small_size) { NetBuilder net_builder("reduce_sum_nhwc"); // create model int N = 32, H = 2, W = 2, C = 256; - auto in = net_builder.CreateInput(Float(32), {N, H, W, C}, "in"); + auto in = net_builder.CreateInput(Float(32), {N, H, W, C}, "in"); auto out = net_builder.ReduceSum(in, {0, 1, 2}); auto fetch_ids = {out->id}; std::vector input_data(N * H * W * C); InitRandomVector(&input_data, input_data.size(), 0.0f, 1.0f, 1e-3); - std::unordered_map> feeds = {{"in", input_data}}; - auto program = net_builder.Build(); - auto output = RunModelTest(program, {"ReduceSplit", "OpFusionPass", "FusionMergePass"}, feeds, fetch_ids); - auto output_expect = RunModelTest(program, {"OpFusionPass", "FusionMergePass"}, feeds, fetch_ids); + std::unordered_map> feeds = { + {"in", input_data}}; + auto program = net_builder.Build(); + auto output = RunModelTest(program, + {"ReduceSplit", "OpFusionPass", "FusionMergePass"}, + feeds, + fetch_ids); + auto output_expect = RunModelTest( + program, {"OpFusionPass", "FusionMergePass"}, feeds, fetch_ids); for (auto& out : output) { // should be equal, since ReduceSplit is not affected diff --git a/paddle/cinn/hlir/pass/single_group_optimize_pass.cc b/paddle/cinn/hlir/pass/single_group_optimize_pass.cc index b1f932427436d..c79dbfcb5ef24 100644 --- a/paddle/cinn/hlir/pass/single_group_optimize_pass.cc +++ b/paddle/cinn/hlir/pass/single_group_optimize_pass.cc @@ -88,9 +88,11 @@ std::vector> SingleGroupOptimizePass::Apply() { optimized_groups.emplace_back(group); continue; } - CHECK(node_to_groups_.count(nodes.front())) << "Can't find node " << nodes.front()->id() << " in node_to_groups_!"; - // NOTE(jeff41404): if a node used by more than one group, then will not be optimized to avoid unexpected changes to - // other group which has multiple nodes. + CHECK(node_to_groups_.count(nodes.front())) + << "Can't find node " << nodes.front()->id() << " in node_to_groups_!"; + // NOTE(jeff41404): if a node used by more than one group, then will not be + // optimized to avoid unexpected changes to other group which has multiple + // nodes. if (node_to_groups_[nodes.front()] > 1) { optimized_groups.emplace_back(group); continue; @@ -127,33 +129,39 @@ bool SingleGroupOptimizePass::TryReplaceNodeToCustomCall(Node* node) const { if (can_replace) { // replace single node group to custom call function const auto& op_name = node->op()->name; - VLOG(4) << "Replaced node " << framework::DebugString(node) << " by \"custom_call\""; + VLOG(4) << "Replaced node " << framework::DebugString(node) + << " by \"custom_call\""; node->attrs.attr_store["original_op"] = op_name; - node->attrs.op = framework::Operator::Get("custom_call"); + node->attrs.op = framework::Operator::Get("custom_call"); } if (can_replace_to_memset) { - node->attrs.attr_store["custom_call"] = std::string("cinn_call_cuda_memset"); + node->attrs.attr_store["custom_call"] = + std::string("cinn_call_cuda_memset"); } if (can_replace_to_memcpy) { - node->attrs.attr_store["custom_call"] = std::string("cinn_call_cuda_memcpy"); + node->attrs.attr_store["custom_call"] = + std::string("cinn_call_cuda_memcpy"); } return can_replace; } bool SingleGroupOptimizePass::CanReplaceToMemset(Node* node) const { - const auto& op_name = node->op()->name; + const auto& op_name = node->op()->name; const auto& attr_store = node->attrs.attr_store; if (op_name == "fill_constant" || op_name == "const_scalar") { - CHECK(attr_store.count("dtype")) << "Missing attribute \"dtype\" in op \"fill_constant\""; + CHECK(attr_store.count("dtype")) + << "Missing attribute \"dtype\" in op \"fill_constant\""; CHECK(absl::holds_alternative(attr_store.at("dtype"))); // if the value is 0, the op can always replaced by memset const auto& value_attr = attr_store.at("value"); - bool is_value_zero = utils::IsValueZero(value_attr) || utils::IsValueZero(value_attr) || - utils::IsValueZero(value_attr) || utils::IsValueZero(value_attr) || + bool is_value_zero = utils::IsValueZero(value_attr) || + utils::IsValueZero(value_attr) || + utils::IsValueZero(value_attr) || + utils::IsValueZero(value_attr) || utils::IsValueZero(value_attr); return is_value_zero; // can support memset non-0 ? diff --git a/paddle/cinn/hlir/pass/test_dot_merger.cc b/paddle/cinn/hlir/pass/test_dot_merger.cc index 8d490192970b4..ee4586571ec06 100644 --- a/paddle/cinn/hlir/pass/test_dot_merger.cc +++ b/paddle/cinn/hlir/pass/test_dot_merger.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "gtest/gtest.h" #include "paddle/cinn/frontend/net_builder.h" #include "paddle/cinn/frontend/pass/pass_test_helper.h" #include "paddle/cinn/runtime/flags.h" -#include "gtest/gtest.h" namespace cinn::frontend::pass { @@ -39,27 +39,31 @@ TEST(DotMerger, lhs) { } int m = 2, k = 10201, n1 = 100, n2 = 100, n3 = 100, axis = 1; NetBuilder builder("net_builder"); - auto a = builder.CreateInput(Float(32), {m, k}, "A"); - auto b = builder.CreateInput(Float(32), {k, n1}, "B"); - auto c = builder.CreateInput(Float(32), {k, n2}, "C"); + auto a = builder.CreateInput(Float(32), {m, k}, "A"); + auto b = builder.CreateInput(Float(32), {k, n1}, "B"); + auto c = builder.CreateInput(Float(32), {k, n2}, "C"); auto c1 = builder.CreateInput(Float(32), {k, n3}, "E"); - auto d = builder.Matmul(a, b); - auto e = builder.Matmul(a, c); + auto d = builder.Matmul(a, b); + auto e = builder.Matmul(a, c); auto e1 = builder.Matmul(a, c1); - auto f = builder.CreateInput(Float(32), {m, n1}, "D"); - auto g = builder.Add(d, f); - auto h = builder.Add(e, g); + auto f = builder.CreateInput(Float(32), {m, n1}, "D"); + auto g = builder.Add(d, f); + auto h = builder.Add(e, g); auto h1 = builder.Add(e1, h); - auto p = builder.Build(); + auto p = builder.Build(); Target target = common::DefaultNVGPUTarget(); std::vector input_ids; - absl::c_transform(std::vector{a.id(), b.id(), c.id(), c1.id()}, - std::back_inserter(input_ids), - [](absl::string_view id) { return std::string(id); }); - OptimizeConfig passes({{"Decomposer", "RemoveIdentity", "TransposeFoldingInput"}, {}}, - {{"OpFusionPass", "FusionMergePass"}, {"DotMerger", "OpFusionPass", "FusionMergePass"}}); - CompareResult(&p, target, input_ids, {h1->id}, 0, std::move(passes), 123, true); + absl::c_transform( + std::vector{a.id(), b.id(), c.id(), c1.id()}, + std::back_inserter(input_ids), + [](absl::string_view id) { return std::string(id); }); + OptimizeConfig passes( + {{"Decomposer", "RemoveIdentity", "TransposeFoldingInput"}, {}}, + {{"OpFusionPass", "FusionMergePass"}, + {"DotMerger", "OpFusionPass", "FusionMergePass"}}); + CompareResult( + &p, target, input_ids, {h1->id}, 0, std::move(passes), 123, true); } /* @@ -81,20 +85,26 @@ TEST(DotMerger, rhs) { } NetBuilder builder("net_builder"); int m1 = 50, m2 = 50, k = 10201, n = 2, axis = 0; - auto a = builder.CreateInput(Float(32), {m1, k}, "A"); - auto b = builder.CreateInput(Float(32), {m2, k}, "B"); - auto c = builder.CreateInput(Float(32), {k, n}, "C"); - auto d = builder.Matmul(a, c); - auto e = builder.Matmul(b, c); - auto f = builder.Concat({d, e}, axis); - auto p = builder.Build(); + auto a = builder.CreateInput(Float(32), {m1, k}, "A"); + auto b = builder.CreateInput(Float(32), {m2, k}, "B"); + auto c = builder.CreateInput(Float(32), {k, n}, "C"); + auto d = builder.Matmul(a, c); + auto e = builder.Matmul(b, c); + auto f = builder.Concat({d, e}, axis); + auto p = builder.Build(); Target target = common::DefaultNVGPUTarget(); std::vector input_ids; absl::c_transform(std::vector{a.id(), b.id(), c.id()}, std::back_inserter(input_ids), [](absl::string_view id) { return std::string(id); }); - OptimizeConfig passes({{"Decomposer", "RemoveIdentity", "TransposeFoldingInput", "GemmRewriter"}, {}}, - {{"OpFusionPass", "FusionMergePass"}, {"DotMerger", "OpFusionPass", "FusionMergePass"}}); - CompareResult(&p, target, input_ids, {f->id}, 0, std::move(passes), 123, true); + OptimizeConfig passes({{"Decomposer", + "RemoveIdentity", + "TransposeFoldingInput", + "GemmRewriter"}, + {}}, + {{"OpFusionPass", "FusionMergePass"}, + {"DotMerger", "OpFusionPass", "FusionMergePass"}}); + CompareResult( + &p, target, input_ids, {f->id}, 0, std::move(passes), 123, true); } } // namespace cinn::frontend::pass diff --git a/paddle/cinn/hlir/pass/test_primitive_ops.cc b/paddle/cinn/hlir/pass/test_primitive_ops.cc index 2a7276baa45be..99dbc09ad0dbe 100755 --- a/paddle/cinn/hlir/pass/test_primitive_ops.cc +++ b/paddle/cinn/hlir/pass/test_primitive_ops.cc @@ -49,7 +49,8 @@ TEST(batch_norm_meta, batch_norm_meta) { auto a = program.batchnorm(A, Scale, Bias, Mean, Variance, attrs); - auto b = program.fused_batchnorm_inference(A, Scale, Bias, Mean, Variance, attrs); + auto b = + program.fused_batchnorm_inference(A, Scale, Bias, Mean, Variance, attrs); Target target = common::DefaultTarget(); program.SetInputs({A}); @@ -82,7 +83,7 @@ TEST(reduction, reduce) { Program program; std::unordered_map attrs; std::vector axis = {1, 2}; - bool keep_dim = false; + bool keep_dim = false; auto a = program.reduce_max(A, axis, keep_dim); auto b = program.reduce_min(A, axis, keep_dim); diff --git a/paddle/cinn/hlir/pe/broadcast.cc b/paddle/cinn/hlir/pe/broadcast.cc index 7992e61d97c30..2658fd3fd3980 100644 --- a/paddle/cinn/hlir/pe/broadcast.cc +++ b/paddle/cinn/hlir/pe/broadcast.cc @@ -51,13 +51,16 @@ void GetBroadcastShape(const std::vector& shape1, CHECK_GE(axis_val, -1) << "wrong axis: " << axis_val << std::endl; if (shape1.size() >= shape2.size()) { CHECK_LE(axis_val, static_cast(shape1.size() - shape2.size())) - << "wrong axis: " << axis_val << " is not <= " << shape1.size() - shape2.size() << std::endl; + << "wrong axis: " << axis_val + << " is not <= " << shape1.size() - shape2.size() << std::endl; if (axis_val >= 0) { *axis_offset = shape1.size() - shape2.size() - axis_val; for (int i = 1; i <= *axis_offset; ++i) { - // specified axis to align, we insert Expr one in tensor B so as to align right with tensor A. + // specified axis to align, we insert Expr one in tensor B so as to + // align right with tensor A. shape2_new.emplace_back(Expr(1)); - common_shape->insert(common_shape->begin(), shape1[static_cast(shape1.size() - i)]); + common_shape->insert(common_shape->begin(), + shape1[static_cast(shape1.size() - i)]); // flag is used to indicate whether to include the indice or not. broadcast_flag1->emplace_back(true); broadcast_flag2->emplace_back(false); @@ -65,13 +68,16 @@ void GetBroadcastShape(const std::vector& shape1, } } else { CHECK_LE(axis_val, static_cast(shape2.size() - shape1.size())) - << "wrong axis: " << axis_val << " is not <= " << shape2.size() - shape1.size() << std::endl; + << "wrong axis: " << axis_val + << " is not <= " << shape2.size() - shape1.size() << std::endl; if (axis_val >= 0) { *axis_offset = shape2.size() - shape1.size() - axis_val; for (int i = 1; i <= *axis_offset; ++i) { - // specified axis to align, we insert Expr one in tensor B so as to align right with tensor A. + // specified axis to align, we insert Expr one in tensor B so as to + // align right with tensor A. shape1_new.emplace_back(Expr(1)); - common_shape->insert(common_shape->begin(), shape2[static_cast(shape2.size() - i)]); + common_shape->insert(common_shape->begin(), + shape2[static_cast(shape2.size() - i)]); // flag is used to indicate whether to include the indice or not. broadcast_flag2->emplace_back(true); broadcast_flag1->emplace_back(false); @@ -106,7 +112,8 @@ void GetBroadcastShape(const std::vector& shape1, broadcast_flag1->emplace_back(true); broadcast_flag2->emplace_back(false); } else if (var1 && var2) { - Expr max_var = ir::Max::Make(shape1_new[size1 - i], shape2_new[size2 - i]); + Expr max_var = + ir::Max::Make(shape1_new[size1 - i], shape2_new[size2 - i]); common_shape->insert(common_shape->begin(), max_var); broadcast_flag1->emplace_back(true); broadcast_flag2->emplace_back(true); @@ -137,16 +144,17 @@ void GetBroadcastShape(const std::vector& shape1, broadcast_flag1->emplace_back(true); broadcast_flag2->emplace_back(false); } else { - LOG(FATAL) << "Incompatible broadcast dims " << shape1_new[size1 - i] << " and " << shape2_new[size2 - i] - << " in: " << shape1_new << " and " << shape2_new << std::endl; + LOG(FATAL) << "Incompatible broadcast dims " << shape1_new[size1 - i] + << " and " << shape2_new[size2 - i] << " in: " << shape1_new + << " and " << shape2_new << std::endl; } } } if (size1 != size2) { int max_size = std::max(size1, size2); - auto& shape = (size1 > size2) ? shape1_new : shape2_new; - auto var_l = (size1 > size2) ? broadcast_flag1 : broadcast_flag2; - auto var_s = (size1 > size2) ? broadcast_flag2 : broadcast_flag1; + auto& shape = (size1 > size2) ? shape1_new : shape2_new; + auto var_l = (size1 > size2) ? broadcast_flag1 : broadcast_flag2; + auto var_s = (size1 > size2) ? broadcast_flag2 : broadcast_flag1; for (; i <= max_size; ++i) { common_shape->insert(common_shape->begin(), shape[max_size - i]); var_l->emplace_back(true); @@ -161,7 +169,8 @@ void GetBroadcastOutShape(const std::vector& input_shape1, int axis) { std::vector shape1; std::vector shape2; - auto fn_expr = [](const std::vector& input_shape, std::vector* shape) { + auto fn_expr = [](const std::vector& input_shape, + std::vector* shape) { for (int i = 0; i < input_shape.size(); i++) { shape->push_back(Expr(input_shape[i])); } @@ -172,7 +181,13 @@ void GetBroadcastOutShape(const std::vector& input_shape1, std::vector broadcast_flags2; int axis_offset = 0; std::vector out_shape; - GetBroadcastShape(shape1, shape2, &out_shape, &broadcast_flags1, &broadcast_flags2, &axis_offset, Expr(axis)); + GetBroadcastShape(shape1, + shape2, + &out_shape, + &broadcast_flags1, + &broadcast_flags2, + &axis_offset, + Expr(axis)); CHECK(common_shape); for (auto& shape : out_shape) { common_shape->push_back(shape.as_int32()); @@ -205,8 +220,8 @@ void GetBroadcastIndice(const std::vector& indice, broadcast_indice2->push_back(indice[i]); } else if (flag_size - i <= tensor_b->shape.size() + axis_offset && broadcast_indice2->size() < tensor_b->shape.size()) { - // insert indice 0 when have not yet reached the dimension of tensor. Meanwhile we have to consider the case of - // axis alignment. + // insert indice 0 when have not yet reached the dimension of tensor. + // Meanwhile we have to consider the case of axis alignment. broadcast_indice2->push_back(Expr(0)); } } @@ -218,7 +233,7 @@ Tensor Broadcast(const FuncOp& op, const Tensor& a, const Tensor& b, const std::string& output_name = "", - const Expr& axis = Expr(-1)) { + const Expr& axis = Expr(-1)) { std::vector common_shape; std::vector broadcast_flags1; std::vector broadcast_flags2; @@ -226,22 +241,37 @@ Tensor Broadcast(const FuncOp& op, // the counts of left-shift of tensor b so as to right alignment int axis_offset = 0; - GetBroadcastShape(a->shape, b->shape, &common_shape, &broadcast_flags1, &broadcast_flags2, &axis_offset, axis); + GetBroadcastShape(a->shape, + b->shape, + &common_shape, + &broadcast_flags1, + &broadcast_flags2, + &axis_offset, + axis); auto fn = [=](const std::vector& indice) { std::vector broadcast_indice1; std::vector broadcast_indice2; - GetBroadcastIndice( - indice, a, b, axis_offset, &broadcast_indice1, &broadcast_indice2, broadcast_flags1, broadcast_flags2); + GetBroadcastIndice(indice, + a, + b, + axis_offset, + &broadcast_indice1, + &broadcast_indice2, + broadcast_flags1, + broadcast_flags2); return op(a(broadcast_indice1), b(broadcast_indice2)); }; Tensor output = Compute(common_shape, fn, output_name); return output; } -#define HLIR_IMP_BC_PE(name__, compute__) \ - Tensor name__(const Tensor& A, const Tensor& B, const std::string& output_name, const Expr& axis) { \ - auto fn = [&](const Expr& a, const Expr& b) { compute__ }; \ - return Broadcast(fn, A, B, output_name, axis); \ +#define HLIR_IMP_BC_PE(name__, compute__) \ + Tensor name__(const Tensor& A, \ + const Tensor& B, \ + const std::string& output_name, \ + const Expr& axis) { \ + auto fn = [&](const Expr& a, const Expr& b) { compute__ }; \ + return Broadcast(fn, A, B, output_name, axis); \ } HLIR_IMP_BC_PE(Add, return a + b;); @@ -256,11 +286,14 @@ HLIR_IMP_BC_PE(Minimum, return ir::Min::Make(a, b);); HLIR_IMP_BC_PE(LeftShift, return a << b;); HLIR_IMP_BC_PE(RightShift, return a >> b;); HLIR_IMP_BC_PE(LogicalRightShift, return lang::LogicalRightShift(a, b);); -HLIR_IMP_BC_PE(LogicalAnd, return ir::Cast::Make(Bool(), a) && ir::Cast::Make(Bool(), b);); -HLIR_IMP_BC_PE(LogicalOr, return ir::Cast::Make(Bool(), a) || ir::Cast::Make(Bool(), b);); -HLIR_IMP_BC_PE(LogicalXOr, - return (ir::Cast::Make(Bool(), a) || ir::Cast::Make(Bool(), b)) && - !(ir::Cast::Make(Bool(), a) && ir::Cast::Make(Bool(), b));); +HLIR_IMP_BC_PE(LogicalAnd, + return ir::Cast::Make(Bool(), a) && ir::Cast::Make(Bool(), b);); +HLIR_IMP_BC_PE(LogicalOr, + return ir::Cast::Make(Bool(), a) || ir::Cast::Make(Bool(), b);); +HLIR_IMP_BC_PE( + LogicalXOr, + return (ir::Cast::Make(Bool(), a) || ir::Cast::Make(Bool(), b)) && + !(ir::Cast::Make(Bool(), a) && ir::Cast::Make(Bool(), b));); HLIR_IMP_BC_PE(BitwiseAnd, return a & b;); HLIR_IMP_BC_PE(BitwiseOr, return a | b;); HLIR_IMP_BC_PE(BitwiseXor, return a ^ b;); @@ -272,20 +305,28 @@ HLIR_IMP_BC_PE(GreaterEqual, return a >= b;); HLIR_IMP_BC_PE(LessEqual, return a <= b;); HLIR_IMP_BC_PE(Pow, return lang::Pow(a, b);); -Tensor Atan2(const Tensor& A, const Tensor& B, const std::string& output_name, const Expr& axis) { +Tensor Atan2(const Tensor& A, + const Tensor& B, + const std::string& output_name, + const Expr& axis) { constexpr double PI = 3.14159265358979323846; auto fn = [&](const Expr& elem_a, const Expr& elem_b) { - auto atan = lang::Atan(elem_a / elem_b); - auto pi = common::make_const(atan->type(), PI); + auto atan = lang::Atan(elem_a / elem_b); + auto pi = common::make_const(atan->type(), PI); auto half_pi = common::make_const(atan->type(), PI / 2); - auto zero = ir::Zero(atan->type()); + auto zero = ir::Zero(atan->type()); return ir::Select::Make( ir::EQ::Make(elem_b, zero), ir::Select::Make( - ir::EQ::Make(elem_a, zero), zero, ir::Select::Make(ir::GT::Make(elem_a, zero), half_pi, -half_pi)), + ir::EQ::Make(elem_a, zero), + zero, + ir::Select::Make(ir::GT::Make(elem_a, zero), half_pi, -half_pi)), ir::Select::Make( - ir::GT::Make(elem_b, zero), atan, ir::Select::Make(ir::GE::Make(elem_a, zero), atan + pi, atan - pi))); + ir::GT::Make(elem_b, zero), + atan, + ir::Select::Make( + ir::GE::Make(elem_a, zero), atan + pi, atan - pi))); }; return Broadcast(fn, A, B, output_name, axis); } @@ -295,8 +336,10 @@ Tensor BroadcastTo(const Tensor& A, const std::vector& broadcast_axes, const std::string& out_name) { auto A_shape = A->shape; - CHECK_EQ(A_shape.size(), broadcast_axes.size()) << "broadcast_axes's size should be same with the input shape's size"; - CHECK_GE(out_shape.size(), broadcast_axes.size()) << "broadcast_axes's size should be no more than out_shape's size"; + CHECK_EQ(A_shape.size(), broadcast_axes.size()) + << "broadcast_axes's size should be same with the input shape's size"; + CHECK_GE(out_shape.size(), broadcast_axes.size()) + << "broadcast_axes's size should be no more than out_shape's size"; auto axes = broadcast_axes; for (auto& axis : axes) { // if axis < 0, plus out_shape.size @@ -318,7 +361,8 @@ Tensor BroadcastTo(const Tensor& A, } else if (a_shape_i == out_shape[axes[idx]]) { broadcast_indice.push_back(indice[axes[idx]]); } else { - LOG(FATAL) << "fail to broad cast input shape " << a_shape_i << " to output shape " << out_shape[axes[idx]]; + LOG(FATAL) << "fail to broad cast input shape " << a_shape_i + << " to output shape " << out_shape[axes[idx]]; } } return A(broadcast_indice); @@ -350,21 +394,25 @@ ir::Tensor IsClose(const ir::Tensor& x, auto check_y_nan = lang::IsNan(b); // out = equal_nan && isnan(a) == isnan(b); - auto check_nan_same = Expr(equal_nan) && ir::EQ::Make(check_x_nan, check_y_nan); + auto check_nan_same = + Expr(equal_nan) && ir::EQ::Make(check_x_nan, check_y_nan); // check whether x and y are close // T left = (a > b ? a - b : b - a); auto left = ir::Select::Make(a > b, a - b, b - a); // T right = atol + (b > 0 ? rtol * b : (-rtol) * b); - auto right = ir::Cast::Make(x->type(), atol) + ir::Select::Make(b > ir::Zero(b->type()), - ir::Cast::Make(x->type(), rtol) * b, - ir::Cast::Make(x->type(), -rtol) * b); + auto right = ir::Cast::Make(x->type(), atol) + + ir::Select::Make(b > ir::Zero(b->type()), + ir::Cast::Make(x->type(), rtol) * b, + ir::Cast::Make(x->type(), -rtol) * b); // T diff = (left > right ? left - right : right - left); auto diff = ir::Select::Make(left > right, left - right, right - left); // out = a == b || left <= right || diff <= 1e-15; - auto check_diff = (ir::EQ::Make(a, b) || (left <= right)) || (diff <= lang::Epsilon(diff->type())); + auto check_diff = (ir::EQ::Make(a, b) || (left <= right)) || + (diff <= lang::Epsilon(diff->type())); - return ir::Select::Make(check_x_nan || check_y_nan, check_nan_same, check_diff); + return ir::Select::Make( + check_x_nan || check_y_nan, check_nan_same, check_diff); }; return Broadcast(fn, x, y, out_name, Expr(axis)); } diff --git a/paddle/cinn/hlir/pe/broadcast.h b/paddle/cinn/hlir/pe/broadcast.h index f3ce590ff1dc9..bc7a7da0e3d69 100644 --- a/paddle/cinn/hlir/pe/broadcast.h +++ b/paddle/cinn/hlir/pe/broadcast.h @@ -31,7 +31,8 @@ void GetBroadcastOutShape(const std::vector& input_shape1, * * @param A The first Tensor or Expr * @param B The second Tensor or Expr - * @param axis Tensor B's beginning position of Tensor A. Default is -1(right align) and then axis = rank(X)-rank(Y). + * @param axis Tensor B's beginning position of Tensor A. Default is -1(right + * align) and then axis = rank(X)-rank(Y). * @param out_name The name of the output Tensor * * @return The result Tensor or Expr. @@ -42,11 +43,12 @@ void GetBroadcastOutShape(const std::vector& input_shape1, * shape(A) = (2, 3, 4, 5), shape(B) = (2), with axis=0 * shape(A) = (2, 3, 4, 5), shape(B) = (2, 1), with axis=0 */ -#define HLIR_DCL_BC_PE(name__) \ - ir::Tensor name__(const ir::Tensor& A, \ - const ir::Tensor& B, \ - const std::string& out_name = common::UniqName("T_" #name__ "_out"), \ - const Expr& axis = Expr()); +#define HLIR_DCL_BC_PE(name__) \ + ir::Tensor name__( \ + const ir::Tensor& A, \ + const ir::Tensor& B, \ + const std::string& out_name = common::UniqName("T_" #name__ "_out"), \ + const Expr& axis = Expr()); //! Compute A + B with auto-broadcasting. HLIR_DCL_BC_PE(Add); @@ -107,19 +109,22 @@ ir::Tensor Pow(const ir::Tensor& A, const Expr& axis, const common::Target& target); -ir::Tensor BroadcastTo(const ir::Tensor& A, - const std::vector& out_shape, - const std::vector& broadcast_axes, - const std::string& out_name = common::UniqName("T_broadcast_to_out")); +ir::Tensor BroadcastTo( + const ir::Tensor& A, + const std::vector& out_shape, + const std::vector& broadcast_axes, + const std::string& out_name = common::UniqName("T_broadcast_to_out")); -// This operator checks if all x and y satisfy the condition: |x - y| <= atol + rtol * |y| -ir::Tensor IsClose(const ir::Tensor& x, - const ir::Tensor& y, - int axis = -1, - float rtol = 1e-05f, - float atol = 1e-08f, - bool equal_nan = false, - const std::string& out_name = common::UniqName("IsClose_output")); +// This operator checks if all x and y satisfy the condition: |x - y| <= atol + +// rtol * |y| +ir::Tensor IsClose( + const ir::Tensor& x, + const ir::Tensor& y, + int axis = -1, + float rtol = 1e-05f, + float atol = 1e-08f, + bool equal_nan = false, + const std::string& out_name = common::UniqName("IsClose_output")); } // namespace pe } // namespace hlir diff --git a/paddle/cinn/hlir/pe/elementwise.cc b/paddle/cinn/hlir/pe/elementwise.cc index fe11303b98e70..d26dba6a04dc1 100644 --- a/paddle/cinn/hlir/pe/elementwise.cc +++ b/paddle/cinn/hlir/pe/elementwise.cc @@ -30,21 +30,31 @@ using ir::Expr; using ir::Tensor; using lang::Compute; -#define HLIR_IMP_UNARY_PE(name__) \ - std::vector name__(const Tensor& A, const std::string& output_name) { \ - return {Compute( \ - A->shape, [=](const std::vector& indice) { return lang::name__(A(indice)); }, output_name)}; \ +#define HLIR_IMP_UNARY_PE(name__) \ + std::vector name__(const Tensor& A, \ + const std::string& output_name) { \ + return {Compute( \ + A->shape, \ + [=](const std::vector& indice) { \ + return lang::name__(A(indice)); \ + }, \ + output_name)}; \ } -#define HLIR_MKL_IMP_UNARY_PE(name__, ex_name__) \ - std::vector name__##MKL(const Tensor& A, const std::string& output_name) { \ - CHECK(A->type().is_float()) << "type should be float or double but get " << A->type(); \ - std::string fn_name = "cinn_mkl_" #ex_name__ "_v_fp" + std::to_string(A->type().bits()); \ - auto call = Compute( \ - {Expr(1)}, [=]() -> Expr { return lang::CallExtern(fn_name, {A}); }, output_name); \ - auto out = call->TupleGet(0); \ - out->WithBuffer(A->type()); \ - return {out, call}; \ +#define HLIR_MKL_IMP_UNARY_PE(name__, ex_name__) \ + std::vector name__##MKL(const Tensor& A, \ + const std::string& output_name) { \ + CHECK(A->type().is_float()) \ + << "type should be float or double but get " << A->type(); \ + std::string fn_name = \ + "cinn_mkl_" #ex_name__ "_v_fp" + std::to_string(A->type().bits()); \ + auto call = Compute( \ + {Expr(1)}, \ + [=]() -> Expr { return lang::CallExtern(fn_name, {A}); }, \ + output_name); \ + auto out = call->TupleGet(0); \ + out->WithBuffer(A->type()); \ + return {out, call}; \ } HLIR_MKL_IMP_UNARY_PE(Exp, exp); @@ -108,7 +118,9 @@ HLIR_IMP_UNARY_PE(Cbrt); HLIR_IMP_UNARY_PE(Clz); HLIR_IMP_UNARY_PE(Popc); -ir::Tensor Squeeze(const ir::Tensor& A, const std::vector& axes, const std::string& output_name) { +ir::Tensor Squeeze(const ir::Tensor& A, + const std::vector& axes, + const std::string& output_name) { std::vector position; std::vector output_shape; if (axes.size()) { @@ -170,17 +182,20 @@ ir::Tensor ExpandDims(const ir::Tensor& A, idx.push_back(indice[i]); } } - CHECK_EQ(idx.size(), A->shape.size()) << "The index size not equal with the input rank."; + CHECK_EQ(idx.size(), A->shape.size()) + << "The index size not equal with the input rank."; return A(idx); }, UniqName(output_name)); } -ir::Tensor Reshape(const ir::Tensor& A, const std::vector& new_shape, const std::string& name) { +ir::Tensor Reshape(const ir::Tensor& A, + const std::vector& new_shape, + const std::string& name) { std::vector new_expr_shape; std::vector A_expr_shape = A->shape; - int input_total_size = 1; - int output_total_size = 1; + int input_total_size = 1; + int output_total_size = 1; for (auto& i : A_expr_shape) { CHECK(i.is_constant()) << "Input tensor's shape should be constant value."; input_total_size *= static_cast(i.get_constant()); @@ -190,7 +205,8 @@ ir::Tensor Reshape(const ir::Tensor& A, const std::vector& new_shape, const new_expr_shape.push_back(Expr(i)); } CHECK_EQ(input_total_size, output_total_size) - << "In op reshape, the input tensor and output tensor's total size should be equal, please check!"; + << "In op reshape, the input tensor and output tensor's total size " + "should be equal, please check!"; auto res = Compute( new_expr_shape, [=](const std::vector& indice) { @@ -210,19 +226,31 @@ ir::Tensor Reshape(const ir::Tensor& A, const std::vector& new_shape, const return res; } -ir::Tensor Cast(const ir::Tensor& A, const Type& dtype, const std::string& name) { +ir::Tensor Cast(const ir::Tensor& A, + const Type& dtype, + const std::string& name) { auto res = Compute( - A->shape, [=](const std::vector& indices) { return ir::Cast::Make(dtype, A(indices)); }, name); + A->shape, + [=](const std::vector& indices) { + return ir::Cast::Make(dtype, A(indices)); + }, + name); return res; } -ir::Tensor Arange( - const float start, const float stop, const float step, const Type& dtype, const std::string& output_name) { - int num = static_cast(std::ceil((stop - start) / step)); +ir::Tensor Arange(const float start, + const float stop, + const float step, + const Type& dtype, + const std::string& output_name) { + int num = static_cast(std::ceil((stop - start) / step)); ir::Tensor res = lang::Compute( {Expr(num)}, [=](const std::vector& indices) { - return ir::Cast::Make(dtype, Expr(start) + Expr(step) * ir::Cast::Make(common::F32(), indices[0])); + return ir::Cast::Make( + dtype, + Expr(start) + + Expr(step) * ir::Cast::Make(common::F32(), indices[0])); }, output_name); return res; diff --git a/paddle/cinn/hlir/pe/elementwise.h b/paddle/cinn/hlir/pe/elementwise.h index efe773cb7d01b..95e93c39d5c27 100644 --- a/paddle/cinn/hlir/pe/elementwise.h +++ b/paddle/cinn/hlir/pe/elementwise.h @@ -33,9 +33,13 @@ namespace pe { * * @return The result Tensor. */ -#define HLIR_DCL_UNARY_PE(name__) \ - std::vector name__(const ir::Tensor& A, const std::string& output_name = "T_" #name__ "_out"); \ - std::vector name__##MKL(const ir::Tensor& A, const std::string& output_name = "T_" #name__ "_mkl_out"); +#define HLIR_DCL_UNARY_PE(name__) \ + std::vector name__( \ + const ir::Tensor& A, \ + const std::string& output_name = "T_" #name__ "_out"); \ + std::vector name__##MKL( \ + const ir::Tensor& A, \ + const std::string& output_name = "T_" #name__ "_mkl_out"); HLIR_DCL_UNARY_PE(Exp); HLIR_DCL_UNARY_PE(Erf); @@ -81,20 +85,26 @@ HLIR_DCL_UNARY_PE(Popc); template ir::Tensor AssignValue(const std::vector& values, - const common::Type& type = common::type_of(), + const common::Type& type = common::type_of(), const std::string& output_name = "T_assign_value_out") { - CHECK(!values.empty()) << "The input of pe::AssignValue should not empty! Please check."; + CHECK(!values.empty()) + << "The input of pe::AssignValue should not empty! Please check."; auto out = lang::Compute( {ir::Expr(static_cast(values.size()))}, [=](const std::vector& indice) { - auto init_value = - (type == common::type_of()) ? ir::Expr(values[0]) : common::cast(ir::Expr(values[0]), type); - ir::Expr previous = ir::Select::Make(ir::EQ::Make(indice[0], ir::Expr(0)), init_value, lang::Zero(type)); + auto init_value = (type == common::type_of()) + ? ir::Expr(values[0]) + : common::cast(ir::Expr(values[0]), type); + ir::Expr previous = ir::Select::Make( + ir::EQ::Make(indice[0], ir::Expr(0)), init_value, lang::Zero(type)); for (int i = 1; i < values.size(); ++i) { - auto val = (type == common::type_of()) ? ir::Expr(values[i]) : common::cast(ir::Expr(values[i]), type); - previous = ir::Select::Make(ir::EQ::Make(indice[0], ir::Expr(i)), val, previous); + auto val = (type == common::type_of()) + ? ir::Expr(values[i]) + : common::cast(ir::Expr(values[i]), type); + previous = ir::Select::Make( + ir::EQ::Make(indice[0], ir::Expr(i)), val, previous); } return previous; }, @@ -103,26 +113,32 @@ ir::Tensor AssignValue(const std::vector& values, return out; } -ir::Tensor Squeeze(const ir::Tensor& A, - const std::vector& axes = {}, - const std::string& output_name = UniqName("T_Elementwise_Squeeze_out")); - -ir::Tensor ExpandDims(const ir::Tensor& A, - const std::vector& axes, - const std::vector& out_shape, - const std::string& output_name = UniqName("T_Elementwise_ExpandDims_out")); - -ir::Tensor Reshape(const ir::Tensor& A, - const std::vector& new_shape, - const std::string& name = UniqName("T_Elementwise_Reshape_out")); - -ir::Tensor Cast(const ir::Tensor& A, const Type& dtype, const std::string& name = UniqName("T_Elementwise_Cast_out")); - -ir::Tensor Arange(const float start, - const float stop, - const float step, - const Type& dtype, - const std::string& name = UniqName("T_Elementwise_Arange_out")); +ir::Tensor Squeeze( + const ir::Tensor& A, + const std::vector& axes = {}, + const std::string& output_name = UniqName("T_Elementwise_Squeeze_out")); + +ir::Tensor ExpandDims( + const ir::Tensor& A, + const std::vector& axes, + const std::vector& out_shape, + const std::string& output_name = UniqName("T_Elementwise_ExpandDims_out")); + +ir::Tensor Reshape( + const ir::Tensor& A, + const std::vector& new_shape, + const std::string& name = UniqName("T_Elementwise_Reshape_out")); + +ir::Tensor Cast(const ir::Tensor& A, + const Type& dtype, + const std::string& name = UniqName("T_Elementwise_Cast_out")); + +ir::Tensor Arange( + const float start, + const float stop, + const float step, + const Type& dtype, + const std::string& name = UniqName("T_Elementwise_Arange_out")); } // namespace pe } // namespace hlir diff --git a/paddle/cinn/hlir/pe/ir_schedule_pe.cc b/paddle/cinn/hlir/pe/ir_schedule_pe.cc index 4f8f99d8874d0..784f7ff85d5ab 100644 --- a/paddle/cinn/hlir/pe/ir_schedule_pe.cc +++ b/paddle/cinn/hlir/pe/ir_schedule_pe.cc @@ -39,14 +39,18 @@ namespace cinn { namespace hlir { namespace pe { -void IRElementwiseSchedule(ir::IRSchedule &ir_sch, const std::vector &output_shape, const common::Target &target) { - VLOG(3) << "Before IRElementwiseSchedule, new ir is : " << ir_sch.GetModule().GetExprs().at(0); +void IRElementwiseSchedule(ir::IRSchedule &ir_sch, + const std::vector &output_shape, + const common::Target &target) { + VLOG(3) << "Before IRElementwiseSchedule, new ir is : " + << ir_sch.GetModule().GetExprs().at(0); if (target == common::DefaultNVGPUTarget()) { auto blocks = ir_sch.GetAllBlocks(); ir_sch.FlattenLoops(ir_sch.GetLoops(blocks[0]), true); auto loops = ir_sch.GetLoops(blocks[0]); - auto size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + auto size = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); if (size <= target.max_num_threads()) { ir_sch.Bind(loops[0], "threadIdx.x"); } else { @@ -59,17 +63,22 @@ void IRElementwiseSchedule(ir::IRSchedule &ir_sch, const std::vector &outpu auto blocks = ir_sch.GetAllBlocks(); ir_sch.FlattenLoops(ir_sch.GetLoops(blocks[0]), true); } - VLOG(3) << "After IRElementwiseSchedule, new ir is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After IRElementwiseSchedule, new ir is : " + << ir_sch.GetModule().GetExprs().at(0); } -void IRInjectiveSchedule(ir::IRSchedule &ir_sch, const std::vector &output_shape, const common::Target &target) { - VLOG(3) << "Before IRInjectiveSchedule, new ir is : " << ir_sch.GetModule().GetExprs().at(0); +void IRInjectiveSchedule(ir::IRSchedule &ir_sch, + const std::vector &output_shape, + const common::Target &target) { + VLOG(3) << "Before IRInjectiveSchedule, new ir is : " + << ir_sch.GetModule().GetExprs().at(0); if (target == common::DefaultNVGPUTarget()) { auto blocks = ir_sch.GetAllBlocks(); ir_sch.FlattenLoops(ir_sch.GetLoops(blocks[0]), false); auto loops = ir_sch.GetLoops(blocks[0]); - auto size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + auto size = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); if (size <= target.max_num_threads()) { ir_sch.Bind(loops[0], "threadIdx.x"); } else { @@ -82,27 +91,29 @@ void IRInjectiveSchedule(ir::IRSchedule &ir_sch, const std::vector &output_ auto blocks = ir_sch.GetAllBlocks(); ir_sch.FlattenLoops(ir_sch.GetLoops(blocks[0]), false); } - VLOG(3) << "After IRInjectiveSchedule, new ir is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After IRInjectiveSchedule, new ir is : " + << ir_sch.GetModule().GetExprs().at(0); } void IRScheduleInjectiveCPU(ir::IRSchedule &ir_sch, const std::vector &output_shape, const common::Target &target, bool vectorizable) { - VLOG(3) << "Begin IRScheduleInjectiveCPU" << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "Begin IRScheduleInjectiveCPU" + << ir_sch.GetModule().GetExprs().at(0); auto all_blocks = ir_sch.GetAllBlocks(); - auto loops = ir_sch.GetLoops(all_blocks[0]); - int dims = output_shape.size(); - int factor = GetBasicFactor(GetTensor(all_blocks[0])->type(), target); - auto fused = loops[0]; + auto loops = ir_sch.GetLoops(all_blocks[0]); + int dims = output_shape.size(); + int factor = GetBasicFactor(GetTensor(all_blocks[0])->type(), target); + auto fused = loops[0]; if (dims >= 5) { CHECK_GE(loops.size(), 3U); fused = ir_sch.Fuse({loops[0], loops[1], loops[2]}); - dims = dims - 2; + dims = dims - 2; } else if (dims >= 3) { CHECK_GE(loops.size(), 2U); fused = ir_sch.Fuse({loops[0], loops[1]}); - dims = dims - 1; + dims = dims - 1; } // This part needs to be fixed. @Haoze /* ir_sch.Parallel(fused); @@ -117,7 +128,8 @@ void IRScheduleInjectiveCPU(ir::IRSchedule &ir_sch, ir_sch.Parallel(splited[0]); } } */ - VLOG(3) << "After IRScheduleInjectiveCPU, new ir is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After IRScheduleInjectiveCPU, new ir is : " + << ir_sch.GetModule().GetExprs().at(0); } void IRCudaScheduleInjective(ir::IRSchedule &ir_sch, @@ -125,12 +137,13 @@ void IRCudaScheduleInjective(ir::IRSchedule &ir_sch, const common::Target &target) { VLOG(3) << "Begin IRCudaScheduleInjective "; auto all_blocks = ir_sch.GetAllBlocks(); - auto loops = ir_sch.GetLoops(all_blocks[0]); - auto fused = ir_sch.Fuse(loops); + auto loops = ir_sch.GetLoops(all_blocks[0]); + auto fused = ir_sch.Fuse(loops); - int num_thread = target.max_num_threads(); + int num_thread = target.max_num_threads(); int vector_width = 1; - int prod_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + int prod_size = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); if (prod_size > num_thread) { auto splited = ir_sch.Split(fused, {-1, num_thread}); ir_sch.Bind(splited[0], "blockIdx.x"); @@ -138,12 +151,14 @@ void IRCudaScheduleInjective(ir::IRSchedule &ir_sch, } else { ir_sch.Bind(fused, "threadIdx.x"); } - VLOG(3) << "After IRCudaScheduleInjective, new ir is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After IRCudaScheduleInjective, new ir is : " + << ir_sch.GetModule().GetExprs().at(0); } -std::vector IRCudaScheduleMatMul(const common::CINNValuePack &arg_pack, - const std::vector &output_shape, - const common::Target &target) { +std::vector IRCudaScheduleMatMul( + const common::CINNValuePack &arg_pack, + const std::vector &output_shape, + const common::Target &target) { if (target.arch == Target::Arch::X86) { CINN_NOT_IMPLEMENTED } @@ -164,10 +179,11 @@ std::vector IRCudaScheduleMatMul(const common::CINNValuePack auto init_block = ir_sch.GetAllBlocks().front(); VLOG(3) << "Matmul lowered expr:\n" << ir_sch.GetModule().GetExprs().front(); - int prod_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + int prod_size = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); if (prod_size > 1) { int num_thread = target.max_num_threads(); - auto loops = ir_sch.GetLoops(init_block); + auto loops = ir_sch.GetLoops(init_block); if (loops.size() == 1) { if (ir::GetLoopExtent(loops[0]) > num_thread) { auto splited = ir_sch.Split(loops[0], {-1, num_thread}); @@ -182,7 +198,7 @@ std::vector IRCudaScheduleMatMul(const common::CINNValuePack init_block = ir_sch.GetAllBlocks().front(); ir_sch.Fuse(init_block, {0, 1}); init_block = ir_sch.GetAllBlocks().front(); - loops = ir_sch.GetLoops(init_block); + loops = ir_sch.GetLoops(init_block); } ir_sch.Bind(loops[0], "blockIdx.x"); ir_sch.Bind(loops[1], "threadIdx.x"); @@ -192,13 +208,15 @@ std::vector IRCudaScheduleMatMul(const common::CINNValuePack return {common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; } -void IRCudaScheduleMul(ir::IRSchedule &ir_sch, const std::vector &output_shape, const common::Target &target) { +void IRCudaScheduleMul(ir::IRSchedule &ir_sch, + const std::vector &output_shape, + const common::Target &target) { auto all_blocks = ir_sch.GetAllBlocks(); - auto loops = ir_sch.GetLoops(all_blocks.back()); + auto loops = ir_sch.GetLoops(all_blocks.back()); CHECK_GE(loops.size(), 2U); auto splited = ir_sch.Split(loops[1], {-1, 2}); - all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks.back()); + all_blocks = ir_sch.GetAllBlocks(); + loops = ir_sch.GetLoops(all_blocks.back()); ir_sch.Bind(loops[0], "blockIdx.x"); ir_sch.Bind(loops[1], "threadIdx.x"); } @@ -209,13 +227,13 @@ void IRMulScheduleCPU(ir::IRSchedule &ir_sch, ir_sch.MergeExprs(); auto all_blocks = ir_sch.GetAllBlocks(); CHECK_EQ(all_blocks.size(), 4U); - auto loops = ir_sch.GetLoops(all_blocks[1]); + auto loops = ir_sch.GetLoops(all_blocks[1]); int loop_size = loops.size(); // ir_sch.Reorder({loops[loop_size-1], loops[loop_size-2]}); if (reduce_first_shape.back() > 1) { all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[3]); + loops = ir_sch.GetLoops(all_blocks[3]); ir_sch.Unroll(loops.back()); } } @@ -224,7 +242,8 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, const std::vector> &output_shapes, int axis, const common::Target &target) { - VLOG(3) << "In IRCudaSplitSchedule, Before schedule expr is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "In IRCudaSplitSchedule, Before schedule expr is : " + << ir_sch.GetModule().GetExprs().at(0); ir_sch.MergeExprs(); // if all output are with same shape bool with_same_shape = true; @@ -238,8 +257,11 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, // collect block names auto get_block_name = [](ir::Expr expr) { CHECK(expr.As()); - CHECK(expr.As()->schedule_block.As()); - return expr.As()->schedule_block.As()->name; + CHECK(expr.As() + ->schedule_block.As()); + return expr.As() + ->schedule_block.As() + ->name; }; std::vector block_names; auto blocks = ir_sch.GetAllBlocks(); @@ -250,17 +272,22 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, if (with_same_shape && target == common::DefaultNVGPUTarget()) { // flat loops. { - auto tsize = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies()); + auto tsize = std::accumulate(output_shapes[0].begin(), + output_shapes[0].end(), + 1, + std::multiplies()); for (auto &block_name : block_names) { ir_sch.FlattenLoops(ir_sch.GetLoops(block_name), false); if (tsize > target.max_num_threads()) { // split [-1, 256] - auto splited = ir_sch.Split(ir_sch.GetLoops(block_name)[0], {-1, target.max_num_threads() / 4}); + auto splited = ir_sch.Split(ir_sch.GetLoops(block_name)[0], + {-1, target.max_num_threads() / 4}); ir_sch.Bind(splited[0], "blockIdx.x"); ir_sch.Bind(splited[1], "threadIdx.x"); } else { - auto splited = ir_sch.Split(ir_sch.GetLoops(block_name)[0], {1, tsize}); + auto splited = + ir_sch.Split(ir_sch.GetLoops(block_name)[0], {1, tsize}); ir_sch.Bind(splited[0], "blockIdx.x"); ir_sch.Bind(splited[1], "threadIdx.x"); } @@ -270,7 +297,8 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, { for (int idx = 1; idx < block_names.size(); ++idx) { auto master_loops = ir_sch.GetLoops(block_names[0]); - ir_sch.SimpleComputeAt(ir_sch.GetBlock(block_names[idx]), master_loops[1]); + ir_sch.SimpleComputeAt(ir_sch.GetBlock(block_names[idx]), + master_loops[1]); } } } else if (target == common::DefaultNVGPUTarget()) { @@ -283,11 +311,13 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, auto tsize = first_loop.As()->extent.as_int32(); if (tsize > target.max_num_threads()) { // split [-1, 256] - auto splited = ir_sch.Split(ir_sch.GetLoops(block_names[idx])[0], {-1, target.max_num_threads() / 4}); + auto splited = ir_sch.Split(ir_sch.GetLoops(block_names[idx])[0], + {-1, target.max_num_threads() / 4}); ir_sch.Bind(splited[0], "blockIdx.x"); ir_sch.Bind(splited[1], "threadIdx.x"); } else { - auto splited = ir_sch.Split(ir_sch.GetLoops(block_names[idx])[0], {1, tsize}); + auto splited = + ir_sch.Split(ir_sch.GetLoops(block_names[idx])[0], {1, tsize}); ir_sch.Bind(splited[0], "blockIdx.x"); ir_sch.Bind(splited[1], "threadIdx.x"); } @@ -300,22 +330,29 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, } } } - VLOG(3) << "In IRCudaSplitSchedule, After schedule expr is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "In IRCudaSplitSchedule, After schedule expr is : " + << ir_sch.GetModule().GetExprs().at(0); } void IRCudaScheduleReduce(ir::IRSchedule &ir_sch, ir::Tensor output, int last_dimension_num, const common::Target &target) { - VLOG(3) << "Before IRCudaScheduleReduce : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "Before IRCudaScheduleReduce : " + << ir_sch.GetModule().GetExprs().at(0); int parallel_thread_num = 1; - auto &output_shape = output->shape; - for (int idx = output_shape.size() - 1; idx >= static_cast(output_shape.size()) - last_dimension_num; --idx) { + auto &output_shape = output->shape; + for (int idx = output_shape.size() - 1; + idx >= static_cast(output_shape.size()) - last_dimension_num; + --idx) { parallel_thread_num *= output_shape[idx].as_int32(); } - int index = ir_sch.GetLoops(output->name + "__reduce_init").size() - last_dimension_num; - for (int idx = output_shape.size() - last_dimension_num; idx < static_cast(output_shape.size()) - 1; ++idx) { + int index = ir_sch.GetLoops(output->name + "__reduce_init").size() - + last_dimension_num; + for (int idx = output_shape.size() - last_dimension_num; + idx < static_cast(output_shape.size()) - 1; + ++idx) { auto loops = ir_sch.GetLoops(output->name); ir_sch.Fuse({loops[index], loops[index + 1]}); } @@ -349,14 +386,16 @@ void IRCudaScheduleReduce(ir::IRSchedule &ir_sch, auto loops = ir_sch.GetLoops(output->name); ir_sch.Bind(loops[0], "blockIdx.x"); } - VLOG(3) << "After IRCudaScheduleReduce : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After IRCudaScheduleReduce : " + << ir_sch.GetModule().GetExprs().at(0); } void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, ir::Tensor tmp_out, ir::Tensor out, const common::Target &target) { - VLOG(3) << "Before IRCudaScheduleBlockReduceInternal : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "Before IRCudaScheduleBlockReduceInternal : " + << ir_sch.GetModule().GetExprs().at(0); int fuse_times = ir_sch.GetLoops(tmp_out->name).size() - 2; for (int idx = 0; idx < fuse_times; ++idx) { for (auto &tensor : {tmp_out, out}) { @@ -371,30 +410,41 @@ void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, CHECK_EQ(out->shape[0], Expr(1)); // block and root - auto out_block = ir_sch.GetBlock(out->name); + auto out_block = ir_sch.GetBlock(out->name); auto root_block = ir_sch.GetRootBlock(out_block); CHECK(out_block->as()); - CHECK(out_block->as()->schedule_block->as()); + CHECK(out_block->as() + ->schedule_block->as()); // create var auto var = ir::Var(ir::Expr(0), ir::Expr(1), common::UniqName("i")); out_block->as()->iter_values.push_back(var); - out_block->as()->schedule_block->as()->iter_vars.push_back(var); + out_block->as() + ->schedule_block->as() + ->iter_vars.push_back(var); CHECK(root_block->as()); - CHECK(root_block->as()->schedule_block->as()); + CHECK(root_block->as() + ->schedule_block->as()); // create for and block node - auto for_node = - ir::For::Make(var, Expr(0), Expr(1), ir::ForType::Serial, ir::DeviceAPI::UNK, ir::Block::Make({out_block})); - auto block_node = ir::Block::Make({root_block->as() - ->schedule_block->as() - ->body->as() - ->stmts[0], - for_node}); - - root_block->as()->schedule_block->as()->body = block_node; + auto for_node = ir::For::Make(var, + Expr(0), + Expr(1), + ir::ForType::Serial, + ir::DeviceAPI::UNK, + ir::Block::Make({out_block})); + auto block_node = + ir::Block::Make({root_block->as() + ->schedule_block->as() + ->body->as() + ->stmts[0], + for_node}); + + root_block->as() + ->schedule_block->as() + ->body = block_node; for (auto &tensor : {tmp_out, out}) { auto loops = ir_sch.GetLoops(tensor->name); @@ -403,7 +453,7 @@ void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, } auto loops_tmp_out = ir_sch.GetLoops(tmp_out->name); - auto loops_out = ir_sch.GetLoops(out->name); + auto loops_out = ir_sch.GetLoops(out->name); if (loops_tmp_out.size() == 1) { ir_sch.Bind(loops_tmp_out[0], "threadIdx.x"); ir_sch.Bind(loops_out[0], "threadIdx.x"); @@ -424,7 +474,8 @@ void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, ir_sch.SetBuffer(block, "local", true); } - VLOG(3) << "After IRCudaScheduleBlockReduceInternal : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After IRCudaScheduleBlockReduceInternal : " + << ir_sch.GetModule().GetExprs().at(0); } void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, @@ -432,7 +483,8 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, ir::Tensor tmp_out, ir::Tensor out, const common::Target &target) { - VLOG(3) << "Before IRCudaScheduleBlockReduce : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "Before IRCudaScheduleBlockReduce : " + << ir_sch.GetModule().GetExprs().at(0); int tmp_put_shape_size_without_reduce = 0; for (auto i : tmp_out->shape) { CHECK(i.is_constant()); @@ -447,17 +499,19 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, } int tmp_out_shape_size = tmp_put_shape_size_without_reduce + 1; - for (int idx = 0; idx < reduce_temp_out_shape_size - tmp_out_shape_size; ++idx) { - auto loops = ir_sch.GetLoops(reduce_tmp_out->name); + for (int idx = 0; idx < reduce_temp_out_shape_size - tmp_out_shape_size; + ++idx) { + auto loops = ir_sch.GetLoops(reduce_tmp_out->name); int reduce_axis = reduce_tmp_out->reduce_axis.size(); if (loops.size() >= tmp_put_shape_size_without_reduce + 2 + reduce_axis) - ir_sch.Fuse({loops[tmp_put_shape_size_without_reduce], loops[tmp_put_shape_size_without_reduce + 1]}); + ir_sch.Fuse({loops[tmp_put_shape_size_without_reduce], + loops[tmp_put_shape_size_without_reduce + 1]}); } // fuse parallel dimension for (int idx = 0; idx < tmp_put_shape_size_without_reduce - 1; ++idx) { for (auto &tensor : {reduce_tmp_out, tmp_out, out}) { - auto loops = ir_sch.GetLoops(tensor->name); + auto loops = ir_sch.GetLoops(tensor->name); int reduce_axis = tensor->reduce_axis.size(); if (loops.size() >= 2 + reduce_axis) { ir_sch.Fuse({loops[0], loops[1]}); @@ -465,39 +519,54 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, } } - // Special handling when keepdim = True in reduce stage 1. When keepdim = True, shape size may not be equal to 1. But - // we still need to split the loops, otherwise there will be a problem of data read and write conflict. - int numel = std::accumulate(tmp_out->shape.begin(), tmp_out->shape.end(), 1, [](const int &num, const ir::Expr &e) { - return num * e.as_int32(); - }); - if (tmp_out->shape.size() == 1 || (numel == tmp_out->shape.back().as_int32())) { + // Special handling when keepdim = True in reduce stage 1. When keepdim = + // True, shape size may not be equal to 1. But we still need to split the + // loops, otherwise there will be a problem of data read and write conflict. + int numel = std::accumulate( + tmp_out->shape.begin(), + tmp_out->shape.end(), + 1, + [](const int &num, const ir::Expr &e) { return num * e.as_int32(); }); + if (tmp_out->shape.size() == 1 || + (numel == tmp_out->shape.back().as_int32())) { CHECK_EQ(out->shape[0], Expr(1)); // block and root - auto out_block = ir_sch.GetBlock(out->name); + auto out_block = ir_sch.GetBlock(out->name); auto root_block = ir_sch.GetRootBlock(out_block); CHECK(out_block->as()); - CHECK(out_block->as()->schedule_block->as()); + CHECK(out_block->as() + ->schedule_block->as()); // create var auto var = ir::Var(ir::Expr(0), ir::Expr(1), cinn::UniqName("i")); out_block->as()->iter_values.push_back(var); - out_block->as()->schedule_block->as()->iter_vars.push_back(var); + out_block->as() + ->schedule_block->as() + ->iter_vars.push_back(var); CHECK(root_block->as()); - CHECK(root_block->as()->schedule_block->as()); + CHECK(root_block->as() + ->schedule_block->as()); // create for and block node - auto for_node = - ir::For::Make(var, Expr(0), Expr(1), ir::ForType::Serial, ir::DeviceAPI::UNK, ir::Block::Make({out_block})); - auto block_node = ir::Block::Make({root_block->as() - ->schedule_block->as() - ->body->as() - ->stmts[0], - for_node}); - - root_block->as()->schedule_block->as()->body = block_node; + auto for_node = ir::For::Make(var, + Expr(0), + Expr(1), + ir::ForType::Serial, + ir::DeviceAPI::UNK, + ir::Block::Make({out_block})); + auto block_node = + ir::Block::Make({root_block->as() + ->schedule_block->as() + ->body->as() + ->stmts[0], + for_node}); + + root_block->as() + ->schedule_block->as() + ->body = block_node; for (auto &tensor : {reduce_tmp_out, tmp_out, out}) { auto loops = ir_sch.GetLoops(tensor->name); @@ -548,7 +617,8 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, ir_sch.SetBuffer(block, "local", true); } - VLOG(3) << "After IRCudaScheduleBlockReduce : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After IRCudaScheduleBlockReduce : " + << ir_sch.GetModule().GetExprs().at(0); } void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, @@ -556,26 +626,29 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, ir::Tensor internal, ir::Tensor reduce_out, const common::Target &target) { - VLOG(3) << "Before IRCudaScheduleBlockShuffleReduce : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "Before IRCudaScheduleBlockShuffleReduce : " + << ir_sch.GetModule().GetExprs().at(0); // reshape compute inline { // simplify reshape index auto hand_write_simplify = [](std::vector loops, ir::Expr block) { // check exist select. - auto find_select = ir::CollectIRNodesInOrder(block, [&](const Expr *x) { return x->As(); }); + auto find_select = ir::CollectIRNodesInOrder( + block, [&](const Expr *x) { return x->As(); }); if (find_select.size() > 0) { return; } auto schedule_realize = block.As(); - auto schedule_block = block.As()->schedule_block.As(); + auto schedule_block = block.As() + ->schedule_block.As(); int stride = 1; std::unordered_map var_strides; for (int idx = loops.size() - 1; idx > 0; --idx) { stride = stride * GetLoopExtent(loops[idx]); - auto var = loops[idx - 1].As()->loop_var; + auto var = loops[idx - 1].As()->loop_var; var_strides[var->name] = ir::Expr(stride); } @@ -591,18 +664,21 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, } auto stride = var_strides.find(var->name)->second; - index = index + ir::Expr(schedule_block->iter_vars[idx]) * stride; + index = index + ir::Expr(schedule_block->iter_vars[idx]) * stride; } - auto exprs = ir::CollectIRNodesInOrder(block, [&](const Expr *x) { return x->As(); }); + auto exprs = ir::CollectIRNodesInOrder( + block, [&](const Expr *x) { return x->As(); }); CHECK_EQ(exprs.size(), 1); - auto load = exprs.front().As(); + auto load = exprs.front().As(); load->indices = {index}; }; - hand_write_simplify(ir_sch.GetLoops(reshape->name), ir_sch.GetBlock(reshape->name)); + hand_write_simplify(ir_sch.GetLoops(reshape->name), + ir_sch.GetBlock(reshape->name)); auto block = ir_sch.GetBlock(reshape->name); ir_sch.ComputeInline(block); - VLOG(4) << "After simplify reshape index : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(4) << "After simplify reshape index : " + << ir_sch.GetModule().GetExprs().at(0); } // internal bind shared @@ -613,10 +689,12 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // auto get_loop_index = [&internal](ir::Expr inner_loop, ir::Expr block) { - auto loop_var = inner_loop.As()->loop_var; + auto loop_var = inner_loop.As()->loop_var; auto schedule_realize = block.As(); - auto schedule_block = block.As()->schedule_block.As(); - CHECK_EQ(schedule_realize->iter_values.size(), schedule_block->iter_vars.size()); + auto schedule_block = block.As() + ->schedule_block.As(); + CHECK_EQ(schedule_realize->iter_values.size(), + schedule_block->iter_vars.size()); ir::Var var_name; for (int idx = 0; idx < schedule_block->iter_vars.size(); ++idx) { @@ -631,10 +709,11 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, break; } - auto exprs = ir::CollectIRNodesInOrder(block, [&](const Expr *x) { return x->As(); }); + auto exprs = ir::CollectIRNodesInOrder( + block, [&](const Expr *x) { return x->As(); }); for (auto expr : exprs) { auto load = expr.As(); - auto t = load->tensor.as_tensor_ref(); + auto t = load->tensor.as_tensor_ref(); if (t->name != internal->name) { continue; } @@ -663,43 +742,53 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, } LOG(FATAL) << "Can't find var in tensor indeces!"; }; - auto loop_var_count = get_loop_index(ir_sch.GetLoops(reduce_out->name).back(), ir_sch.GetBlock(reduce_out->name)); + auto loop_var_count = get_loop_index(ir_sch.GetLoops(reduce_out->name).back(), + ir_sch.GetBlock(reduce_out->name)); // fuse loop to bind gpu block.x if (loop_var_count > 1) { auto internal_loops = ir_sch.GetLoops(internal->name); - std::vector fuse_internal_loops(internal_loops.begin(), internal_loops.begin() + loop_var_count); + std::vector fuse_internal_loops( + internal_loops.begin(), internal_loops.begin() + loop_var_count); ir_sch.Fuse(fuse_internal_loops); auto reduce_out_loops = ir_sch.GetLoops(reduce_out->name); - std::vector fuse_reduce_out_loops(reduce_out_loops.begin(), reduce_out_loops.begin() + loop_var_count); + std::vector fuse_reduce_out_loops( + reduce_out_loops.begin(), reduce_out_loops.begin() + loop_var_count); ir_sch.Fuse(fuse_reduce_out_loops); } - VLOG(4) << "After fuse loop for blockIdx.x : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(4) << "After fuse loop for blockIdx.x : " + << ir_sch.GetModule().GetExprs().at(0); // fuse reduce tail to bind gpu thread. - if (ir_sch.GetLoops(reduce_out->name + "__reduce_init").size() > (loop_var_count ? 2 : 1)) { + if (ir_sch.GetLoops(reduce_out->name + "__reduce_init").size() > + (loop_var_count ? 2 : 1)) { int start_index = loop_var_count == 0 ? 0 : 1; // first reduce step: // [block.x, thread.y, tail] or [thread.y, tail] auto internal_loops = ir_sch.GetLoops(internal->name + "__reduce_init"); - std::vector fuse_internal_loops(internal_loops.begin() + start_index + 1, internal_loops.end()); + std::vector fuse_internal_loops( + internal_loops.begin() + start_index + 1, internal_loops.end()); ir_sch.Fuse(fuse_internal_loops); // second reduce step: // [block.x, tail] or [tail] auto reduce_out_loops = ir_sch.GetLoops(reduce_out->name + "__reduce_init"); - std::vector fuse_reduce_out_loops(reduce_out_loops.begin() + start_index, reduce_out_loops.end()); + std::vector fuse_reduce_out_loops( + reduce_out_loops.begin() + start_index, reduce_out_loops.end()); ir_sch.Fuse(fuse_reduce_out_loops); } - VLOG(4) << "After fuse tail loop for threadIdx.x : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(4) << "After fuse tail loop for threadIdx.x : " + << ir_sch.GetModule().GetExprs().at(0); // split reduce loop to bind thread.y { if (loop_var_count > 0) { - auto reduce_out_loops = ir_sch.GetLoops(reduce_out->name + "__reduce_init"); + auto reduce_out_loops = + ir_sch.GetLoops(reduce_out->name + "__reduce_init"); ir_sch.Split(reduce_out_loops[1], {1, -1}); } else { - auto reduce_out_loops = ir_sch.GetLoops(reduce_out->name + "__reduce_init"); + auto reduce_out_loops = + ir_sch.GetLoops(reduce_out->name + "__reduce_init"); ir_sch.Split(reduce_out_loops[0], {1, -1}); } } @@ -708,8 +797,8 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // split internal tail to bind thread { auto start_index = loop_var_count == 0 ? 0 : 1; - auto i_loops = ir_sch.GetLoops(internal->name + "__reduce_init"); - auto r_loops = ir_sch.GetLoops(reduce_out->name + "__reduce_init"); + auto i_loops = ir_sch.GetLoops(internal->name + "__reduce_init"); + auto r_loops = ir_sch.GetLoops(reduce_out->name + "__reduce_init"); // bind blockIdx.x if (loop_var_count) { ir_sch.Bind(i_loops[0], "blockIdx.x"); @@ -765,7 +854,8 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, bind_thread(64); } } - VLOG(4) << "After split tail loop for threadIdx.x : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(4) << "After split tail loop for threadIdx.x : " + << ir_sch.GetModule().GetExprs().at(0); // do reorder { ir_sch.Reorder(internal->name + "__reduce_init", axis_in_nroder); @@ -774,17 +864,20 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // unroll last dim { auto i_loops = ir_sch.GetLoops(internal->name); - if (ir_sch.GetLoops(internal->name + "__reduce_init").size() < i_loops.size() && + if (ir_sch.GetLoops(internal->name + "__reduce_init").size() < + i_loops.size() && GetLoopExtent(i_loops.back()) <= 64) { ir_sch.Unroll(i_loops.back()); } auto r_loops = ir_sch.GetLoops(reduce_out->name); - if (ir_sch.GetLoops(reduce_out->name + "__reduce_init").size() < r_loops.size()) { + if (ir_sch.GetLoops(reduce_out->name + "__reduce_init").size() < + r_loops.size()) { ir_sch.Unroll(r_loops.back()); } } - VLOG(3) << "After IRCudaScheduleBlockShuffleReduce : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After IRCudaScheduleBlockShuffleReduce : " + << ir_sch.GetModule().GetExprs().at(0); } void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, @@ -793,13 +886,15 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, ir::Tensor tmp_out, ir::Tensor out, const common::Target &target) { - VLOG(3) << "Before IRCudaTwoStepReduceSchedule : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "Before IRCudaTwoStepReduceSchedule : " + << ir_sch.GetModule().GetExprs().at(0); // fuse axis - int fuse_times = ir_sch.GetLoops(internal->name).size() - internal->reduce_axis.size() - 2; + int fuse_times = + ir_sch.GetLoops(internal->name).size() - internal->reduce_axis.size() - 2; for (int idx = 0; idx < fuse_times; ++idx) { for (auto &tensor : {internal, tmp_out, out}) { - auto block = ir_sch.GetBlock(tensor->name); - auto loops = ir_sch.GetLoops(block); + auto block = ir_sch.GetBlock(tensor->name); + auto loops = ir_sch.GetLoops(block); int reduce_axis = tensor->reduce_axis.size(); ir_sch.Fuse({loops[0], loops[1]}); } @@ -807,41 +902,53 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, if (ir_sch.GetLoops(tmp_out->name).size() == 1) { // block and root - auto out_block = ir_sch.GetBlock(out->name); + auto out_block = ir_sch.GetBlock(out->name); auto root_block = ir_sch.GetRootBlock(out_block); CHECK(out_block->as()); - CHECK(out_block->as()->schedule_block->as()); + CHECK(out_block->as() + ->schedule_block->as()); // create var // auto var = ir::Var(ir::Expr(0), ir::Expr(1), "i_0"); auto var = ir::Var(ir::Expr(0), ir::Expr(1), cinn::UniqName("i")); out_block->as()->iter_values.push_back(var); - out_block->as()->schedule_block->as()->iter_vars.push_back(var); + out_block->as() + ->schedule_block->as() + ->iter_vars.push_back(var); CHECK(root_block->as()); - CHECK(root_block->as()->schedule_block->as()); + CHECK(root_block->as() + ->schedule_block->as()); // create for and block node - auto for_node = - ir::For::Make(var, Expr(0), Expr(1), ir::ForType::Serial, ir::DeviceAPI::UNK, ir::Block::Make({out_block})); - - auto block_node = ir::Block::Make({root_block->as() - ->schedule_block->as() - ->body->as() - ->stmts[0], - root_block->as() - ->schedule_block->as() - ->body->as() - ->stmts[1], - for_node}); - - root_block->as()->schedule_block->as()->body = block_node; + auto for_node = ir::For::Make(var, + Expr(0), + Expr(1), + ir::ForType::Serial, + ir::DeviceAPI::UNK, + ir::Block::Make({out_block})); + + auto block_node = + ir::Block::Make({root_block->as() + ->schedule_block->as() + ->body->as() + ->stmts[0], + root_block->as() + ->schedule_block->as() + ->body->as() + ->stmts[1], + for_node}); + + root_block->as() + ->schedule_block->as() + ->body = block_node; for (auto &tensor : {internal, tmp_out, out}) { auto block = ir_sch.GetBlock(tensor->name); auto loops = ir_sch.GetLoops(block); - if (!loops.empty()) ir_sch.Split(loops[0], {-1, ir::GetLoopExtent(loops[0])}); + if (!loops.empty()) + ir_sch.Split(loops[0], {-1, ir::GetLoopExtent(loops[0])}); } } auto reshape_block = ir_sch.GetBlock(reshape->name); @@ -856,8 +963,8 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // The current one-dimensional reduce does not make full use of SM. // This case is optimized into a two-dimensional. auto internal_loops = ir_sch.GetLoops(internal->name); - auto block_dim_x = internal_loops[1].As()->extent.as_int32(); - int block_dim_y = block_dim_x <= 32 ? 2 : 1; + auto block_dim_x = internal_loops[1].As()->extent.as_int32(); + int block_dim_y = block_dim_x <= 32 ? 2 : 1; for (auto &tensor : {internal, tmp_out, out}) { auto loops = ir_sch.GetLoops(tensor->name); @@ -876,9 +983,12 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, ir_sch.Bind(loops[1], "threadIdx.x"); } } - VLOG(3) << "After IRCudaTwoStepReduceSchedule : " << ir_sch.GetModule().GetExprs().at(0); - // ir_sch.SimpleComputeAt(ir_sch.GetBlock(tmp_out->name), ir_sch.GetLoops(out->name)[0]); - // ir_sch.SimpleComputeAt(ir_sch.GetBlock(internal->name), ir_sch.GetLoops(out->name)[0]); + VLOG(3) << "After IRCudaTwoStepReduceSchedule : " + << ir_sch.GetModule().GetExprs().at(0); + // ir_sch.SimpleComputeAt(ir_sch.GetBlock(tmp_out->name), + // ir_sch.GetLoops(out->name)[0]); + // ir_sch.SimpleComputeAt(ir_sch.GetBlock(internal->name), + // ir_sch.GetLoops(out->name)[0]); } void IRSoftmaxScheduleCPU(ir::IRSchedule &ir_sch, int axis) { @@ -896,36 +1006,41 @@ void IRSoftmaxScheduleCPU(ir::IRSchedule &ir_sch, int axis) { ir_sch.Fuse(all_blocks[2], {0, 1}); } all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[2]); + loops = ir_sch.GetLoops(all_blocks[2]); ir_sch.ComputeAt(all_blocks[1], loops[0]); } -void IRPoolScheduleGPU(ir::IRSchedule &ir_sch, const common::Target &target, int arg_pack_size) { - VLOG(3) << "Before IRPoolScheduleGPU: " << ir_sch.GetModule().GetExprs().at(0); +void IRPoolScheduleGPU(ir::IRSchedule &ir_sch, + const common::Target &target, + int arg_pack_size) { + VLOG(3) << "Before IRPoolScheduleGPU: " + << ir_sch.GetModule().GetExprs().at(0); auto all_blocks = ir_sch.GetAllBlocks(); VLOG(3) << "all_blocks[0] is : " << all_blocks[0]; auto loops = ir_sch.GetLoops(all_blocks[0]); ir_sch.Fuse(loops); // Blocks were changed after Fuse, so we have to get all blocks again. - all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[0]); + all_blocks = ir_sch.GetAllBlocks(); + loops = ir_sch.GetLoops(all_blocks[0]); auto splited = ir_sch.Split(loops[0], {-1, 1024}); ir_sch.Bind(splited[0], "blockIdx.x"); ir_sch.Bind(splited[1], "threadIdx.x"); VLOG(3) << "End IRPoolScheduleGPU: " << ir_sch.GetModule().GetExprs().at(0); } -void IRGlobalPoolScheduleGPU(ir::IRSchedule &ir_sch, const common::Target &target) { - VLOG(3) << "Before IRGlobalPoolScheduleGPU: " << ir_sch.GetModule().GetExprs().at(0); +void IRGlobalPoolScheduleGPU(ir::IRSchedule &ir_sch, + const common::Target &target) { + VLOG(3) << "Before IRGlobalPoolScheduleGPU: " + << ir_sch.GetModule().GetExprs().at(0); auto all_blocks = ir_sch.GetAllBlocks(); CHECK_EQ(all_blocks.size(), 2U); auto loops = ir_sch.GetLoops(all_blocks[1]); if (loops.size() > 1) { - auto fused = ir_sch.Fuse(all_blocks[0], {0, 1}); + auto fused = ir_sch.Fuse(all_blocks[0], {0, 1}); auto splited = ir_sch.Split(fused, {-1, 32}); - all_blocks = ir_sch.GetAllBlocks(); - fused = ir_sch.Fuse(all_blocks[1], {0, 1}); - splited = ir_sch.Split(fused, {-1, 32}); + all_blocks = ir_sch.GetAllBlocks(); + fused = ir_sch.Fuse(all_blocks[1], {0, 1}); + splited = ir_sch.Split(fused, {-1, 32}); ir_sch.Bind(splited[0], "blockIdx.x"); ir_sch.Bind(splited[1], "threadIdx.y"); all_blocks = ir_sch.GetAllBlocks(); @@ -936,15 +1051,15 @@ void IRGlobalPoolScheduleGPU(ir::IRSchedule &ir_sch, const common::Target &targe CHECK_GE(loops.size(), 3U); ir_sch.Bind(loops[2], "threadIdx.x"); } else { - loops = ir_sch.GetLoops(all_blocks[0]); + loops = ir_sch.GetLoops(all_blocks[0]); auto splited = ir_sch.Split(loops[0], {-1, 32}); - all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[1]); - splited = ir_sch.Split(loops[0], {-1, 32}); + all_blocks = ir_sch.GetAllBlocks(); + loops = ir_sch.GetLoops(all_blocks[1]); + splited = ir_sch.Split(loops[0], {-1, 32}); ir_sch.Bind(splited[0], "blockIdx.x"); ir_sch.Bind(splited[1], "threadIdx.y"); all_blocks = ir_sch.GetAllBlocks(); - splited = ir_sch.GetLoops(all_blocks[1]); + splited = ir_sch.GetLoops(all_blocks[1]); ir_sch.SimpleComputeAt(all_blocks[0], splited[1]); all_blocks = ir_sch.GetAllBlocks(); ir_sch.SetBuffer(all_blocks[0], "local", true); @@ -952,18 +1067,21 @@ void IRGlobalPoolScheduleGPU(ir::IRSchedule &ir_sch, const common::Target &targe CHECK_GE(loops.size(), 3U); ir_sch.Bind(loops[2], "threadIdx.x"); } - VLOG(3) << "After IRGlobalPoolScheduleGPU: " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After IRGlobalPoolScheduleGPU: " + << ir_sch.GetModule().GetExprs().at(0); } -void IRCudaScheduleDepthwiseConv(ir::IRSchedule &ir_sch, const std::vector &tensors) { +void IRCudaScheduleDepthwiseConv(ir::IRSchedule &ir_sch, + const std::vector &tensors) { if (tensors.size() == 3U) { CHECK(tensors[1].as_tensor()); auto input_pad = ir_sch.GetBlock(tensors[1].as_tensor_ref()->name); ir_sch.ComputeInline(input_pad); } auto all_blocks = ir_sch.GetAllBlocks(); - VLOG(3) << "Begin IRCudaScheduleDepthwiseConv with expr: " << ir_sch.GetModule().GetExprs().at(0); - auto OL = ir_sch.CacheWrite(all_blocks[0], 0, "local"); + VLOG(3) << "Begin IRCudaScheduleDepthwiseConv with expr: " + << ir_sch.GetModule().GetExprs().at(0); + auto OL = ir_sch.CacheWrite(all_blocks[0], 0, "local"); all_blocks = ir_sch.GetAllBlocks(); CHECK_GE(all_blocks.size(), 2); auto loops = ir_sch.GetLoops(all_blocks[1]); @@ -973,20 +1091,22 @@ void IRCudaScheduleDepthwiseConv(ir::IRSchedule &ir_sch, const std::vectorshape[2])); int h = output->shape[2].as_int32(); optim::Simplify(&(output->shape[3])); - int w = output->shape[3].as_int32(); + int w = output->shape[3].as_int32(); int rc = input_pad->shape[1].as_int32(); - std::string key = - "CudaDirectConvSchedule " + std::to_string(input_pad->shape[0].as_int32()) + " " + - std::to_string(input_pad->shape[1].as_int32()) + " " + std::to_string(input_pad->shape[2].as_int32()) + " " + - std::to_string(input_pad->shape[3].as_int32()) + " " + std::to_string(weights->shape[0].as_int32()) + " " + - std::to_string(weights->shape[1].as_int32()) + " " + std::to_string(weights->shape[2].as_int32()) + " " + - std::to_string(weights->shape[3].as_int32()) + " " + std::to_string(output->shape[0].as_int32()) + " " + - std::to_string(output->shape[1].as_int32()) + " " + std::to_string(output->shape[2].as_int32()) + " " + - std::to_string(output->shape[3].as_int32()); + std::string key = "CudaDirectConvSchedule " + + std::to_string(input_pad->shape[0].as_int32()) + " " + + std::to_string(input_pad->shape[1].as_int32()) + " " + + std::to_string(input_pad->shape[2].as_int32()) + " " + + std::to_string(input_pad->shape[3].as_int32()) + " " + + std::to_string(weights->shape[0].as_int32()) + " " + + std::to_string(weights->shape[1].as_int32()) + " " + + std::to_string(weights->shape[2].as_int32()) + " " + + std::to_string(weights->shape[3].as_int32()) + " " + + std::to_string(output->shape[0].as_int32()) + " " + + std::to_string(output->shape[1].as_int32()) + " " + + std::to_string(output->shape[2].as_int32()) + " " + + std::to_string(output->shape[3].as_int32()); if (res.count(key) == 0) { VLOG(3) << "Didn't find saved param, key is: " << key; } else { @@ -1015,27 +1140,28 @@ void IRCudaScheduleConv(ir::IRSchedule &ir_sch, const common::Target &target) { // return; } ir_sch.ComputeInline(all_blocks[0]); - int f_inner = GetInnerSplitter(c, h); - int block_z = SplitEven(c / f_inner); + int f_inner = GetInnerSplitter(c, h); + int block_z = SplitEven(c / f_inner); int thread_z = c / f_inner / block_z; int rc_factor = SplitEven(rc); while (w * thread_z > 1024 && thread_z % 2 == 0) { thread_z = thread_z / 2; - f_inner = f_inner * 2; + f_inner = f_inner * 2; } CHECK_LE(w * thread_z, 1024) << "Wrong Param of Conv2d!"; std::vector loops; - all_blocks = ir_sch.GetAllBlocks(); + all_blocks = ir_sch.GetAllBlocks(); auto reduce_init_name = GetTensor(all_blocks[0])->name; { // Do CacheWrite all_blocks = ir_sch.GetAllBlocks(); - auto OL = ir_sch.CacheWrite(all_blocks[1], 0, "local"); - VLOG(3) << "After CacheWrite with expr: " << ir_sch.GetModule().GetExprs().at(0); + auto OL = ir_sch.CacheWrite(all_blocks[1], 0, "local"); + VLOG(3) << "After CacheWrite with expr: " + << ir_sch.GetModule().GetExprs().at(0); } - all_blocks = ir_sch.GetAllBlocks(); - auto temp_output_name = GetTensor(all_blocks[1])->name; + all_blocks = ir_sch.GetAllBlocks(); + auto temp_output_name = GetTensor(all_blocks[1])->name; auto final_output_name = GetTensor(all_blocks[2])->name; { // Do Split @@ -1052,11 +1178,12 @@ void IRCudaScheduleConv(ir::IRSchedule &ir_sch, const common::Target &target) { { // Do ComputeAt auto temp_out = ir_sch.GetBlock(temp_output_name); - loops = ir_sch.GetLoops(final_output_name); + loops = ir_sch.GetLoops(final_output_name); CHECK_GE(loops.size(), 5U); ir_sch.ComputeAt(temp_out, loops[4]); } - VLOG(3) << "After ComputeAt with expr: " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After ComputeAt with expr: " + << ir_sch.GetModule().GetExprs().at(0); { // Do Split loops = ir_sch.GetLoops(temp_output_name); @@ -1065,11 +1192,13 @@ void IRCudaScheduleConv(ir::IRSchedule &ir_sch, const common::Target &target) { } { // Do Split - auto reduce_init = ir_sch.GetBlock(reduce_init_name); - ir::ScheduleBlockRealize *reduce_init_block = reduce_init.As(); - loops = ir_sch.GetLoops(reduce_init_name); - // If loops size is less than 4, it means one or more 1-loops are eliminated in the lowering process. - // Here we restore them by identifying the constant iter value in the ScheduleBlock + auto reduce_init = ir_sch.GetBlock(reduce_init_name); + ir::ScheduleBlockRealize *reduce_init_block = + reduce_init.As(); + loops = ir_sch.GetLoops(reduce_init_name); + // If loops size is less than 4, it means one or more 1-loops are eliminated + // in the lowering process. Here we restore them by identifying the constant + // iter value in the ScheduleBlock while (loops.size() < 4U) { for (int i = 0; i < reduce_init_block->iter_values.size(); ++i) { auto &v = reduce_init_block->iter_values[i]; @@ -1092,7 +1221,7 @@ void IRCudaScheduleConv(ir::IRSchedule &ir_sch, const common::Target &target) { { // Do SimpleComputeAt auto reduce_init = ir_sch.GetBlock(reduce_init_name); - loops = ir_sch.GetLoops(temp_output_name); + loops = ir_sch.GetLoops(temp_output_name); CHECK_GE(loops.size(), 6U); ir_sch.SimpleComputeAt(reduce_init, loops[5]); } @@ -1105,7 +1234,8 @@ void IRCudaScheduleConv(ir::IRSchedule &ir_sch, const common::Target &target) { ir_sch.Bind(loops[3], "threadIdx.z"); ir_sch.Bind(loops[4], "threadIdx.x"); } - VLOG(3) << "After IRCudaScheduleConv, expr is : " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After IRCudaScheduleConv, expr is : " + << ir_sch.GetModule().GetExprs().at(0); } void IRCudaScheduleConv2(ir::IRSchedule &ir_sch, @@ -1123,18 +1253,20 @@ void IRCudaScheduleConv2(ir::IRSchedule &ir_sch, optim::Simplify(&(output->shape[2])); optim::Simplify(&(output->shape[3])); - VLOG(3) << "Begin IRCudaScheduleConv2 with expr : " << ir_sch.GetModule().GetExprs().at(0); - auto input_cache = ir_sch.CacheRead(all_blocks[2], 1, "shared"); - all_blocks = ir_sch.GetAllBlocks(); + VLOG(3) << "Begin IRCudaScheduleConv2 with expr : " + << ir_sch.GetModule().GetExprs().at(0); + auto input_cache = ir_sch.CacheRead(all_blocks[2], 1, "shared"); + all_blocks = ir_sch.GetAllBlocks(); auto weights_cache = ir_sch.CacheRead(all_blocks[3], 2, "shared"); - all_blocks = ir_sch.GetAllBlocks(); - auto output_cache = ir_sch.CacheWrite(all_blocks[4], 0, "local"); - all_blocks = ir_sch.GetAllBlocks(); + all_blocks = ir_sch.GetAllBlocks(); + auto output_cache = ir_sch.CacheWrite(all_blocks[4], 0, "local"); + all_blocks = ir_sch.GetAllBlocks(); ir_sch.ComputeInline(all_blocks[1]); - VLOG(3) << "In the middle of IRCudaScheduleConv2, expr is: " << ir_sch.GetModule().GetExprs().at(0); - auto &x_param = res[key]["x"]; - auto &y_param = res[key]["y"]; - auto &f_param = res[key]["f"]; + VLOG(3) << "In the middle of IRCudaScheduleConv2, expr is: " + << ir_sch.GetModule().GetExprs().at(0); + auto &x_param = res[key]["x"]; + auto &y_param = res[key]["y"]; + auto &f_param = res[key]["f"]; auto &rx_param = res[key]["rx"]; auto &ry_param = res[key]["ry"]; auto &rc_param = res[key]["rc"]; @@ -1145,17 +1277,17 @@ void IRCudaScheduleConv2(ir::IRSchedule &ir_sch, ir_sch.Split(loops[3], {-1, x_param[1], x_param[2], x_param[3]}); all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[4]); + loops = ir_sch.GetLoops(all_blocks[4]); CHECK_GE(loops.size(), 3U); ir_sch.Split(loops[2], {-1, y_param[1], y_param[2], y_param[3]}); all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[4]); + loops = ir_sch.GetLoops(all_blocks[4]); CHECK_GE(loops.size(), 2U); ir_sch.Split(loops[1], {-1, f_param[1], f_param[2], f_param[3]}); all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[4]); + loops = ir_sch.GetLoops(all_blocks[4]); CHECK_GE(loops.size(), 13U); ir_sch.Reorder({loops[0], loops[1], @@ -1172,7 +1304,7 @@ void IRCudaScheduleConv2(ir::IRSchedule &ir_sch, loops[12]}); all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[4]); + loops = ir_sch.GetLoops(all_blocks[4]); CHECK_GE(loops.size(), 13U); ir_sch.Bind(loops[1], "blockIdx.z"); ir_sch.Bind(loops[2], "blockIdx.y"); @@ -1185,37 +1317,46 @@ void IRCudaScheduleConv2(ir::IRSchedule &ir_sch, ir_sch.Unroll(loops[12]); all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[4]); + loops = ir_sch.GetLoops(all_blocks[4]); CHECK_GE(loops.size(), 10U); ir_sch.ComputeAt(all_blocks[3], loops[9]); all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[3]); + loops = ir_sch.GetLoops(all_blocks[3]); CHECK_GE(loops.size(), 16U); ir_sch.Split(loops[15], {-1, rx_param[1]}); all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[3]); + loops = ir_sch.GetLoops(all_blocks[3]); CHECK_GE(loops.size(), 15U); ir_sch.Split(loops[14], {-1, ry_param[1]}); all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[3]); + loops = ir_sch.GetLoops(all_blocks[3]); CHECK_GE(loops.size(), 14U); ir_sch.Split(loops[13], {-1, rc_param[1]}); all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[3]); + loops = ir_sch.GetLoops(all_blocks[3]); CHECK_GE(loops.size(), 14U); - ir_sch.Reorder({loops[13], loops[15], loops[17], loops[14], loops[16], loops[18], loops[10], loops[11], loops[12]}); + ir_sch.Reorder({loops[13], + loops[15], + loops[17], + loops[14], + loops[16], + loops[18], + loops[10], + loops[11], + loops[12]}); all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[3]); + loops = ir_sch.GetLoops(all_blocks[3]); CHECK_GE(loops.size(), 13U); ir_sch.ComputeAt(all_blocks[0], loops[12]); all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[3]); + loops = ir_sch.GetLoops(all_blocks[3]); CHECK_GE(loops.size(), 13U); ir_sch.ComputeAt(all_blocks[1], loops[12]); // Work In Progress - VLOG(3) << "After IRCudaScheduleConv2, expr is: " << ir_sch.GetModule().GetExprs().at(0); + VLOG(3) << "After IRCudaScheduleConv2, expr is: " + << ir_sch.GetModule().GetExprs().at(0); } } // namespace pe diff --git a/paddle/cinn/hlir/pe/ir_schedule_pe.h b/paddle/cinn/hlir/pe/ir_schedule_pe.h index 4c07c7cb1ad9f..e7839fcc1ae57 100644 --- a/paddle/cinn/hlir/pe/ir_schedule_pe.h +++ b/paddle/cinn/hlir/pe/ir_schedule_pe.h @@ -31,9 +31,13 @@ namespace cinn { namespace hlir { namespace pe { -void IRElementwiseSchedule(ir::IRSchedule &ir_sch, const std::vector &output_shape, const common::Target &target); +void IRElementwiseSchedule(ir::IRSchedule &ir_sch, + const std::vector &output_shape, + const common::Target &target); -void IRInjectiveSchedule(ir::IRSchedule &ir_sch, const std::vector &output_shape, const common::Target &target); +void IRInjectiveSchedule(ir::IRSchedule &ir_sch, + const std::vector &output_shape, + const common::Target &target); void IRScheduleInjectiveCPU(ir::IRSchedule &ir_sch, const std::vector &output_shape, @@ -44,20 +48,28 @@ void IRCudaScheduleInjective(ir::IRSchedule &ir_sch, const std::vector &output_shape, const common::Target &target); -std::vector IRCudaScheduleMatMul(const common::CINNValuePack &arg_pack, - const std::vector &output_shape, - const common::Target &target); +std::vector IRCudaScheduleMatMul( + const common::CINNValuePack &arg_pack, + const std::vector &output_shape, + const common::Target &target); -void IRCudaScheduleMul(ir::IRSchedule &ir_sch, const std::vector &output_shape, const common::Target &target); +void IRCudaScheduleMul(ir::IRSchedule &ir_sch, + const std::vector &output_shape, + const common::Target &target); -void IRMulScheduleCPU(ir::IRSchedule &ir_sch, const std::vector &reduce_first_shape, const common::Target &target); +void IRMulScheduleCPU(ir::IRSchedule &ir_sch, + const std::vector &reduce_first_shape, + const common::Target &target); void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, const std::vector> &output_shapes, int axis, const common::Target &target); -void IRCudaScheduleReduce(ir::IRSchedule &ir_sch, ir::Tensor out, int last_dimension_num, const common::Target &target); +void IRCudaScheduleReduce(ir::IRSchedule &ir_sch, + ir::Tensor out, + int last_dimension_num, + const common::Target &target); void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, ir::Tensor reduce_tmp_out, @@ -70,8 +82,11 @@ void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, ir::Tensor out, const common::Target &target); -void IRCudaScheduleBlockShuffleReduce( - ir::IRSchedule &ir_sch, ir::Tensor reshape, ir::Tensor internal, ir::Tensor out, const common::Target &target); +void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, + ir::Tensor reshape, + ir::Tensor internal, + ir::Tensor out, + const common::Target &target); void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, ir::Tensor reshape, @@ -82,11 +97,15 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, void IRSoftmaxScheduleCPU(ir::IRSchedule &ir_sch, int axis = -1); -void IRPoolScheduleGPU(ir::IRSchedule &ir_sch, const common::Target &target, int arg_pack_size = 3); +void IRPoolScheduleGPU(ir::IRSchedule &ir_sch, + const common::Target &target, + int arg_pack_size = 3); -void IRCudaScheduleDepthwiseConv(ir::IRSchedule &ir_sch, const std::vector &tensors); +void IRCudaScheduleDepthwiseConv(ir::IRSchedule &ir_sch, + const std::vector &tensors); -void IRGlobalPoolScheduleGPU(ir::IRSchedule &ir_sch, const common::Target &target); +void IRGlobalPoolScheduleGPU(ir::IRSchedule &ir_sch, + const common::Target &target); void IRCudaScheduleConv2(ir::IRSchedule &ir_sch, ir::Tensor &input_pad, diff --git a/paddle/cinn/hlir/pe/load_params_test.cc b/paddle/cinn/hlir/pe/load_params_test.cc index bc05920b54f24..897e8186db4eb 100644 --- a/paddle/cinn/hlir/pe/load_params_test.cc +++ b/paddle/cinn/hlir/pe/load_params_test.cc @@ -22,24 +22,27 @@ namespace pe { using ir::Tensor; TEST(load_x86_params, load_x86_params) { - auto &res = ScheduleParam::get_x86_instance().GetParam(); - std::string key = "X86ScheduleConv input 1 3 224 224 weight 64 3 7 7 stride 2 2 padding 3 3 dilation 1 1"; + auto &res = ScheduleParam::get_x86_instance().GetParam(); + std::string key = + "X86ScheduleConv input 1 3 224 224 weight 64 3 7 7 stride 2 2 padding 3 " + "3 dilation 1 1"; ASSERT_EQ(res.count(key), 1); absl::flat_hash_map conv2d_factors; - auto target = common::DefaultHostTarget(); - std::vector shape_input = {1, 64, 56, 56}; + auto target = common::DefaultHostTarget(); + std::vector shape_input = {1, 64, 56, 56}; std::vector shape_weights = {64, 64, 3, 3}; - std::vector strides = {1, 1}; - std::vector pads = {1, 1}; - std::vector dilations = {1, 1}; - key = GenerateX86ConvKey(shape_input, shape_weights, strides, pads, dilations); + std::vector strides = {1, 1}; + std::vector pads = {1, 1}; + std::vector dilations = {1, 1}; + key = + GenerateX86ConvKey(shape_input, shape_weights, strides, pads, dilations); GetConv2dFactors(&conv2d_factors, -1, -1, -1, -1, -1, Float(32), target, key); int ic_bn_size = conv2d_factors["ic_bn"]; int oc_bn_size = conv2d_factors["oc_bn"]; int fc_bn_size = conv2d_factors["fc_bn"]; int ow_bn_size = conv2d_factors["ow_bn"]; - int unroll_kw = conv2d_factors["unroll_kw"]; + int unroll_kw = conv2d_factors["unroll_kw"]; ASSERT_EQ(ic_bn_size, 64); ASSERT_EQ(fc_bn_size, 64); ASSERT_EQ(oc_bn_size, 32); diff --git a/paddle/cinn/hlir/pe/load_x86_params.cc b/paddle/cinn/hlir/pe/load_x86_params.cc index 6de15b72096af..aa0fd02218f90 100644 --- a/paddle/cinn/hlir/pe/load_x86_params.cc +++ b/paddle/cinn/hlir/pe/load_x86_params.cc @@ -20,1279 +20,2302 @@ namespace cinn { namespace hlir { namespace pe { -void InputX86Param(absl::flat_hash_map>> *model_data, - const std::string &key, - const absl::flat_hash_map> &schedule_data) { +void InputX86Param( + absl::flat_hash_map>> + *model_data, + const std::string &key, + const absl::flat_hash_map> &schedule_data) { CHECK(model_data); (*model_data)[key] = schedule_data; } void LoadX86DefaultParams( - absl::flat_hash_map>> *model_data) { + absl::flat_hash_map>> + *model_data) { CHECK(model_data); // resnet 1 InputX86Param(model_data, - "X86ScheduleConv input 1 3 224 224 weight 64 3 7 7 stride 2 2 padding 3 3 dilation 1 1", - {{"ic_bn", {1, 3}}, {"oc_bn", {2, 32}}, {"ow_bn", {14, 8}}, {"unroll_kw", {0}}}); + "X86ScheduleConv input 1 3 224 224 weight 64 3 7 7 stride 2 2 " + "padding 3 3 dilation 1 1", + {{"ic_bn", {1, 3}}, + {"oc_bn", {2, 32}}, + {"ow_bn", {14, 8}}, + {"unroll_kw", {0}}}); // resnet 3 4 5 6 InputX86Param(model_data, - "X86ScheduleConv input 1 64 56 56 weight 64 64 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {1, 64}}, {"oc_bn", {2, 32}}, {"ow_bn", {8, 7}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 64 56 56 weight 64 64 3 3 stride 1 1 " + "padding 1 1 dilation 1 1", + {{"ic_bn", {1, 64}}, + {"oc_bn", {2, 32}}, + {"ow_bn", {8, 7}}, + {"unroll_kw", {1}}}); // resnet 8 InputX86Param(model_data, - "X86ScheduleConv input 1 64 56 56 weight 128 64 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {2, 32}}, {"oc_bn", {2, 64}}, {"ow_bn", {7, 4}}, {"unroll_kw", {0}}}); + "X86ScheduleConv input 1 64 56 56 weight 128 64 3 3 stride 2 2 " + "padding 1 1 dilation 1 1", + {{"ic_bn", {2, 32}}, + {"oc_bn", {2, 64}}, + {"ow_bn", {7, 4}}, + {"unroll_kw", {0}}}); // resnet 9 10 11 InputX86Param(model_data, - "X86ScheduleConv input 1 128 28 28 weight 128 128 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {1, 128}}, {"oc_bn", {4, 32}}, {"ow_bn", {4, 7}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 128 28 28 weight 128 128 3 3 stride 1 " + "1 padding 1 1 dilation 1 1", + {{"ic_bn", {1, 128}}, + {"oc_bn", {4, 32}}, + {"ow_bn", {4, 7}}, + {"unroll_kw", {1}}}); // resnet 7 InputX86Param(model_data, - "X86ScheduleConv input 1 64 56 56 weight 128 64 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {8, 8}}, {"oc_bn", {4, 32}}, {"ow_bn", {7, 4}}, {"oh_bn", {1}}}); + "X86ScheduleConv input 1 64 56 56 weight 128 64 1 1 stride 2 2 " + "padding 0 0 dilation 1 1", + {{"ic_bn", {8, 8}}, + {"oc_bn", {4, 32}}, + {"ow_bn", {7, 4}}, + {"oh_bn", {1}}}); // resnet 13 InputX86Param(model_data, - "X86ScheduleConv input 1 128 28 28 weight 256 128 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {16, 8}}, {"oc_bn", {8, 32}}, {"ow_bn", {2, 7}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 128 28 28 weight 256 128 3 3 stride 2 " + "2 padding 1 1 dilation 1 1", + {{"ic_bn", {16, 8}}, + {"oc_bn", {8, 32}}, + {"ow_bn", {2, 7}}, + {"unroll_kw", {1}}}); // resnet 14 15 16 InputX86Param(model_data, - "X86ScheduleConv input 1 256 14 14 weight 256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {2, 128}}, {"oc_bn", {16, 16}}, {"ow_bn", {1, 14}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 256 14 14 weight 256 256 3 3 stride 1 " + "1 padding 1 1 dilation 1 1", + {{"ic_bn", {2, 128}}, + {"oc_bn", {16, 16}}, + {"ow_bn", {1, 14}}, + {"unroll_kw", {1}}}); // resnet 12 InputX86Param(model_data, - "X86ScheduleConv input 1 128 28 28 weight 256 128 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {2, 64}}, {"oc_bn", {16, 16}}, {"ow_bn", {1, 14}}, {"oh_bn", {1}}}); + "X86ScheduleConv input 1 128 28 28 weight 256 128 1 1 stride 2 " + "2 padding 0 0 dilation 1 1", + {{"ic_bn", {2, 64}}, + {"oc_bn", {16, 16}}, + {"ow_bn", {1, 14}}, + {"oh_bn", {1}}}); // resnet 18 InputX86Param(model_data, - "X86ScheduleConv input 1 256 14 14 weight 512 256 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {32, 8}}, {"oc_bn", {16, 32}}, {"ow_bn", {1, 7}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 256 14 14 weight 512 256 3 3 stride 2 " + "2 padding 1 1 dilation 1 1", + {{"ic_bn", {32, 8}}, + {"oc_bn", {16, 32}}, + {"ow_bn", {1, 7}}, + {"unroll_kw", {1}}}); // resnet 19 20 21 InputX86Param(model_data, - "X86ScheduleConv input 1 512 7 7 weight 512 512 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {1, 512}}, {"oc_bn", {16, 32}}, {"ow_bn", {1, 7}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 512 7 7 weight 512 512 3 3 stride 1 1 " + "padding 1 1 dilation 1 1", + {{"ic_bn", {1, 512}}, + {"oc_bn", {16, 32}}, + {"ow_bn", {1, 7}}, + {"unroll_kw", {1}}}); // resnet 17 InputX86Param(model_data, - "X86ScheduleConv input 1 256 14 14 weight 512 256 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {2, 128}}, {"oc_bn", {16, 32}}, {"ow_bn", {1, 7}}, {"oh_bn", {1}}}); + "X86ScheduleConv input 1 256 14 14 weight 512 256 1 1 stride 2 " + "2 padding 0 0 dilation 1 1", + {{"ic_bn", {2, 128}}, + {"oc_bn", {16, 32}}, + {"ow_bn", {1, 7}}, + {"oh_bn", {1}}}); // resnet 2 InputX86Param(model_data, - "X86ScheduleConv input 1 64 56 56 weight 64 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {4, 16}}, {"oc_bn", {2, 32}}, {"ow_bn", {4, 14}}, {"oh_bn", {1}}}); + "X86ScheduleConv input 1 64 56 56 weight 64 64 1 1 stride 1 1 " + "padding 0 0 dilation 1 1", + {{"ic_bn", {4, 16}}, + {"oc_bn", {2, 32}}, + {"ow_bn", {4, 14}}, + {"oh_bn", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 64 56 56 weight 256 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {16, 4}}, {"oc_bn", {8, 32}}, {"ow_bn", {8, 7}}, {"oh_bn", {1}}}); + "X86ScheduleConv input 1 64 56 56 weight 256 64 1 1 stride 1 1 " + "padding 0 0 dilation 1 1", + {{"ic_bn", {16, 4}}, + {"oc_bn", {8, 32}}, + {"ow_bn", {8, 7}}, + {"oh_bn", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 256 56 56 weight 64 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {1, 256}}, {"oc_bn", {2, 32}}, {"ow_bn", {8, 7}}, {"oh_bn", {1}}}); + "X86ScheduleConv input 1 256 56 56 weight 64 256 1 1 stride 1 " + "1 padding 0 0 dilation 1 1", + {{"ic_bn", {1, 256}}, + {"oc_bn", {2, 32}}, + {"ow_bn", {8, 7}}, + {"oh_bn", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 256 56 56 weight 128 256 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {1, 256}}, {"oc_bn", {4, 32}}, {"ow_bn", {4, 7}}, {"oh_bn", {1}}}); + "X86ScheduleConv input 1 256 56 56 weight 128 256 1 1 stride 2 " + "2 padding 0 0 dilation 1 1", + {{"ic_bn", {1, 256}}, + {"oc_bn", {4, 32}}, + {"ow_bn", {4, 7}}, + {"oh_bn", {1}}}); // resnet 50 InputX86Param(model_data, - "X86ScheduleConv input 1 256 56 56 weight 512 256 1 1 stride 2 2 padding 0 0 dilation 1 1", + "X86ScheduleConv input 1 256 56 56 weight 512 256 1 1 stride 2 " + "2 padding 0 0 dilation 1 1", // Todo: tempory fix, enhance alterlayout and test performance - {{"ic_bn", {1, 256}}, {"oc_bn", {16, 32}}, {"ow_bn", {7, 4}}, {"oh_bn", {1}}}); - // {{"ic_bn", {1, 256}}, {"oc_bn", {8, 64}}, {"ow_bn", {7, 4}}, {"oh_bn", {1}}}); - // resnet50 + {{"ic_bn", {1, 256}}, + {"oc_bn", {16, 32}}, + {"ow_bn", {7, 4}}, + {"oh_bn", {1}}}); + // {{"ic_bn", {1, 256}}, {"oc_bn", {8, 64}}, {"ow_bn", {7, 4}}, {"oh_bn", + // {1}}}); resnet50 InputX86Param(model_data, - "X86ScheduleConv input 1 128 28 28 weight 512 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {32, 4}}, {"oc_bn", {16, 32}}, {"ow_bn", {4, 7}}, {"oh_bn", {1}}}); + "X86ScheduleConv input 1 128 28 28 weight 512 128 1 1 stride 1 " + "1 padding 0 0 dilation 1 1", + {{"ic_bn", {32, 4}}, + {"oc_bn", {16, 32}}, + {"ow_bn", {4, 7}}, + {"oh_bn", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 512 28 28 weight 128 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {1, 512}}, {"oc_bn", {2, 64}}, {"ow_bn", {7, 4}}, {"oh_bn", {1}}}); + "X86ScheduleConv input 1 512 28 28 weight 128 512 1 1 stride 1 " + "1 padding 0 0 dilation 1 1", + {{"ic_bn", {1, 512}}, + {"oc_bn", {2, 64}}, + {"ow_bn", {7, 4}}, + {"oh_bn", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 512 28 28 weight 256 512 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {8, 64}}, {"oc_bn", {4, 64}}, {"ow_bn", {7, 2}}, {"oh_bn", {2}}}); + "X86ScheduleConv input 1 512 28 28 weight 256 512 1 1 stride 2 " + "2 padding 0 0 dilation 1 1", + {{"ic_bn", {8, 64}}, + {"oc_bn", {4, 64}}, + {"ow_bn", {7, 2}}, + {"oh_bn", {2}}}); // resnet 50 InputX86Param(model_data, - "X86ScheduleConv input 1 512 28 28 weight 1024 512 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {1, 512}}, {"oc_bn", {16, 64}}, {"ow_bn", {7, 2}}, {"oh_bn", {2}}}); + "X86ScheduleConv input 1 512 28 28 weight 1024 512 1 1 stride " + "2 2 padding 0 0 dilation 1 1", + {{"ic_bn", {1, 512}}, + {"oc_bn", {16, 64}}, + {"ow_bn", {7, 2}}, + {"oh_bn", {2}}}); // resnet 50 InputX86Param(model_data, - "X86ScheduleConv input 1 256 14 14 weight 1024 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {1, 256}}, {"oc_bn", {16, 64}}, {"ow_bn", {7, 2}}, {"oh_bn", {2}}}); + "X86ScheduleConv input 1 256 14 14 weight 1024 256 1 1 stride " + "1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {1, 256}}, + {"oc_bn", {16, 64}}, + {"ow_bn", {7, 2}}, + {"oh_bn", {2}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 1024 14 14 weight 256 1024 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {2, 512}}, {"oc_bn", {4, 64}}, {"ow_bn", {7, 2}}, {"oh_bn", {2}}}); + "X86ScheduleConv input 1 1024 14 14 weight 256 1024 1 1 stride " + "2 2 padding 0 0 dilation 1 1", + {{"ic_bn", {2, 512}}, + {"oc_bn", {4, 64}}, + {"ow_bn", {7, 2}}, + {"oh_bn", {2}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 1024 14 14 weight 512 1024 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {2, 512}}, {"oc_bn", {16, 32}}, {"ow_bn", {1, 7}}, {"oh_bn", {1}}}); + "X86ScheduleConv input 1 1024 14 14 weight 512 1024 1 1 stride " + "2 2 padding 0 0 dilation 1 1", + {{"ic_bn", {2, 512}}, + {"oc_bn", {16, 32}}, + {"ow_bn", {1, 7}}, + {"oh_bn", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 1024 14 14 weight 2048 1024 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {1, 1024}}, {"oc_bn", {64, 32}}, {"ow_bn", {1, 7}}, {"oh_bn", {1}}}); + "X86ScheduleConv input 1 1024 14 14 weight 2048 1024 1 1 " + "stride 2 2 padding 0 0 dilation 1 1", + {{"ic_bn", {1, 1024}}, + {"oc_bn", {64, 32}}, + {"ow_bn", {1, 7}}, + {"oh_bn", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 512 7 7 weight 2048 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {128, 4}}, {"oc_bn", {64, 32}}, {"ow_bn", {1, 7}}, {"oh_bn", {1}}}); + "X86ScheduleConv input 1 512 7 7 weight 2048 512 1 1 stride 1 " + "1 padding 0 0 dilation 1 1", + {{"ic_bn", {128, 4}}, + {"oc_bn", {64, 32}}, + {"ow_bn", {1, 7}}, + {"oh_bn", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 2048 7 7 weight 512 2048 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {512, 4}}, {"oc_bn", {16, 32}}, {"ow_bn", {1, 7}}, {"oh_bn", {1}}}); + "X86ScheduleConv input 1 2048 7 7 weight 512 2048 1 1 stride 1 " + "1 padding 0 0 dilation 1 1", + {{"ic_bn", {512, 4}}, + {"oc_bn", {16, 32}}, + {"ow_bn", {1, 7}}, + {"oh_bn", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 3 224 224 weight 64 3 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {1, 3}}, {"oc_bn", {2, 32}}, {"ow_bn", {28, 8}}, {"unroll_kw", {0}}}); + "X86ScheduleConv input 1 3 224 224 weight 64 3 3 3 stride 1 1 " + "padding 1 1 dilation 1 1", + {{"ic_bn", {1, 3}}, + {"oc_bn", {2, 32}}, + {"ow_bn", {28, 8}}, + {"unroll_kw", {0}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 64 224 224 weight 64 64 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {4, 16}}, {"oc_bn", {2, 32}}, {"ow_bn", {28, 8}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 64 224 224 weight 64 64 3 3 stride 1 " + "1 padding 1 1 dilation 1 1", + {{"ic_bn", {4, 16}}, + {"oc_bn", {2, 32}}, + {"ow_bn", {28, 8}}, + {"unroll_kw", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 64 112 112 weight 128 64 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {2, 32}}, {"oc_bn", {2, 64}}, {"ow_bn", {28, 4}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 64 112 112 weight 128 64 3 3 stride 1 " + "1 padding 1 1 dilation 1 1", + {{"ic_bn", {2, 32}}, + {"oc_bn", {2, 64}}, + {"ow_bn", {28, 4}}, + {"unroll_kw", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 128 112 112 weight 128 128 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {2, 64}}, {"oc_bn", {2, 64}}, {"ow_bn", {28, 4}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 128 112 112 weight 128 128 3 3 stride " + "1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {2, 64}}, + {"oc_bn", {2, 64}}, + {"ow_bn", {28, 4}}, + {"unroll_kw", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 128 56 56 weight 256 128 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {4, 32}}, {"oc_bn", {8, 32}}, {"ow_bn", {7, 8}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 128 56 56 weight 256 128 3 3 stride 1 " + "1 padding 1 1 dilation 1 1", + {{"ic_bn", {4, 32}}, + {"oc_bn", {8, 32}}, + {"ow_bn", {7, 8}}, + {"unroll_kw", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 256 56 56 weight 256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {1, 256}}, {"oc_bn", {8, 32}}, {"ow_bn", {7, 8}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 256 56 56 weight 256 256 3 3 stride 1 " + "1 padding 1 1 dilation 1 1", + {{"ic_bn", {1, 256}}, + {"oc_bn", {8, 32}}, + {"ow_bn", {7, 8}}, + {"unroll_kw", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 256 28 28 weight 512 256 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {1, 256}}, {"oc_bn", {16, 32}}, {"ow_bn", {4, 7}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 256 28 28 weight 512 256 3 3 stride 1 " + "1 padding 1 1 dilation 1 1", + {{"ic_bn", {1, 256}}, + {"oc_bn", {16, 32}}, + {"ow_bn", {4, 7}}, + {"unroll_kw", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 512 28 28 weight 512 512 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {1, 512}}, {"oc_bn", {32, 16}}, {"ow_bn", {2, 14}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 512 28 28 weight 512 512 3 3 stride 1 " + "1 padding 1 1 dilation 1 1", + {{"ic_bn", {1, 512}}, + {"oc_bn", {32, 16}}, + {"ow_bn", {2, 14}}, + {"unroll_kw", {1}}}); InputX86Param(model_data, - "X86ScheduleConv input 1 512 14 14 weight 512 512 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {1, 512}}, {"oc_bn", {32, 16}}, {"ow_bn", {1, 14}}, {"unroll_kw", {1}}}); + "X86ScheduleConv input 1 512 14 14 weight 512 512 3 3 stride 1 " + "1 padding 1 1 dilation 1 1", + {{"ic_bn", {1, 512}}, + {"oc_bn", {32, 16}}, + {"ow_bn", {1, 14}}, + {"unroll_kw", {1}}}); } void LoadResNet18Params( - absl::flat_hash_map>> *model_data) { + absl::flat_hash_map>> + *model_data) { CHECK(model_data); - InputX86Param( - model_data, - "resnet18 index 0 X86ScheduleConv input 1 3 224 224 weight 64 3 7 7 stride 2 2 padding 3 3 dilation 1 1", - {{"ic_bn", {-1, 1}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "resnet18 index 1 X86ScheduleConv input 1 64 56 56 weight 64 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet18 index 2 X86ScheduleConv input 1 64 56 56 weight 64 64 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 3 X86ScheduleConv input 1 64 56 56 weight 64 64 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 4 X86ScheduleConv input 1 64 56 56 weight 64 64 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 5 X86ScheduleConv input 1 64 56 56 weight 64 64 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 6 X86ScheduleConv input 1 64 56 56 weight 128 64 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 4}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet18 index 7 X86ScheduleConv input 1 64 56 56 weight 128 64 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 8 X86ScheduleConv input 1 128 28 28 weight 128 128 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 9 X86ScheduleConv input 1 128 28 28 weight 128 128 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 10 X86ScheduleConv input 1 128 28 28 weight 128 128 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 11 X86ScheduleConv input 1 128 28 28 weight 256 128 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet18 index 12 X86ScheduleConv input 1 128 28 28 weight 256 128 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 13 X86ScheduleConv input 1 256 14 14 weight 256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 14 X86ScheduleConv input 1 256 14 14 weight 256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 15 X86ScheduleConv input 1 256 14 14 weight 256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 16 X86ScheduleConv input 1 256 14 14 weight 512 256 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet18 index 17 X86ScheduleConv input 1 256 14 14 weight 512 256 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 18 X86ScheduleConv input 1 512 7 7 weight 512 512 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 19 X86ScheduleConv input 1 512 7 7 weight 512 512 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet18 index 20 X86ScheduleConv input 1 512 7 7 weight 512 512 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 0 X86ScheduleConv input 1 3 224 224 weight 64 " + "3 7 7 stride 2 2 padding 3 3 dilation 1 1", + {{"ic_bn", {-1, 1}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "resnet18 index 1 X86ScheduleConv input 1 64 56 56 weight 64 " + "64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet18 index 2 X86ScheduleConv input 1 64 56 56 weight 64 " + "64 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 3 X86ScheduleConv input 1 64 56 56 weight 64 " + "64 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 4 X86ScheduleConv input 1 64 56 56 weight 64 " + "64 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 5 X86ScheduleConv input 1 64 56 56 weight 64 " + "64 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 6 X86ScheduleConv input 1 64 56 56 weight 128 " + "64 1 1 stride 2 2 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 4}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet18 index 7 X86ScheduleConv input 1 64 56 56 weight 128 " + "64 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 8 X86ScheduleConv input 1 128 28 28 weight 128 " + "128 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 9 X86ScheduleConv input 1 128 28 28 weight 128 " + "128 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 10 X86ScheduleConv input 1 128 28 28 weight " + "128 128 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 11 X86ScheduleConv input 1 128 28 28 weight " + "256 128 1 1 stride 2 2 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet18 index 12 X86ScheduleConv input 1 128 28 28 weight " + "256 128 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 13 X86ScheduleConv input 1 256 14 14 weight " + "256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 14 X86ScheduleConv input 1 256 14 14 weight " + "256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 15 X86ScheduleConv input 1 256 14 14 weight " + "256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 16 X86ScheduleConv input 1 256 14 14 weight " + "512 256 1 1 stride 2 2 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet18 index 17 X86ScheduleConv input 1 256 14 14 weight " + "512 256 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 18 X86ScheduleConv input 1 512 7 7 weight 512 " + "512 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 19 X86ScheduleConv input 1 512 7 7 weight 512 " + "512 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet18 index 20 X86ScheduleConv input 1 512 7 7 weight 512 " + "512 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); } void LoadResNet50Params( - absl::flat_hash_map>> *model_data) { + absl::flat_hash_map>> + *model_data) { CHECK(model_data); - InputX86Param( - model_data, - "resnet50 index 0 X86ScheduleConv input 1 3 224 224 weight 64 3 7 7 stride 2 2 padding 3 3 dilation 1 1", - {{"ic_bn", {-1, 1}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "resnet50 index 1 X86ScheduleConv input 1 64 56 56 weight 256 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 2 X86ScheduleConv input 1 64 56 56 weight 64 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet50 index 3 X86ScheduleConv input 1 64 56 56 weight 64 64 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 4 X86ScheduleConv input 1 64 56 56 weight 256 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 5 X86ScheduleConv input 1 256 56 56 weight 64 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 6 X86ScheduleConv input 1 64 56 56 weight 64 64 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 7 X86ScheduleConv input 1 64 56 56 weight 256 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 8 X86ScheduleConv input 1 256 56 56 weight 64 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 9 X86ScheduleConv input 1 64 56 56 weight 64 64 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 10 X86ScheduleConv input 1 64 56 56 weight 256 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 11 X86ScheduleConv input 1 256 56 56 weight 512 256 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet50 index 12 X86ScheduleConv input 1 256 56 56 weight 128 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 13 X86ScheduleConv input 1 128 56 56 weight 128 128 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 14 X86ScheduleConv input 1 128 28 28 weight 512 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 15 X86ScheduleConv input 1 512 28 28 weight 128 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet50 index 16 X86ScheduleConv input 1 128 28 28 weight 128 128 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 17 X86ScheduleConv input 1 128 28 28 weight 512 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 18 X86ScheduleConv input 1 512 28 28 weight 128 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet50 index 19 X86ScheduleConv input 1 128 28 28 weight 128 128 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 20 X86ScheduleConv input 1 128 28 28 weight 512 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 21 X86ScheduleConv input 1 512 28 28 weight 128 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet50 index 22 X86ScheduleConv input 1 128 28 28 weight 128 128 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 23 X86ScheduleConv input 1 128 28 28 weight 512 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 24 X86ScheduleConv input 1 512 28 28 weight 1024 512 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 128}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 25 X86ScheduleConv input 1 512 28 28 weight 256 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet50 index 26 X86ScheduleConv input 1 256 28 28 weight 256 256 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 27 X86ScheduleConv input 1 256 14 14 weight 1024 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 28 X86ScheduleConv input 1 1024 14 14 weight 256 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet50 index 29 X86ScheduleConv input 1 256 14 14 weight 256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 30 X86ScheduleConv input 1 256 14 14 weight 1024 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 31 X86ScheduleConv input 1 1024 14 14 weight 256 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet50 index 32 X86ScheduleConv input 1 256 14 14 weight 256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 33 X86ScheduleConv input 1 256 14 14 weight 1024 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 34 X86ScheduleConv input 1 1024 14 14 weight 256 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet50 index 35 X86ScheduleConv input 1 256 14 14 weight 256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 36 X86ScheduleConv input 1 256 14 14 weight 1024 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 37 X86ScheduleConv input 1 1024 14 14 weight 256 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet50 index 38 X86ScheduleConv input 1 256 14 14 weight 256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 39 X86ScheduleConv input 1 256 14 14 weight 1024 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 40 X86ScheduleConv input 1 1024 14 14 weight 256 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "resnet50 index 41 X86ScheduleConv input 1 256 14 14 weight 256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 42 X86ScheduleConv input 1 256 14 14 weight 1024 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 43 X86ScheduleConv input 1 1024 14 14 weight 2048 1024 1 1 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 44 X86ScheduleConv input 1 1024 14 14 weight 512 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 128}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 45 X86ScheduleConv input 1 512 14 14 weight 512 512 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 46 X86ScheduleConv input 1 512 7 7 weight 2048 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 4}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 47 X86ScheduleConv input 1 2048 7 7 weight 512 2048 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 48 X86ScheduleConv input 1 512 7 7 weight 512 512 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 49 X86ScheduleConv input 1 512 7 7 weight 2048 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 4}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 50 X86ScheduleConv input 1 2048 7 7 weight 512 2048 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "resnet50 index 51 X86ScheduleConv input 1 512 7 7 weight 512 512 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "resnet50 index 52 X86ScheduleConv input 1 512 7 7 weight 2048 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 4}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 0 X86ScheduleConv input 1 3 224 224 weight 64 " + "3 7 7 stride 2 2 padding 3 3 dilation 1 1", + {{"ic_bn", {-1, 1}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "resnet50 index 1 X86ScheduleConv input 1 64 56 56 weight 256 " + "64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 2 X86ScheduleConv input 1 64 56 56 weight 64 " + "64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet50 index 3 X86ScheduleConv input 1 64 56 56 weight 64 " + "64 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 4 X86ScheduleConv input 1 64 56 56 weight 256 " + "64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 5 X86ScheduleConv input 1 256 56 56 weight 64 " + "256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 6 X86ScheduleConv input 1 64 56 56 weight 64 " + "64 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 7 X86ScheduleConv input 1 64 56 56 weight 256 " + "64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 8 X86ScheduleConv input 1 256 56 56 weight 64 " + "256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 9 X86ScheduleConv input 1 64 56 56 weight 64 " + "64 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 10 X86ScheduleConv input 1 64 56 56 weight 256 " + "64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 11 X86ScheduleConv input 1 256 56 56 weight " + "512 256 1 1 stride 2 2 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet50 index 12 X86ScheduleConv input 1 256 56 56 weight " + "128 256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 13 X86ScheduleConv input 1 128 56 56 weight " + "128 128 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 14 X86ScheduleConv input 1 128 28 28 weight " + "512 128 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 15 X86ScheduleConv input 1 512 28 28 weight " + "128 512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet50 index 16 X86ScheduleConv input 1 128 28 28 weight " + "128 128 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 17 X86ScheduleConv input 1 128 28 28 weight " + "512 128 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 18 X86ScheduleConv input 1 512 28 28 weight " + "128 512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet50 index 19 X86ScheduleConv input 1 128 28 28 weight " + "128 128 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 20 X86ScheduleConv input 1 128 28 28 weight " + "512 128 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 21 X86ScheduleConv input 1 512 28 28 weight " + "128 512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet50 index 22 X86ScheduleConv input 1 128 28 28 weight " + "128 128 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 23 X86ScheduleConv input 1 128 28 28 weight " + "512 128 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 24 X86ScheduleConv input 1 512 28 28 weight " + "1024 512 1 1 stride 2 2 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 128}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 25 X86ScheduleConv input 1 512 28 28 weight " + "256 512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet50 index 26 X86ScheduleConv input 1 256 28 28 weight " + "256 256 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 27 X86ScheduleConv input 1 256 14 14 weight " + "1024 256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 28 X86ScheduleConv input 1 1024 14 14 weight " + "256 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet50 index 29 X86ScheduleConv input 1 256 14 14 weight " + "256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 30 X86ScheduleConv input 1 256 14 14 weight " + "1024 256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 31 X86ScheduleConv input 1 1024 14 14 weight " + "256 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet50 index 32 X86ScheduleConv input 1 256 14 14 weight " + "256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 33 X86ScheduleConv input 1 256 14 14 weight " + "1024 256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 34 X86ScheduleConv input 1 1024 14 14 weight " + "256 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet50 index 35 X86ScheduleConv input 1 256 14 14 weight " + "256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 36 X86ScheduleConv input 1 256 14 14 weight " + "1024 256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 37 X86ScheduleConv input 1 1024 14 14 weight " + "256 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet50 index 38 X86ScheduleConv input 1 256 14 14 weight " + "256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 39 X86ScheduleConv input 1 256 14 14 weight " + "1024 256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 40 X86ScheduleConv input 1 1024 14 14 weight " + "256 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "resnet50 index 41 X86ScheduleConv input 1 256 14 14 weight " + "256 256 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 42 X86ScheduleConv input 1 256 14 14 weight " + "1024 256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 43 X86ScheduleConv input 1 1024 14 14 weight " + "2048 1024 1 1 stride 2 2 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 44 X86ScheduleConv input 1 1024 14 14 weight " + "512 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 128}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 45 X86ScheduleConv input 1 512 14 14 weight " + "512 512 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 46 X86ScheduleConv input 1 512 7 7 weight 2048 " + "512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 4}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 47 X86ScheduleConv input 1 2048 7 7 weight 512 " + "2048 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 48 X86ScheduleConv input 1 512 7 7 weight 512 " + "512 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 49 X86ScheduleConv input 1 512 7 7 weight 2048 " + "512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 4}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 50 X86ScheduleConv input 1 2048 7 7 weight 512 " + "2048 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "resnet50 index 51 X86ScheduleConv input 1 512 7 7 weight 512 " + "512 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "resnet50 index 52 X86ScheduleConv input 1 512 7 7 weight 2048 " + "512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 4}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); } void LoadMobileNetV1Params( - absl::flat_hash_map>> *model_data) { + absl::flat_hash_map>> + *model_data) { CHECK(model_data); - InputX86Param( - model_data, - "mobilenetv1 index 0 X86ScheduleConv input 1 3 224 224 weight 32 3 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 1}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "mobilenetv1 index 1 X86ScheduleConv input 1 32 112 112 weight 32 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv1 index 2 X86ScheduleConv input 1 32 112 112 weight 64 32 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv1 index 3 X86ScheduleConv input 1 64 112 112 weight 64 1 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv1 index 4 X86ScheduleConv input 1 64 56 56 weight 128 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv1 index 5 X86ScheduleConv input 1 128 56 56 weight 128 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv1 index 6 X86ScheduleConv input 1 128 56 56 weight 128 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv1 index 7 X86ScheduleConv input 1 128 56 56 weight 128 1 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "mobilenetv1 index 8 X86ScheduleConv input 1 128 28 28 weight 256 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv1 index 9 X86ScheduleConv input 1 256 28 28 weight 256 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 4}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv1 index 10 X86ScheduleConv input 1 256 28 28 weight 256 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv1 index 11 X86ScheduleConv input 1 256 28 28 weight 256 1 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv1 index 12 X86ScheduleConv input 1 256 14 14 weight 512 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv1 index 13 X86ScheduleConv input 1 512 14 14 weight 512 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv1 index 14 X86ScheduleConv input 1 512 14 14 weight 512 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv1 index 15 X86ScheduleConv input 1 512 14 14 weight 512 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv1 index 16 X86ScheduleConv input 1 512 14 14 weight 512 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv1 index 17 X86ScheduleConv input 1 512 14 14 weight 512 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv1 index 18 X86ScheduleConv input 1 512 14 14 weight 512 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv1 index 19 X86ScheduleConv input 1 512 14 14 weight 512 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv1 index 20 X86ScheduleConv input 1 512 14 14 weight 512 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv1 index 21 X86ScheduleConv input 1 512 14 14 weight 512 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv1 index 22 X86ScheduleConv input 1 512 14 14 weight 512 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv1 index 23 X86ScheduleConv input 1 512 14 14 weight 512 1 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 4}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "mobilenetv1 index 24 X86ScheduleConv input 1 512 7 7 weight 1024 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 4}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv1 index 25 X86ScheduleConv input 1 1024 7 7 weight 1024 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv1 index 26 X86ScheduleConv input 1 1024 7 7 weight 1024 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 4}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv1 index 0 X86ScheduleConv input 1 3 224 224 weight " + "32 3 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 1}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "mobilenetv1 index 1 X86ScheduleConv input 1 32 112 112 weight " + "32 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv1 index 2 X86ScheduleConv input 1 32 112 112 weight " + "64 32 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv1 index 3 X86ScheduleConv input 1 64 112 112 weight " + "64 1 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv1 index 4 X86ScheduleConv input 1 64 56 56 weight " + "128 64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv1 index 5 X86ScheduleConv input 1 128 56 56 weight " + "128 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv1 index 6 X86ScheduleConv input 1 128 56 56 weight " + "128 128 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv1 index 7 X86ScheduleConv input 1 128 56 56 weight " + "128 1 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "mobilenetv1 index 8 X86ScheduleConv input 1 128 28 28 weight " + "256 128 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv1 index 9 X86ScheduleConv input 1 256 28 28 weight " + "256 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 4}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv1 index 10 X86ScheduleConv input 1 256 28 28 weight " + "256 256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv1 index 11 X86ScheduleConv input 1 256 28 28 weight " + "256 1 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv1 index 12 X86ScheduleConv input 1 256 14 14 weight " + "512 256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv1 index 13 X86ScheduleConv input 1 512 14 14 weight " + "512 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv1 index 14 X86ScheduleConv input 1 512 14 14 weight " + "512 512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv1 index 15 X86ScheduleConv input 1 512 14 14 weight " + "512 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv1 index 16 X86ScheduleConv input 1 512 14 14 weight " + "512 512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv1 index 17 X86ScheduleConv input 1 512 14 14 weight " + "512 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv1 index 18 X86ScheduleConv input 1 512 14 14 weight " + "512 512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv1 index 19 X86ScheduleConv input 1 512 14 14 weight " + "512 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv1 index 20 X86ScheduleConv input 1 512 14 14 weight " + "512 512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv1 index 21 X86ScheduleConv input 1 512 14 14 weight " + "512 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv1 index 22 X86ScheduleConv input 1 512 14 14 weight " + "512 512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv1 index 23 X86ScheduleConv input 1 512 14 14 weight " + "512 1 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 4}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "mobilenetv1 index 24 X86ScheduleConv input 1 512 7 7 weight " + "1024 512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 4}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv1 index 25 X86ScheduleConv input 1 1024 7 7 weight " + "1024 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv1 index 26 X86ScheduleConv input 1 1024 7 7 weight " + "1024 1024 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 4}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); } void LoadMobileNetV2Params( - absl::flat_hash_map>> *model_data) { + absl::flat_hash_map>> + *model_data) { CHECK(model_data); - InputX86Param( - model_data, - "mobilenetv2 index 0 X86ScheduleConv input 1 3 224 224 weight 32 3 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 1}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 1 X86ScheduleConv input 1 32 112 112 weight 32 32 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 2 X86ScheduleConv input 1 32 112 112 weight 32 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 3 X86ScheduleConv input 1 32 112 112 weight 16 32 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 16}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 4 X86ScheduleConv input 1 16 112 112 weight 96 16 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 4}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 5 X86ScheduleConv input 1 96 112 112 weight 96 1 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 2}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 6 X86ScheduleConv input 1 96 56 56 weight 24 96 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 14}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 7 X86ScheduleConv input 1 24 56 56 weight 144 24 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 14}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 8 X86ScheduleConv input 1 144 56 56 weight 144 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv2 index 9 X86ScheduleConv input 1 144 56 56 weight 24 144 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 28}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 10 X86ScheduleConv input 1 24 56 56 weight 144 24 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 14}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 11 X86ScheduleConv input 1 144 56 56 weight 144 1 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 2}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 12 X86ScheduleConv input 1 144 28 28 weight 32 144 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 13 X86ScheduleConv input 1 32 28 28 weight 192 32 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 14 X86ScheduleConv input 1 192 28 28 weight 192 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv2 index 15 X86ScheduleConv input 1 192 28 28 weight 32 192 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 16 X86ScheduleConv input 1 32 28 28 weight 192 32 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 17 X86ScheduleConv input 1 192 28 28 weight 192 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv2 index 18 X86ScheduleConv input 1 192 28 28 weight 32 192 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 19 X86ScheduleConv input 1 32 28 28 weight 192 32 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 20 X86ScheduleConv input 1 192 28 28 weight 192 1 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 21 X86ScheduleConv input 1 192 14 14 weight 64 192 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 96}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 22 X86ScheduleConv input 1 64 14 14 weight 384 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 23 X86ScheduleConv input 1 384 14 14 weight 384 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv2 index 24 X86ScheduleConv input 1 384 14 14 weight 64 384 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 25 X86ScheduleConv input 1 64 14 14 weight 384 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 26 X86ScheduleConv input 1 384 14 14 weight 384 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv2 index 27 X86ScheduleConv input 1 384 14 14 weight 64 384 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 28 X86ScheduleConv input 1 64 14 14 weight 384 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 29 X86ScheduleConv input 1 384 14 14 weight 384 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv2 index 30 X86ScheduleConv input 1 384 14 14 weight 64 384 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 31 X86ScheduleConv input 1 64 14 14 weight 384 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 32 X86ScheduleConv input 1 384 14 14 weight 384 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv2 index 33 X86ScheduleConv input 1 384 14 14 weight 96 384 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 34 X86ScheduleConv input 1 96 14 14 weight 576 96 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 35 X86ScheduleConv input 1 576 14 14 weight 576 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 36 X86ScheduleConv input 1 576 14 14 weight 96 576 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 37 X86ScheduleConv input 1 96 14 14 weight 576 96 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 38 X86ScheduleConv input 1 576 14 14 weight 576 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 39 X86ScheduleConv input 1 576 14 14 weight 96 576 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 40 X86ScheduleConv input 1 96 14 14 weight 576 96 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 41 X86ScheduleConv input 1 576 14 14 weight 576 1 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "mobilenetv2 index 42 X86ScheduleConv input 1 576 7 7 weight 160 576 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 3}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 43 X86ScheduleConv input 1 160 7 7 weight 960 160 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 80}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 44 X86ScheduleConv input 1 960 7 7 weight 960 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 192}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - // InputX86Param(model_data, "mobilenetv2 index 45 X86ScheduleConv input 1 960 7 7 weight 160 960 1 1 stride 1 1 - // padding 0 0 dilation 1 1", {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 45 X86ScheduleConv input 1 960 7 7 weight 160 960 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 46 X86ScheduleConv input 1 160 7 7 weight 960 160 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 80}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 47 X86ScheduleConv input 1 960 7 7 weight 960 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 192}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - // InputX86Param(model_data, "mobilenetv2 index 48 X86ScheduleConv input 1 960 7 7 weight 160 960 1 1 stride 1 1 - // padding 0 0 dilation 1 1", {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 48 X86ScheduleConv input 1 960 7 7 weight 160 960 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 49 X86ScheduleConv input 1 160 7 7 weight 960 160 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 80}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 50 X86ScheduleConv input 1 960 7 7 weight 960 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 80}}, {"oc_bn", {-1, 80}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "mobilenetv2 index 51 X86ScheduleConv input 1 960 7 7 weight 320 960 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 80}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "mobilenetv2 index 52 X86ScheduleConv input 1 320 7 7 weight 1280 320 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 4}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 0 X86ScheduleConv input 1 3 224 224 weight " + "32 3 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 1}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 1 X86ScheduleConv input 1 32 112 112 weight " + "32 32 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 2 X86ScheduleConv input 1 32 112 112 weight " + "32 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 3 X86ScheduleConv input 1 32 112 112 weight " + "16 32 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 16}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 4 X86ScheduleConv input 1 16 112 112 weight " + "96 16 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 5 X86ScheduleConv input 1 96 112 112 weight " + "96 1 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 2}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 6 X86ScheduleConv input 1 96 56 56 weight " + "24 96 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 7 X86ScheduleConv input 1 24 56 56 weight " + "144 24 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 8 X86ScheduleConv input 1 144 56 56 weight " + "144 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv2 index 9 X86ScheduleConv input 1 144 56 56 weight " + "24 144 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 28}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 10 X86ScheduleConv input 1 24 56 56 weight " + "144 24 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 11 X86ScheduleConv input 1 144 56 56 weight " + "144 1 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 2}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 12 X86ScheduleConv input 1 144 28 28 weight " + "32 144 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 13 X86ScheduleConv input 1 32 28 28 weight " + "192 32 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 14 X86ScheduleConv input 1 192 28 28 weight " + "192 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv2 index 15 X86ScheduleConv input 1 192 28 28 weight " + "32 192 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 16 X86ScheduleConv input 1 32 28 28 weight " + "192 32 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 17 X86ScheduleConv input 1 192 28 28 weight " + "192 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv2 index 18 X86ScheduleConv input 1 192 28 28 weight " + "32 192 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 19 X86ScheduleConv input 1 32 28 28 weight " + "192 32 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 20 X86ScheduleConv input 1 192 28 28 weight " + "192 1 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 21 X86ScheduleConv input 1 192 14 14 weight " + "64 192 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 96}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 22 X86ScheduleConv input 1 64 14 14 weight " + "384 64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 23 X86ScheduleConv input 1 384 14 14 weight " + "384 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv2 index 24 X86ScheduleConv input 1 384 14 14 weight " + "64 384 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 25 X86ScheduleConv input 1 64 14 14 weight " + "384 64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 26 X86ScheduleConv input 1 384 14 14 weight " + "384 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv2 index 27 X86ScheduleConv input 1 384 14 14 weight " + "64 384 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 28 X86ScheduleConv input 1 64 14 14 weight " + "384 64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 29 X86ScheduleConv input 1 384 14 14 weight " + "384 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv2 index 30 X86ScheduleConv input 1 384 14 14 weight " + "64 384 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 31 X86ScheduleConv input 1 64 14 14 weight " + "384 64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 32 X86ScheduleConv input 1 384 14 14 weight " + "384 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv2 index 33 X86ScheduleConv input 1 384 14 14 weight " + "96 384 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 34 X86ScheduleConv input 1 96 14 14 weight " + "576 96 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 35 X86ScheduleConv input 1 576 14 14 weight " + "576 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 36 X86ScheduleConv input 1 576 14 14 weight " + "96 576 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 37 X86ScheduleConv input 1 96 14 14 weight " + "576 96 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 38 X86ScheduleConv input 1 576 14 14 weight " + "576 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 39 X86ScheduleConv input 1 576 14 14 weight " + "96 576 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 40 X86ScheduleConv input 1 96 14 14 weight " + "576 96 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 41 X86ScheduleConv input 1 576 14 14 weight " + "576 1 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "mobilenetv2 index 42 X86ScheduleConv input 1 576 7 7 weight " + "160 576 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 3}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 43 X86ScheduleConv input 1 160 7 7 weight " + "960 160 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 80}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 44 X86ScheduleConv input 1 960 7 7 weight " + "960 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 192}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + // InputX86Param(model_data, "mobilenetv2 index 45 X86ScheduleConv input 1 960 + // 7 7 weight 160 960 1 1 stride 1 1 padding 0 0 dilation 1 1", {{"ic_bn", + // {-1, 64}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 45 X86ScheduleConv input 1 960 7 7 weight " + "160 960 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 46 X86ScheduleConv input 1 160 7 7 weight " + "960 160 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 80}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 47 X86ScheduleConv input 1 960 7 7 weight " + "960 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 192}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + // InputX86Param(model_data, "mobilenetv2 index 48 X86ScheduleConv input 1 960 + // 7 7 weight 160 960 1 1 stride 1 1 padding 0 0 dilation 1 1", {{"ic_bn", + // {-1, 64}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 48 X86ScheduleConv input 1 960 7 7 weight " + "160 960 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 49 X86ScheduleConv input 1 160 7 7 weight " + "960 160 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 80}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 50 X86ScheduleConv input 1 960 7 7 weight " + "960 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 80}}, + {"oc_bn", {-1, 80}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "mobilenetv2 index 51 X86ScheduleConv input 1 960 7 7 weight " + "320 960 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 80}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "mobilenetv2 index 52 X86ScheduleConv input 1 320 7 7 weight " + "1280 320 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 4}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); } void LoadSqueezeNetParams( - absl::flat_hash_map>> *model_data) { + absl::flat_hash_map>> + *model_data) { CHECK(model_data); - InputX86Param( - model_data, - "squeezenet index 0 X86ScheduleConv input 1 3 227 227 weight 64 3 3 3 stride 2 2 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 1}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "squeezenet index 1 X86ScheduleConv input 1 64 56 56 weight 16 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 14}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "squeezenet index 3 X86ScheduleConv input 1 16 56 56 weight 64 16 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "squeezenet index 2 X86ScheduleConv input 1 16 56 56 weight 64 16 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "squeezenet index 4 X86ScheduleConv input 1 128 56 56 weight 16 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 14}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "squeezenet index 6 X86ScheduleConv input 1 16 56 56 weight 64 16 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "squeezenet index 5 X86ScheduleConv input 1 16 56 56 weight 64 16 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "squeezenet index 7 X86ScheduleConv input 1 128 28 28 weight 32 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "squeezenet index 9 X86ScheduleConv input 1 32 28 28 weight 128 32 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "squeezenet index 8 X86ScheduleConv input 1 32 28 28 weight 128 32 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "squeezenet index 10 X86ScheduleConv input 1 256 28 28 weight 32 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "squeezenet index 12 X86ScheduleConv input 1 32 28 28 weight 128 32 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "squeezenet index 11 X86ScheduleConv input 1 32 28 28 weight 128 32 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 2}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "squeezenet index 13 X86ScheduleConv input 1 256 14 14 weight 48 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "squeezenet index 15 X86ScheduleConv input 1 48 14 14 weight 192 48 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 48}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 8}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "squeezenet index 14 X86ScheduleConv input 1 48 14 14 weight 192 48 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 48}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "squeezenet index 16 X86ScheduleConv input 1 384 14 14 weight 48 384 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "squeezenet index 18 X86ScheduleConv input 1 48 14 14 weight 192 48 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 48}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 8}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "squeezenet index 17 X86ScheduleConv input 1 48 14 14 weight 192 48 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 48}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "squeezenet index 19 X86ScheduleConv input 1 384 14 14 weight 64 384 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "squeezenet index 21 X86ScheduleConv input 1 64 14 14 weight 256 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "squeezenet index 20 X86ScheduleConv input 1 64 14 14 weight 256 64 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "squeezenet index 22 X86ScheduleConv input 1 512 14 14 weight 64 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "squeezenet index 24 X86ScheduleConv input 1 64 14 14 weight 256 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "squeezenet index 23 X86ScheduleConv input 1 64 14 14 weight 256 64 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "squeezenet index 25 X86ScheduleConv input 1 512 14 14 weight 1000 512 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 1}}, {"oc_bn", {-1, 10}}, {"ow_bn", {-1, 14}}, {"oh_bn", {2}}}); + InputX86Param(model_data, + "squeezenet index 0 X86ScheduleConv input 1 3 227 227 weight " + "64 3 3 3 stride 2 2 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 1}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "squeezenet index 1 X86ScheduleConv input 1 64 56 56 weight 16 " + "64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "squeezenet index 3 X86ScheduleConv input 1 16 56 56 weight 64 " + "16 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "squeezenet index 2 X86ScheduleConv input 1 16 56 56 weight 64 " + "16 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "squeezenet index 4 X86ScheduleConv input 1 128 56 56 weight " + "16 128 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "squeezenet index 6 X86ScheduleConv input 1 16 56 56 weight 64 " + "16 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "squeezenet index 5 X86ScheduleConv input 1 16 56 56 weight 64 " + "16 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "squeezenet index 7 X86ScheduleConv input 1 128 28 28 weight " + "32 128 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "squeezenet index 9 X86ScheduleConv input 1 32 28 28 weight " + "128 32 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "squeezenet index 8 X86ScheduleConv input 1 32 28 28 weight " + "128 32 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "squeezenet index 10 X86ScheduleConv input 1 256 28 28 weight " + "32 256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "squeezenet index 12 X86ScheduleConv input 1 32 28 28 weight " + "128 32 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "squeezenet index 11 X86ScheduleConv input 1 32 28 28 weight " + "128 32 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 2}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "squeezenet index 13 X86ScheduleConv input 1 256 14 14 weight " + "48 256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "squeezenet index 15 X86ScheduleConv input 1 48 14 14 weight " + "192 48 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 48}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 8}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "squeezenet index 14 X86ScheduleConv input 1 48 14 14 weight " + "192 48 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 48}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "squeezenet index 16 X86ScheduleConv input 1 384 14 14 weight " + "48 384 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "squeezenet index 18 X86ScheduleConv input 1 48 14 14 weight " + "192 48 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 48}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 8}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "squeezenet index 17 X86ScheduleConv input 1 48 14 14 weight " + "192 48 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 48}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "squeezenet index 19 X86ScheduleConv input 1 384 14 14 weight " + "64 384 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "squeezenet index 21 X86ScheduleConv input 1 64 14 14 weight " + "256 64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "squeezenet index 20 X86ScheduleConv input 1 64 14 14 weight " + "256 64 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "squeezenet index 22 X86ScheduleConv input 1 512 14 14 weight " + "64 512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "squeezenet index 24 X86ScheduleConv input 1 64 14 14 weight " + "256 64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "squeezenet index 23 X86ScheduleConv input 1 64 14 14 weight " + "256 64 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "squeezenet index 25 X86ScheduleConv input 1 512 14 14 weight " + "1000 512 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 1}}, + {"oc_bn", {-1, 10}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {2}}}); } void LoadFaceDetParams( - absl::flat_hash_map>> *model_data) { + absl::flat_hash_map>> + *model_data) { CHECK(model_data); InputX86Param(model_data, - "facedet index 0 X86ScheduleConv input 1 3 240 320 weight 16 3 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 1}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 1 X86ScheduleConv input 1 16 120 160 weight 16 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 2 X86ScheduleConv input 1 16 120 160 weight 32 16 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 20}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "facedet index 3 X86ScheduleConv input 1 32 120 160 weight 32 1 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - InputX86Param(model_data, - "facedet index 4 X86ScheduleConv input 1 32 60 80 weight 32 32 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 5}}, {"oh_bn", {2}}}); - InputX86Param(model_data, - "facedet index 5 X86ScheduleConv input 1 32 60 80 weight 32 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {1}}}); - InputX86Param(model_data, - "facedet index 6 X86ScheduleConv input 1 32 60 80 weight 32 32 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 5}}, {"oh_bn", {2}}}); - InputX86Param(model_data, - "facedet index 7 X86ScheduleConv input 1 32 60 80 weight 32 1 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param(model_data, - "facedet index 8 X86ScheduleConv input 1 32 30 40 weight 64 32 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param(model_data, - "facedet index 9 X86ScheduleConv input 1 64 30 40 weight 64 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "facedet index 10 X86ScheduleConv input 1 64 30 40 weight 64 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param(model_data, - "facedet index 11 X86ScheduleConv input 1 64 30 40 weight 64 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "facedet index 13 X86ScheduleConv input 1 64 30 40 weight 64 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 4}}, {"oh_bn", {1}}}); - InputX86Param(model_data, - "facedet index 26 X86ScheduleConv input 1 64 30 40 weight 64 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "facedet index 12 X86ScheduleConv input 1 64 30 40 weight 64 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 20}}, {"oh_bn", {1}}}); - InputX86Param(model_data, - "facedet index 14 X86ScheduleConv input 1 64 30 40 weight 8 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 40}}, {"oh_bn", {1}}}); - InputX86Param(model_data, - "facedet index 18 X86ScheduleConv input 1 8 30 40 weight 16 8 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 16}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 22 X86ScheduleConv input 1 16 30 40 weight 16 16 3 3 stride 1 1 padding 2 2 dilation 2 2", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"unroll_kw", {1}}}); - InputX86Param(model_data, - "facedet index 15 X86ScheduleConv input 1 64 30 40 weight 8 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 40}}, {"oh_bn", {1}}}); - InputX86Param(model_data, - "facedet index 19 X86ScheduleConv input 1 8 30 40 weight 16 8 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "facedet index 21 X86ScheduleConv input 1 16 30 40 weight 16 16 3 3 stride 1 1 padding 3 3 dilation 3 3", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {1}}}); - InputX86Param(model_data, - "facedet index 16 X86ScheduleConv input 1 64 30 40 weight 8 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 40}}, {"oh_bn", {1}}}); - InputX86Param(model_data, - "facedet index 17 X86ScheduleConv input 1 8 30 40 weight 12 8 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 12}}, {"ow_bn", {-1, 10}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 20 X86ScheduleConv input 1 12 30 40 weight 16 12 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 12}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 10}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 23 X86ScheduleConv input 1 16 30 40 weight 16 16 3 3 stride 1 1 padding 5 5 dilation 5 5", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 24 X86ScheduleConv input 1 48 30 40 weight 64 48 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 6}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 5}}, {"oh_bn", {1}}}); - InputX86Param(model_data, - "facedet index 27 X86ScheduleConv input 1 64 30 40 weight 64 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {0}}}); - InputX86Param(model_data, - "facedet index 29 X86ScheduleConv input 1 64 30 40 weight 6 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 6}}, {"ow_bn", {-1, 40}}, {"oh_bn", {1}}}); - InputX86Param(model_data, - "facedet index 25 X86ScheduleConv input 1 64 30 40 weight 64 1 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 5}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 30 X86ScheduleConv input 1 64 15 20 weight 128 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 5}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "facedet index 31 X86ScheduleConv input 1 128 15 20 weight 128 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 32 X86ScheduleConv input 1 128 15 20 weight 128 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 5}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "facedet index 33 X86ScheduleConv input 1 128 15 20 weight 128 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 34 X86ScheduleConv input 1 128 15 20 weight 128 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 5}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "facedet index 36 X86ScheduleConv input 1 128 15 20 weight 128 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 39 X86ScheduleConv input 1 128 15 20 weight 4 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 4}}, {"ow_bn", {-1, 20}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "facedet index 35 X86ScheduleConv input 1 128 15 20 weight 128 1 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 40 X86ScheduleConv input 1 128 8 10 weight 256 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "facedet index 41 X86ScheduleConv input 1 256 8 10 weight 256 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 42 X86ScheduleConv input 1 256 8 10 weight 256 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 5}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "facedet index 44 X86ScheduleConv input 1 256 8 10 weight 256 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 48 X86ScheduleConv input 1 256 8 10 weight 4 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 4}}, {"ow_bn", {-1, 10}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "facedet index 43 X86ScheduleConv input 1 256 8 10 weight 64 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param(model_data, - "facedet index 46 X86ScheduleConv input 1 64 8 10 weight 64 1 3 3 stride 2 2 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {0}}}); - InputX86Param(model_data, - "facedet index 49 X86ScheduleConv input 1 64 4 5 weight 256 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 4}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 5}}, {"oh_bn", {2}}}); - InputX86Param(model_data, - "facedet index 51 X86ScheduleConv input 1 256 4 5 weight 6 256 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 16}}, {"oc_bn", {-1, 6}}, {"ow_bn", {-1, 5}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 28 X86ScheduleConv input 1 64 30 40 weight 12 64 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 12}}, {"ow_bn", {-1, 40}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "facedet index 37 X86ScheduleConv input 1 128 15 20 weight 128 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 38 X86ScheduleConv input 1 128 15 20 weight 8 128 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 20}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "facedet index 45 X86ScheduleConv input 1 256 8 10 weight 256 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "facedet index 47 X86ScheduleConv input 1 256 8 10 weight 8 256 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 10}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "facedet index 50 X86ScheduleConv input 1 256 4 5 weight 12 256 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 1}}, {"oc_bn", {-1, 12}}, {"ow_bn", {-1, 5}}, {"unroll_kw", {0}}}); + "facedet index 0 X86ScheduleConv input 1 3 240 320 weight 16 3 " + "3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 1}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 1 X86ScheduleConv input 1 16 120 160 weight 16 " + "1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 2 X86ScheduleConv input 1 16 120 160 weight 32 " + "16 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 20}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 3 X86ScheduleConv input 1 32 120 160 weight 32 " + "1 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "facedet index 4 X86ScheduleConv input 1 32 60 80 weight 32 32 " + "1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 5}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "facedet index 5 X86ScheduleConv input 1 32 60 80 weight 32 1 " + "3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 6 X86ScheduleConv input 1 32 60 80 weight 32 32 " + "1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 5}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "facedet index 7 X86ScheduleConv input 1 32 60 80 weight 32 1 " + "3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 8 X86ScheduleConv input 1 32 30 40 weight 64 32 " + "1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 9 X86ScheduleConv input 1 64 30 40 weight 64 1 " + "3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "facedet index 10 X86ScheduleConv input 1 64 30 40 weight 64 " + "64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 11 X86ScheduleConv input 1 64 30 40 weight 64 1 " + "3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "facedet index 13 X86ScheduleConv input 1 64 30 40 weight 64 " + "64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 4}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 26 X86ScheduleConv input 1 64 30 40 weight 64 1 " + "3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "facedet index 12 X86ScheduleConv input 1 64 30 40 weight 64 " + "64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 20}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 14 X86ScheduleConv input 1 64 30 40 weight 8 64 " + "1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 40}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 18 X86ScheduleConv input 1 8 30 40 weight 16 8 " + "3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 16}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 22 X86ScheduleConv input 1 16 30 40 weight 16 " + "16 3 3 stride 1 1 padding 2 2 dilation 2 2", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 15 X86ScheduleConv input 1 64 30 40 weight 8 64 " + "1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 40}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 19 X86ScheduleConv input 1 8 30 40 weight 16 8 " + "3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "facedet index 21 X86ScheduleConv input 1 16 30 40 weight 16 " + "16 3 3 stride 1 1 padding 3 3 dilation 3 3", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 16 X86ScheduleConv input 1 64 30 40 weight 8 64 " + "1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 40}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 17 X86ScheduleConv input 1 8 30 40 weight 12 8 " + "3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 12}}, + {"ow_bn", {-1, 10}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 20 X86ScheduleConv input 1 12 30 40 weight 16 " + "12 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 12}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 10}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 23 X86ScheduleConv input 1 16 30 40 weight 16 " + "16 3 3 stride 1 1 padding 5 5 dilation 5 5", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 24 X86ScheduleConv input 1 48 30 40 weight 64 " + "48 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 6}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 5}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 27 X86ScheduleConv input 1 64 30 40 weight 64 1 " + "3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "facedet index 29 X86ScheduleConv input 1 64 30 40 weight 6 64 " + "1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 6}}, + {"ow_bn", {-1, 40}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 25 X86ScheduleConv input 1 64 30 40 weight 64 1 " + "3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 5}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 30 X86ScheduleConv input 1 64 15 20 weight 128 " + "64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 5}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 31 X86ScheduleConv input 1 128 15 20 weight 128 " + "1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 32 X86ScheduleConv input 1 128 15 20 weight 128 " + "128 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 5}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 33 X86ScheduleConv input 1 128 15 20 weight 128 " + "1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 34 X86ScheduleConv input 1 128 15 20 weight 128 " + "128 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 5}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 36 X86ScheduleConv input 1 128 15 20 weight 128 " + "1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 39 X86ScheduleConv input 1 128 15 20 weight 4 " + "128 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 4}}, + {"ow_bn", {-1, 20}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 35 X86ScheduleConv input 1 128 15 20 weight 128 " + "1 3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 40 X86ScheduleConv input 1 128 8 10 weight 256 " + "128 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "facedet index 41 X86ScheduleConv input 1 256 8 10 weight 256 " + "1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 42 X86ScheduleConv input 1 256 8 10 weight 256 " + "256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 5}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 44 X86ScheduleConv input 1 256 8 10 weight 256 " + "1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 48 X86ScheduleConv input 1 256 8 10 weight 4 " + "256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 4}}, + {"ow_bn", {-1, 10}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 43 X86ScheduleConv input 1 256 8 10 weight 64 " + "256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "facedet index 46 X86ScheduleConv input 1 64 8 10 weight 64 1 " + "3 3 stride 2 2 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "facedet index 49 X86ScheduleConv input 1 64 4 5 weight 256 64 " + "1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 4}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 5}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "facedet index 51 X86ScheduleConv input 1 256 4 5 weight 6 256 " + "3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 16}}, + {"oc_bn", {-1, 6}}, + {"ow_bn", {-1, 5}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 28 X86ScheduleConv input 1 64 30 40 weight 12 " + "64 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 12}}, + {"ow_bn", {-1, 40}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 37 X86ScheduleConv input 1 128 15 20 weight 128 " + "1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 38 X86ScheduleConv input 1 128 15 20 weight 8 " + "128 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 20}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 45 X86ScheduleConv input 1 256 8 10 weight 256 " + "1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "facedet index 47 X86ScheduleConv input 1 256 8 10 weight 8 " + "256 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 10}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "facedet index 50 X86ScheduleConv input 1 256 4 5 weight 12 " + "256 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 1}}, + {"oc_bn", {-1, 12}}, + {"ow_bn", {-1, 5}}, + {"unroll_kw", {0}}}); } void LoadEfficientNetParams( - absl::flat_hash_map>> *model_data) { + absl::flat_hash_map>> + *model_data) { CHECK(model_data); - InputX86Param( - model_data, - "efficientnet index 0 X86ScheduleConv input 1 3 224 224 weight 32 3 3 3 stride 2 2 padding 2 2 dilation 1 1", - {{"ic_bn", {-1, 1}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 4}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "efficientnet index 1 X86ScheduleConv input 1 32 112 112 weight 32 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "efficientnet index 2 X86ScheduleConv input 1 32 1 1 weight 8 32 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 32}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - // InputX86Param(model_data, "efficientnet index 3 X86ScheduleConv input 1 8 1 1 weight 32 8 1 1 stride 1 1 padding 0 - // 0 dilation 1 1", {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 3 X86ScheduleConv input 1 8 1 1 weight 32 8 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 4 X86ScheduleConv input 1 32 112 112 weight 16 32 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 2}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 28}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 5 X86ScheduleConv input 1 16 112 112 weight 96 16 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 2}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 6 X86ScheduleConv input 1 96 112 112 weight 96 1 3 3 stride 2 2 padding 2 2 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "efficientnet index 7 X86ScheduleConv input 1 96 1 1 weight 4 96 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 96}}, {"oc_bn", {-1, 4}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - // InputX86Param(model_data, "efficientnet index 8 X86ScheduleConv input 1 4 1 1 weight 96 4 1 1 stride 1 1 padding 0 - // 0 dilation 1 1", {{"ic_bn", {-1, 4}}, {"oc_bn", {-1, 96}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 8 X86ScheduleConv input 1 4 1 1 weight 96 4 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 4}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 9 X86ScheduleConv input 1 96 56 56 weight 24 96 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 6}}, {"oc_bn", {-1, 12}}, {"ow_bn", {-1, 14}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 10 X86ScheduleConv input 1 24 56 56 weight 144 24 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 6}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "efficientnet index 11 X86ScheduleConv input 1 144 56 56 weight 144 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "efficientnet index 12 X86ScheduleConv input 1 144 1 1 weight 6 144 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 36}}, {"oc_bn", {-1, 6}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - // InputX86Param(model_data, "efficientnet index 13 X86ScheduleConv input 1 6 1 1 weight 144 6 1 1 stride 1 1 padding - // 0 0 dilation 1 1", {{"ic_bn", {-1, 6}}, {"oc_bn", {-1, 144}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 13 X86ScheduleConv input 1 6 1 1 weight 144 6 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 6}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 14 X86ScheduleConv input 1 144 56 56 weight 24 144 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 12}}, {"ow_bn", {-1, 14}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 15 X86ScheduleConv input 1 144 56 56 weight 144 1 5 5 stride 2 2 padding 3 3 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 29}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "efficientnet index 16 X86ScheduleConv input 1 144 28 28 weight 40 144 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 28}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 17 X86ScheduleConv input 1 40 28 28 weight 240 40 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "efficientnet index 18 X86ScheduleConv input 1 240 28 28 weight 240 1 5 5 stride 1 1 padding 2 2 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 4}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "efficientnet index 19 X86ScheduleConv input 1 240 1 1 weight 10 240 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 6}}, {"oc_bn", {-1, 10}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 20 X86ScheduleConv input 1 10 1 1 weight 240 10 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 5}}, {"oc_bn", {-1, 4}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 21 X86ScheduleConv input 1 240 28 28 weight 40 240 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 14}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "efficientnet index 22 X86ScheduleConv input 1 240 28 28 weight 240 1 3 3 stride 2 2 padding 2 2 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 5}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "efficientnet index 23 X86ScheduleConv input 1 240 14 14 weight 80 240 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 80}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 24 X86ScheduleConv input 1 80 14 14 weight 480 80 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 4}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 7}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 25 X86ScheduleConv input 1 480 14 14 weight 480 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 80}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 14}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "efficientnet index 26 X86ScheduleConv input 1 480 1 1 weight 20 480 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 24}}, {"oc_bn", {-1, 20}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 27 X86ScheduleConv input 1 20 1 1 weight 480 20 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 20}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 28 X86ScheduleConv input 1 480 14 14 weight 80 480 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 80}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "efficientnet index 29 X86ScheduleConv input 1 480 14 14 weight 480 1 5 5 stride 1 1 padding 2 2 dilation 1 1", - {{"ic_bn", {-1, 96}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "efficientnet index 30 X86ScheduleConv input 1 480 14 14 weight 112 480 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 80}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "efficientnet index 31 X86ScheduleConv input 1 112 14 14 weight 672 112 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 56}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 2}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "efficientnet index 32 X86ScheduleConv input 1 672 14 14 weight 672 1 5 5 stride 1 1 padding 2 2 dilation 1 1", - {{"ic_bn", {-1, 96}}, {"oc_bn", {-1, 48}}, {"ow_bn", {-1, 2}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "efficientnet index 33 X86ScheduleConv input 1 672 1 1 weight 28 672 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 6}}, {"oc_bn", {-1, 14}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 34 X86ScheduleConv input 1 28 1 1 weight 672 28 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 1}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 35 X86ScheduleConv input 1 672 14 14 weight 112 672 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 96}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 14}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "efficientnet index 36 X86ScheduleConv input 1 672 14 14 weight 672 1 5 5 stride 2 2 padding 3 3 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 8}}, {"unroll_kw", {0}}}); - InputX86Param( - model_data, - "efficientnet index 37 X86ScheduleConv input 1 672 7 7 weight 192 672 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 4}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "efficientnet index 38 X86ScheduleConv input 1 192 7 7 weight 1152 192 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 3}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "efficientnet index 39 X86ScheduleConv input 1 1152 7 7 weight 1152 1 5 5 stride 1 1 padding 2 2 dilation 1 1", - {{"ic_bn", {-1, 8}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 7}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "efficientnet index 40 X86ScheduleConv input 1 1152 1 1 weight 48 1152 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 576}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 41 X86ScheduleConv input 1 48 1 1 weight 1152 48 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 12}}, {"oc_bn", {-1, 8}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); - InputX86Param( - model_data, - "efficientnet index 42 X86ScheduleConv input 1 1152 7 7 weight 192 1152 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 72}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "efficientnet index 43 X86ScheduleConv input 1 1152 7 7 weight 1152 1 3 3 stride 1 1 padding 1 1 dilation 1 1", - {{"ic_bn", {-1, 64}}, {"oc_bn", {-1, 64}}, {"ow_bn", {-1, 1}}, {"unroll_kw", {1}}}); - InputX86Param( - model_data, - "efficientnet index 44 X86ScheduleConv input 1 1152 7 7 weight 320 1152 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 384}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); - InputX86Param( - model_data, - "efficientnet index 45 X86ScheduleConv input 1 320 7 7 weight 1280 320 1 1 stride 1 1 padding 0 0 dilation 1 1", - {{"ic_bn", {-1, 4}}, {"oc_bn", {-1, 16}}, {"ow_bn", {-1, 7}}, {"oh_bn", {2}}}); + InputX86Param(model_data, + "efficientnet index 0 X86ScheduleConv input 1 3 224 224 weight " + "32 3 3 3 stride 2 2 padding 2 2 dilation 1 1", + {{"ic_bn", {-1, 1}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 4}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "efficientnet index 1 X86ScheduleConv input 1 32 112 112 " + "weight 32 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "efficientnet index 2 X86ScheduleConv input 1 32 1 1 weight 8 " + "32 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 32}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + // InputX86Param(model_data, "efficientnet index 3 X86ScheduleConv input 1 8 1 + // 1 weight 32 8 1 1 stride 1 1 padding 0 0 dilation 1 1", {{"ic_bn", {-1, + // 8}}, {"oc_bn", {-1, 32}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 3 X86ScheduleConv input 1 8 1 1 weight 32 " + "8 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 4 X86ScheduleConv input 1 32 112 112 " + "weight 16 32 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 2}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 28}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 5 X86ScheduleConv input 1 16 112 112 " + "weight 96 16 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 2}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 6 X86ScheduleConv input 1 96 112 112 " + "weight 96 1 3 3 stride 2 2 padding 2 2 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "efficientnet index 7 X86ScheduleConv input 1 96 1 1 weight 4 " + "96 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 96}}, + {"oc_bn", {-1, 4}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + // InputX86Param(model_data, "efficientnet index 8 X86ScheduleConv input 1 4 1 + // 1 weight 96 4 1 1 stride 1 1 padding 0 0 dilation 1 1", {{"ic_bn", {-1, + // 4}}, {"oc_bn", {-1, 96}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 8 X86ScheduleConv input 1 4 1 1 weight 96 " + "4 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 4}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 9 X86ScheduleConv input 1 96 56 56 weight " + "24 96 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 6}}, + {"oc_bn", {-1, 12}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 10 X86ScheduleConv input 1 24 56 56 weight " + "144 24 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 6}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "efficientnet index 11 X86ScheduleConv input 1 144 56 56 " + "weight 144 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "efficientnet index 12 X86ScheduleConv input 1 144 1 1 weight " + "6 144 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 36}}, + {"oc_bn", {-1, 6}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + // InputX86Param(model_data, "efficientnet index 13 X86ScheduleConv input 1 6 + // 1 1 weight 144 6 1 1 stride 1 1 padding 0 0 dilation 1 1", {{"ic_bn", {-1, + // 6}}, {"oc_bn", {-1, 144}}, {"ow_bn", {-1, 1}}, {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 13 X86ScheduleConv input 1 6 1 1 weight " + "144 6 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 6}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 14 X86ScheduleConv input 1 144 56 56 " + "weight 24 144 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 12}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 15 X86ScheduleConv input 1 144 56 56 " + "weight 144 1 5 5 stride 2 2 padding 3 3 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 29}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "efficientnet index 16 X86ScheduleConv input 1 144 28 28 " + "weight 40 144 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 28}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 17 X86ScheduleConv input 1 40 28 28 weight " + "240 40 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "efficientnet index 18 X86ScheduleConv input 1 240 28 28 " + "weight 240 1 5 5 stride 1 1 padding 2 2 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 4}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "efficientnet index 19 X86ScheduleConv input 1 240 1 1 weight " + "10 240 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 6}}, + {"oc_bn", {-1, 10}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 20 X86ScheduleConv input 1 10 1 1 weight " + "240 10 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 5}}, + {"oc_bn", {-1, 4}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 21 X86ScheduleConv input 1 240 28 28 " + "weight 40 240 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "efficientnet index 22 X86ScheduleConv input 1 240 28 28 " + "weight 240 1 3 3 stride 2 2 padding 2 2 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 5}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "efficientnet index 23 X86ScheduleConv input 1 240 14 14 " + "weight 80 240 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 80}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 24 X86ScheduleConv input 1 80 14 14 weight " + "480 80 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 4}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 25 X86ScheduleConv input 1 480 14 14 " + "weight 480 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 80}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 14}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "efficientnet index 26 X86ScheduleConv input 1 480 1 1 weight " + "20 480 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 24}}, + {"oc_bn", {-1, 20}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 27 X86ScheduleConv input 1 20 1 1 weight " + "480 20 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 20}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 28 X86ScheduleConv input 1 480 14 14 " + "weight 80 480 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 80}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "efficientnet index 29 X86ScheduleConv input 1 480 14 14 " + "weight 480 1 5 5 stride 1 1 padding 2 2 dilation 1 1", + {{"ic_bn", {-1, 96}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "efficientnet index 30 X86ScheduleConv input 1 480 14 14 " + "weight 112 480 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 80}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "efficientnet index 31 X86ScheduleConv input 1 112 14 14 " + "weight 672 112 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 56}}, + {"oc_bn", {-1, 32}}, + {"ow_bn", {-1, 2}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "efficientnet index 32 X86ScheduleConv input 1 672 14 14 " + "weight 672 1 5 5 stride 1 1 padding 2 2 dilation 1 1", + {{"ic_bn", {-1, 96}}, + {"oc_bn", {-1, 48}}, + {"ow_bn", {-1, 2}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "efficientnet index 33 X86ScheduleConv input 1 672 1 1 weight " + "28 672 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 6}}, + {"oc_bn", {-1, 14}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 34 X86ScheduleConv input 1 28 1 1 weight " + "672 28 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 1}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 35 X86ScheduleConv input 1 672 14 14 " + "weight 112 672 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 96}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 14}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "efficientnet index 36 X86ScheduleConv input 1 672 14 14 " + "weight 672 1 5 5 stride 2 2 padding 3 3 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 8}}, + {"unroll_kw", {0}}}); + InputX86Param(model_data, + "efficientnet index 37 X86ScheduleConv input 1 672 7 7 weight " + "192 672 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 4}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "efficientnet index 38 X86ScheduleConv input 1 192 7 7 weight " + "1152 192 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 3}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "efficientnet index 39 X86ScheduleConv input 1 1152 7 7 weight " + "1152 1 5 5 stride 1 1 padding 2 2 dilation 1 1", + {{"ic_bn", {-1, 8}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 7}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "efficientnet index 40 X86ScheduleConv input 1 1152 1 1 weight " + "48 1152 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 576}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 41 X86ScheduleConv input 1 48 1 1 weight " + "1152 48 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 12}}, + {"oc_bn", {-1, 8}}, + {"ow_bn", {-1, 1}}, + {"oh_bn", {1}}}); + InputX86Param(model_data, + "efficientnet index 42 X86ScheduleConv input 1 1152 7 7 weight " + "192 1152 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 72}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "efficientnet index 43 X86ScheduleConv input 1 1152 7 7 weight " + "1152 1 3 3 stride 1 1 padding 1 1 dilation 1 1", + {{"ic_bn", {-1, 64}}, + {"oc_bn", {-1, 64}}, + {"ow_bn", {-1, 1}}, + {"unroll_kw", {1}}}); + InputX86Param(model_data, + "efficientnet index 44 X86ScheduleConv input 1 1152 7 7 weight " + "320 1152 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 384}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); + InputX86Param(model_data, + "efficientnet index 45 X86ScheduleConv input 1 320 7 7 weight " + "1280 320 1 1 stride 1 1 padding 0 0 dilation 1 1", + {{"ic_bn", {-1, 4}}, + {"oc_bn", {-1, 16}}, + {"ow_bn", {-1, 7}}, + {"oh_bn", {2}}}); } -absl::flat_hash_map>> CreateX86Params() { - absl::flat_hash_map>> model_data; +absl::flat_hash_map>> +CreateX86Params() { + absl::flat_hash_map>> + model_data; LoadX86DefaultParams(&model_data); LoadResNet18Params(&model_data); LoadResNet50Params(&model_data); diff --git a/paddle/cinn/hlir/pe/load_x86_params.h b/paddle/cinn/hlir/pe/load_x86_params.h index 9273b242dae79..e897792ef7ecd 100644 --- a/paddle/cinn/hlir/pe/load_x86_params.h +++ b/paddle/cinn/hlir/pe/load_x86_params.h @@ -23,27 +23,49 @@ namespace cinn { namespace hlir { namespace pe { -void InputX86Param(absl::flat_hash_map>> *model_data, - const std::string &key, - const absl::flat_hash_map> &schedule_data); +void InputX86Param( + absl::flat_hash_map>> + *model_data, + const std::string &key, + const absl::flat_hash_map> &schedule_data); -absl::flat_hash_map>> CreateX86Params(); +absl::flat_hash_map>> +CreateX86Params(); void LoadResNet18Params( - absl::flat_hash_map>> *model_data); + absl::flat_hash_map>> + *model_data); void LoadResNet50Params( - absl::flat_hash_map>> *model_data); + absl::flat_hash_map>> + *model_data); void LoadMobileNetV1Params( - absl::flat_hash_map>> *model_data); + absl::flat_hash_map>> + *model_data); void LoadMobileNetV2Params( - absl::flat_hash_map>> *model_data); + absl::flat_hash_map>> + *model_data); void LoadFaceDetParams( - absl::flat_hash_map>> *model_data); + absl::flat_hash_map>> + *model_data); void LoadEfficientNetParams( - absl::flat_hash_map>> *model_data); + absl::flat_hash_map>> + *model_data); void LoadSqueezeNetParams( - absl::flat_hash_map>> *model_data); + absl::flat_hash_map>> + *model_data); -void CreateX86Params(absl::flat_hash_map>> *model_data); +void CreateX86Params( + absl::flat_hash_map>> + *model_data); } // namespace pe } // namespace hlir diff --git a/paddle/cinn/hlir/pe/nn.cc b/paddle/cinn/hlir/pe/nn.cc index 590f9e66b4dc7..faf459800d523 100644 --- a/paddle/cinn/hlir/pe/nn.cc +++ b/paddle/cinn/hlir/pe/nn.cc @@ -58,27 +58,51 @@ std::string Type2StrForNN(common::Type type) { return ""; } -ir::Tensor Relu(const ir::Tensor &A, double threshold, const std::string &output_name) { +ir::Tensor Relu(const ir::Tensor &A, + double threshold, + const std::string &output_name) { return lang::Compute( - A->shape, [=](const std::vector &indice) { return lang::Relu(A(indice), threshold); }, output_name); + A->shape, + [=](const std::vector &indice) { + return lang::Relu(A(indice), threshold); + }, + output_name); } -ir::Tensor Relu6(const ir::Tensor &A, double threshold, const std::string &output_name) { +ir::Tensor Relu6(const ir::Tensor &A, + double threshold, + const std::string &output_name) { return lang::Compute( - A->shape, [=](const std::vector &indice) { return lang::Relu6(A(indice), threshold); }, output_name); + A->shape, + [=](const std::vector &indice) { + return lang::Relu6(A(indice), threshold); + }, + output_name); } -Tensor LeakyRelu(const Tensor &A, double alpha, const std::string &output_name) { +Tensor LeakyRelu(const Tensor &A, + double alpha, + const std::string &output_name) { return Compute( - A->shape, [=](const std::vector &indice) { return lang::LeakyRelu(A(indice), alpha); }, output_name); + A->shape, + [=](const std::vector &indice) { + return lang::LeakyRelu(A(indice), alpha); + }, + output_name); } -Tensor PRelu(const Tensor &A, const Tensor &slope, const int axis, const std::string &output_name) { +Tensor PRelu(const Tensor &A, + const Tensor &slope, + const int axis, + const std::string &output_name) { CHECK_LT(axis, A->shape.size()) << "Wrong axis value: " << axis << std::endl; - CHECK(A->shape[axis] == slope->shape[0]) << "Wrong slope shape: " << slope->shape[0] << std::endl; + CHECK(A->shape[axis] == slope->shape[0]) + << "Wrong slope shape: " << slope->shape[0] << std::endl; return Compute( A->shape, - [=](const std::vector &indice) { return lang::LeakyRelu(A(indice), slope(indice[axis])); }, + [=](const std::vector &indice) { + return lang::LeakyRelu(A(indice), slope(indice[axis])); + }, output_name); } @@ -91,8 +115,11 @@ std::vector Conv2d_winograd_NCHW(const ir::Tensor &input, int dilation_h, int dilation_w, const std::string &output_name) { - CHECK_EQ(input->shape.size(), 4U) << "Input's dimension of Conv2d_winograd_NCHW op is not 4! Please check."; - CHECK_EQ(weights->shape.size(), 4U) << "Weight's dimension of Conv2d_winograd_NCHW op is not 4! Please check."; + CHECK_EQ(input->shape.size(), 4U) + << "Input's dimension of Conv2d_winograd_NCHW op is not 4! Please check."; + CHECK_EQ(weights->shape.size(), 4U) + << "Weight's dimension of Conv2d_winograd_NCHW op is not 4! Please " + "check."; std::vector output_shape; std::vector new_weights_shape; std::vector input_pad_shape; @@ -107,30 +134,45 @@ std::vector Conv2d_winograd_NCHW(const ir::Tensor &input, auto weights_dilation = Compute( new_weights_shape, [=](Expr nn, Expr cc, Expr yy, Expr xx) { - auto cond = lang::logic_and({(yy) % dilation_h == 0, xx % dilation_w == 0}); + auto cond = + lang::logic_and({(yy) % dilation_h == 0, xx % dilation_w == 0}); return ir::Select::Make( - cond, weights(nn, cc, (yy / dilation_h), (xx / dilation_w)), common::make_const(weights->type(), 0)); + cond, + weights(nn, cc, (yy / dilation_h), (xx / dilation_w)), + common::make_const(weights->type(), 0)); }, UniqName("weights_dilation")); - CHECK(MathEqual((weights->shape[0] * weights->shape[1]) % input->shape[1], Expr(0))) + CHECK(MathEqual((weights->shape[0] * weights->shape[1]) % input->shape[1], + Expr(0))) << "filter's output channel size must be divisible by group\n"; int alpha = weights_dilation->shape[3].as_int32() + tile_size - 1; - input_pad_shape = {input->shape[0], input->shape[1], input->shape[2] + 2 * pad_h, input->shape[3] + 2 * pad_w}; + input_pad_shape = {input->shape[0], + input->shape[1], + input->shape[2] + 2 * pad_h, + input->shape[3] + 2 * pad_w}; ir::Tensor input_pad; if (pad_h == 0 && pad_w == 0) { input_pad = Compute( - input->shape, [=](Expr nn, Expr cc, Expr yy, Expr xx) { return input(nn, cc, yy, xx); }, UniqName("input_pad")); + input->shape, + [=](Expr nn, Expr cc, Expr yy, Expr xx) { + return input(nn, cc, yy, xx); + }, + UniqName("input_pad")); } else { input_pad = Compute( input_pad_shape, [=](Expr nn, Expr cc, Expr yy, Expr xx) { - auto cond = - lang::logic_and({yy >= pad_h, yy < input->shape[2] + pad_h, xx >= pad_w, xx < input->shape[3] + pad_w}); - return ir::Select::Make(cond, input(nn, cc, yy - pad_h, xx - pad_w), ir::Zero(input->type())); + auto cond = lang::logic_and({yy >= pad_h, + yy < input->shape[2] + pad_h, + xx >= pad_w, + xx < input->shape[3] + pad_w}); + return ir::Select::Make(cond, + input(nn, cc, yy - pad_h, xx - pad_w), + ir::Zero(input->type())); }, UniqName("input_pad")); } @@ -143,15 +185,22 @@ std::vector Conv2d_winograd_NCHW(const ir::Tensor &input, input->shape[0], // B weights->shape[0], // O common::AutoSimplify( - (input->shape[2] - ((weights_dilation->shape[2] - 1) * dilation_h + 1) + 2 * pad_h) / stride_h + 1), // H + (input->shape[2] - + ((weights_dilation->shape[2] - 1) * dilation_h + 1) + 2 * pad_h) / + stride_h + + 1), // H common::AutoSimplify( - (input->shape[3] - ((weights_dilation->shape[3] - 1) * dilation_w + 1) + 2 * pad_w) / stride_w + 1) // W + (input->shape[3] - + ((weights_dilation->shape[3] - 1) * dilation_w + 1) + 2 * pad_w) / + stride_w + + 1) // W }; - std::vector winograd_transform = winograd_transform_matrices(m, r); - ir::Tensor A = winograd_transform[0]; - ir::Tensor B = winograd_transform[1]; - ir::Tensor G = winograd_transform[2]; + std::vector winograd_transform = + winograd_transform_matrices(m, r); + ir::Tensor A = winograd_transform[0]; + ir::Tensor B = winograd_transform[1]; + ir::Tensor G = winograd_transform[2]; int nH = (common::AutoSimplify(output_shape[2]).as_int32() + m - 1) / m; int nW = (common::AutoSimplify(output_shape[3]).as_int32() + m - 1) / m; @@ -160,61 +209,86 @@ std::vector Conv2d_winograd_NCHW(const ir::Tensor &input, Var r_kh(weights_dilation->shape[2], UniqName("r_kh")); Var r_kw(weights_dilation->shape[3], UniqName("r_kw")); - std::vector kernel_shape = {Expr(alpha), Expr(alpha), weights_dilation->shape[1], weights_dilation->shape[0]}; - auto kernel_pack = Compute( + std::vector kernel_shape = {Expr(alpha), + Expr(alpha), + weights_dilation->shape[1], + weights_dilation->shape[0]}; + auto kernel_pack = Compute( kernel_shape, [=](Expr eps, Expr nu, Expr ci, Expr co) { - return lang::ReduceSum(weights_dilation(co, ci, r_kh, r_kw) * G(eps, r_kh) * G(nu, r_kw), {r_kh, r_kw}); + return lang::ReduceSum( + weights_dilation(co, ci, r_kh, r_kw) * G(eps, r_kh) * G(nu, r_kw), + {r_kh, r_kw}); }, UniqName("kernel_pack")); // pack input tile - std::vector input_tile_shape = {weights_dilation->shape[1], Expr(P), Expr(alpha), Expr(alpha)}; - auto input_tile = Compute( + std::vector input_tile_shape = { + weights_dilation->shape[1], Expr(P), Expr(alpha), Expr(alpha)}; + auto input_tile = Compute( input_tile_shape, [=](Expr c, Expr p, Expr eps, Expr nu) { - return input_pad((p / (nH * nW)), c, ((p / nW) % nH) * m + eps, (p % nW) * m + nu); + return input_pad( + (p / (nH * nW)), c, ((p / nW) % nH) * m + eps, (p % nW) * m + nu); }, UniqName("input_tile")); - std::vector data_pack_shape = {Expr(alpha), Expr(alpha), weights_dilation->shape[1], Expr(P)}; + std::vector data_pack_shape = { + Expr(alpha), Expr(alpha), weights_dilation->shape[1], Expr(P)}; Var r_a(input_tile->shape[2], UniqName("r_a")); Var r_b(input_tile->shape[3], UniqName("r_b")); auto data_pack = Compute( data_pack_shape, [=](Expr eps, Expr nu, Expr ci, Expr p) { - return lang::ReduceSum(input_tile(ci, p, r_a, r_b) * B(r_a, eps) * B(r_b, nu), {r_a, r_b}); + return lang::ReduceSum( + input_tile(ci, p, r_a, r_b) * B(r_a, eps) * B(r_b, nu), {r_a, r_b}); }, UniqName("data_pack")); // do batch gemm - std::vector bgemm_shape = {Expr(alpha), Expr(alpha), weights_dilation->shape[0], Expr(P)}; + std::vector bgemm_shape = { + Expr(alpha), Expr(alpha), weights_dilation->shape[0], Expr(P)}; Var ci(kernel_pack->shape[2], UniqName("ci")); auto bgemm = Compute( bgemm_shape, [=](Expr eps, Expr nu, Expr co, Expr p) { - return lang::ReduceSum(kernel_pack(eps, nu, ci, co) * data_pack(eps, nu, ci, p), {ci}); + return lang::ReduceSum( + kernel_pack(eps, nu, ci, co) * data_pack(eps, nu, ci, p), {ci}); }, UniqName("bgemm")); // # inverse transform - std::vector inverse_shape = {weights_dilation->shape[0], Expr(P), Expr(m), Expr(m)}; + std::vector inverse_shape = { + weights_dilation->shape[0], Expr(P), Expr(m), Expr(m)}; Var r_g_a(bgemm->shape[0], UniqName("r_g_a")); Var r_g_b(bgemm->shape[1], UniqName("r_g_b")); auto inverse = Compute( inverse_shape, [=](Expr co, Expr p, Expr vh, Expr vw) { - return lang::ReduceSum(bgemm(r_g_a, r_g_b, co, p) * A(r_g_a, vh) * A(r_g_b, vw), {r_g_a, r_g_b}); + return lang::ReduceSum( + bgemm(r_g_a, r_g_b, co, p) * A(r_g_a, vh) * A(r_g_b, vw), + {r_g_a, r_g_b}); }, UniqName("inverse")); auto res = Compute( output_shape, [=](Expr n, Expr co, Expr h, Expr w) { - return inverse(co, n * nH * nW + (h / m) * nW + (w / m), (h % m), (w % m)); + return inverse( + co, n * nH * nW + (h / m) * nW + (w / m), (h % m), (w % m)); }, output_name); - return {weights_dilation, input_pad, A, B, G, kernel_pack, input_tile, data_pack, bgemm, inverse, res}; + return {weights_dilation, + input_pad, + A, + B, + G, + kernel_pack, + input_tile, + data_pack, + bgemm, + inverse, + res}; } std::vector Conv2d_NCHW(const ir::Tensor &input, @@ -227,29 +301,37 @@ std::vector Conv2d_NCHW(const ir::Tensor &input, int dilation_w, const std::string &output_name, bool choose_direct_compute) { - CHECK_EQ(input->shape.size(), 4U) << "Input's dimension of Conv2d_NCHW op is not 4! Please check."; - CHECK_EQ(weights->shape.size(), 4U) << "Weight's dimension of Conv2d_NCHW op is not 4! Please check."; + CHECK_EQ(input->shape.size(), 4U) + << "Input's dimension of Conv2d_NCHW op is not 4! Please check."; + CHECK_EQ(weights->shape.size(), 4U) + << "Weight's dimension of Conv2d_NCHW op is not 4! Please check."; std::vector output_shape_int; std::vector new_weights_shape_int; std::vector input_pad_shape_int; output_shape_int = { input->shape[0].as_int32(), // B weights->shape[0].as_int32(), // O - (input->shape[2].as_int32() - ((weights->shape[2].as_int32() - 1) * dilation_h + 1) + 2 * pad_h) / stride_h + + (input->shape[2].as_int32() - + ((weights->shape[2].as_int32() - 1) * dilation_h + 1) + 2 * pad_h) / + stride_h + 1, // H - (input->shape[3].as_int32() - ((weights->shape[3].as_int32() - 1) * dilation_w + 1) + 2 * pad_w) / stride_w + + (input->shape[3].as_int32() - + ((weights->shape[3].as_int32() - 1) * dilation_w + 1) + 2 * pad_w) / + stride_w + 1 // W }; new_weights_shape_int = {weights->shape[0].as_int32(), weights->shape[1].as_int32(), dilation_h * (weights->shape[2].as_int32() - 1) + 1, dilation_w * (weights->shape[3].as_int32() - 1) + 1}; - input_pad_shape_int = {input->shape[0].as_int32(), + input_pad_shape_int = {input->shape[0].as_int32(), input->shape[1].as_int32(), input->shape[2].as_int32() + 2 * pad_h, input->shape[3].as_int32() + 2 * pad_w}; - std::vector output_shape{ - Expr(output_shape_int[0]), Expr(output_shape_int[1]), Expr(output_shape_int[2]), Expr(output_shape_int[3])}; + std::vector output_shape{Expr(output_shape_int[0]), + Expr(output_shape_int[1]), + Expr(output_shape_int[2]), + Expr(output_shape_int[3])}; std::vector new_weights_shape{Expr(new_weights_shape_int[0]), Expr(new_weights_shape_int[1]), Expr(new_weights_shape_int[2]), @@ -263,34 +345,57 @@ std::vector Conv2d_NCHW(const ir::Tensor &input, CHECK(weights->shape[3].is_constant()); int kh = weights->shape[2].as_int32(); int kw = weights->shape[3].as_int32(); - if (!choose_direct_compute && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1 && 2 < kh && - kh < 8 && 2 < kw && kw < 8) { - auto &res = ScheduleParam::get_cuda_instance().GetParam(); - std::string key = "CudaWinogradConvSchedule " + std::to_string(input_pad_shape_int[0]) + " " + - std::to_string(input_pad_shape_int[1]) + " " + std::to_string(input_pad_shape_int[2]) + " " + - std::to_string(input_pad_shape_int[3]) + " " + std::to_string(new_weights_shape_int[0]) + " " + - std::to_string(new_weights_shape_int[1]) + " " + std::to_string(new_weights_shape_int[2]) + " " + - std::to_string(new_weights_shape_int[3]) + " " + std::to_string(output_shape_int[0]) + " " + - std::to_string(output_shape_int[1]) + " " + std::to_string(output_shape_int[2]) + " " + + if (!choose_direct_compute && stride_h == 1 && stride_w == 1 && + dilation_h == 1 && dilation_w == 1 && 2 < kh && kh < 8 && 2 < kw && + kw < 8) { + auto &res = ScheduleParam::get_cuda_instance().GetParam(); + std::string key = "CudaWinogradConvSchedule " + + std::to_string(input_pad_shape_int[0]) + " " + + std::to_string(input_pad_shape_int[1]) + " " + + std::to_string(input_pad_shape_int[2]) + " " + + std::to_string(input_pad_shape_int[3]) + " " + + std::to_string(new_weights_shape_int[0]) + " " + + std::to_string(new_weights_shape_int[1]) + " " + + std::to_string(new_weights_shape_int[2]) + " " + + std::to_string(new_weights_shape_int[3]) + " " + + std::to_string(output_shape_int[0]) + " " + + std::to_string(output_shape_int[1]) + " " + + std::to_string(output_shape_int[2]) + " " + std::to_string(output_shape_int[3]); if (res.count(key) > 0) { VLOG(3) << "Find saved winograd_conv2d schedule param! key is: " << key; - return Conv2d_winograd_NCHW( - input, weights, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, output_name); + return Conv2d_winograd_NCHW(input, + weights, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + output_name); } - VLOG(3) << "Didn't find saved winograd_conv2d schedule param! key is: " << key; + VLOG(3) << "Didn't find saved winograd_conv2d schedule param! key is: " + << key; } ir::Tensor input_pad; if (pad_h == 0 && pad_w == 0) { input_pad = Compute( - input->shape, [=](Expr nn, Expr cc, Expr yy, Expr xx) { return input(nn, cc, yy, xx); }, UniqName("input_pad")); + input->shape, + [=](Expr nn, Expr cc, Expr yy, Expr xx) { + return input(nn, cc, yy, xx); + }, + UniqName("input_pad")); } else { input_pad = Compute( input_pad_shape, [=](Expr nn, Expr cc, Expr yy, Expr xx) { - auto cond = - lang::logic_and({yy >= pad_h, yy < input->shape[2] + pad_h, xx >= pad_w, xx < input->shape[3] + pad_w}); - return ir::Select::Make(cond, input(nn, cc, yy - pad_h, xx - pad_w), ir::Zero(input->type())); + auto cond = lang::logic_and({yy >= pad_h, + yy < input->shape[2] + pad_h, + xx >= pad_w, + xx < input->shape[3] + pad_w}); + return ir::Select::Make(cond, + input(nn, cc, yy - pad_h, xx - pad_w), + ir::Zero(input->type())); }, UniqName("input_pad")); } @@ -299,12 +404,16 @@ std::vector Conv2d_NCHW(const ir::Tensor &input, Var ry(weights->shape[2], UniqName("ry")); Var rx(weights->shape[3], UniqName("rx")); - CHECK(MathEqual((weights->shape[0] * weights->shape[1]) % input->shape[1], Expr(0))) + CHECK(MathEqual((weights->shape[0] * weights->shape[1]) % input->shape[1], + Expr(0))) << "filter's output channel size must be divisible by group\n"; auto res = Compute( output_shape, [=](Expr nn, Expr ff, Expr yy, Expr xx) { - return lang::ReduceSum(input_pad(nn, rc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w) * + return lang::ReduceSum(input_pad(nn, + rc, + yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w) * weights(ff, rc, ry, rx), {rc, ry, rx}); }, @@ -326,21 +435,24 @@ std::vector Conv2d_NCHW_5D(const ir::Tensor &input, // input: 4D to 5D, NCHW->NCHWc // [batch, in_channel, in_height, in_width] -> // [batch, in_channel_chunk, in_height, in_width, in_channel_block] - auto type = input->type(); - std::vector shape_input = input->shape; + auto type = input->type(); + std::vector shape_input = input->shape; std::vector shape_weights = weights->shape; CHECK_EQ(shape_input.size(), 4U) << "input's shape size should be 4"; CHECK_EQ(shape_weights.size(), 4U) << "weight's shape size should be 4"; - Expr c_in = common::AutoSimplify(shape_input[1]); + Expr c_in = common::AutoSimplify(shape_input[1]); Expr c_filter = common::AutoSimplify(shape_weights[1]); - Expr c_out = common::AutoSimplify(shape_weights[0]); + Expr c_out = common::AutoSimplify(shape_weights[0]); absl::flat_hash_map conv2d_factors; - int oc = c_out.as_int32(); - int ic = c_in.as_int32(); + int oc = c_out.as_int32(); + int ic = c_in.as_int32(); int fc_size = c_filter.as_int32(); if (key.empty()) { - key = - GenerateX86ConvKey(shape_input, shape_weights, {stride_h, stride_w}, {pad_h, pad_w}, {dilation_h, dilation_w}); + key = GenerateX86ConvKey(shape_input, + shape_weights, + {stride_h, stride_w}, + {pad_h, pad_w}, + {dilation_h, dilation_w}); } GetConv2dFactors(&conv2d_factors, oc, ic, fc_size, -1, -1, type, target, key); int ic_bn_size = conv2d_factors["ic_bn"]; @@ -349,26 +461,29 @@ std::vector Conv2d_NCHW_5D(const ir::Tensor &input, VLOG(3) << "oc_bn: " << oc_bn_size; VLOG(3) << "ic_bn: " << ic_bn_size; VLOG(3) << "fc_bn: " << fc_bn_size; - Expr ic_bn = Expr(ic_bn_size); - Expr oc_bn = Expr(oc_bn_size); - Expr fc_bn = Expr(fc_bn_size); + Expr ic_bn = Expr(ic_bn_size); + Expr oc_bn = Expr(oc_bn_size); + Expr fc_bn = Expr(fc_bn_size); Expr ic_chunk = c_in / ic_bn; Expr oc_chunk = c_out / oc_bn; Expr fc_chunk = c_filter / fc_bn; // pack data, 4D->5D Expr batch = shape_input[0]; - Expr h_in = shape_input[2]; - Expr w_in = shape_input[3]; - Expr h_f = shape_weights[2]; - Expr w_f = shape_weights[3]; - auto data = Compute( + Expr h_in = shape_input[2]; + Expr w_in = shape_input[3]; + Expr h_f = shape_weights[2]; + Expr w_f = shape_weights[3]; + auto data = Compute( {batch, ic_chunk, h_in, w_in, ic_bn}, - [=](Expr n, Expr icc, Expr h, Expr w, Expr icb) { return input(n, icc * ic_bn + icb, h, w); }, + [=](Expr n, Expr icc, Expr h, Expr w, Expr icb) { + return input(n, icc * ic_bn + icb, h, w); + }, UniqName("data_vec")); // pack kernel, 4D->6D std::vector new_weights_shape; - new_weights_shape = {oc_chunk, fc_chunk, shape_weights[2], shape_weights[3], fc_bn, oc_bn}; + new_weights_shape = { + oc_chunk, fc_chunk, shape_weights[2], shape_weights[3], fc_bn, oc_bn}; auto weights_dilation = Compute( new_weights_shape, @@ -377,20 +492,33 @@ std::vector Conv2d_NCHW_5D(const ir::Tensor &input, }, UniqName("weights_dilation_vec")); - auto tensors = Conv2d_NCHWc(data, weights_dilation, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w); + auto tensors = Conv2d_NCHWc(data, + weights_dilation, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w); CHECK_EQ(tensors.size(), 2U) << "Conv2d_NCHWc should return 2 tensors"; auto packed_out = tensors[0]; - auto input_pad = tensors[1]; + auto input_pad = tensors[1]; // 5D back to 4D, NCHWc->NCHW std::vector output_shape = { - batch, // B - c_out, // O - common::AutoSimplify((h_in - ((h_f - 1) * dilation_h + 1) + 2 * pad_h) / stride_h + 1), // H - common::AutoSimplify((w_in - ((w_f - 1) * dilation_w + 1) + 2 * pad_w) / stride_w + 1) // W + batch, // B + c_out, // O + common::AutoSimplify((h_in - ((h_f - 1) * dilation_h + 1) + 2 * pad_h) / + stride_h + + 1), // H + common::AutoSimplify((w_in - ((w_f - 1) * dilation_w + 1) + 2 * pad_w) / + stride_w + + 1) // W }; auto res = Compute( output_shape, - [=](Expr n, Expr c, Expr h, Expr w) { return packed_out(n, c / oc_bn, h, w, c % oc_bn); }, + [=](Expr n, Expr c, Expr h, Expr w) { + return packed_out(n, c / oc_bn, h, w, c % oc_bn); + }, UniqName("conv2d_nchw_out")); return {res, packed_out, weights_dilation, input_pad, data}; } @@ -406,53 +534,68 @@ std::vector Conv2d_NCHWc(const ir::Tensor &input, const std::string &output_name, const common::Target &target) { // input: [N, c_in_outer, H, W, c_in_inner] - // weight: [c_out_outer, c_filter_outer, filter_h, filter_w, c_filter_inner, c_out_inner] - auto type = input->type(); - std::vector shape_input = input->shape; + // weight: [c_out_outer, c_filter_outer, filter_h, filter_w, c_filter_inner, + // c_out_inner] + auto type = input->type(); + std::vector shape_input = input->shape; std::vector shape_weights = weights->shape; - CHECK_EQ(shape_input.size(), 5U) << "Conv2d_NCHWc input's shape size should be 5"; - CHECK_EQ(shape_weights.size(), 6U) << "Conv2d_NCHWc weight's shape size should be 6"; + CHECK_EQ(shape_input.size(), 5U) + << "Conv2d_NCHWc input's shape size should be 5"; + CHECK_EQ(shape_weights.size(), 6U) + << "Conv2d_NCHWc weight's shape size should be 6"; - Expr batch = shape_input[0]; + Expr batch = shape_input[0]; Expr c_in_outer = common::AutoSimplify(shape_input[1]); - Expr h_in = shape_input[2]; - Expr w_in = shape_input[3]; + Expr h_in = shape_input[2]; + Expr w_in = shape_input[3]; Expr c_in_inner = common::AutoSimplify(shape_input[4]); - Expr c_out_outer = shape_weights[0]; + Expr c_out_outer = shape_weights[0]; Expr c_filter_outer = common::AutoSimplify(shape_weights[1]); - Expr h_f = shape_weights[2]; - Expr w_f = shape_weights[3]; + Expr h_f = shape_weights[2]; + Expr w_f = shape_weights[3]; Expr c_filter_inner = common::AutoSimplify(shape_weights[4]); - Expr c_out_inner = common::AutoSimplify(shape_weights[5]); + Expr c_out_inner = common::AutoSimplify(shape_weights[5]); Expr c_filter = common::AutoSimplify(c_filter_outer * c_filter_inner); - Expr c_out = common::AutoSimplify(c_out_outer * c_out_inner); - Expr c_in = common::AutoSimplify(c_in_outer * c_in_inner); + Expr c_out = common::AutoSimplify(c_out_outer * c_out_inner); + Expr c_in = common::AutoSimplify(c_in_outer * c_in_inner); Var fc(c_filter, UniqName("fc")); Var fy(h_f, UniqName("fy")); Var fx(w_f, UniqName("fx")); std::vector output_shape = { - batch, // B - c_out_outer, // O - common::AutoSimplify((h_in - ((h_f - 1) * dilation_h + 1) + 2 * pad_h) / stride_h + 1), // H - common::AutoSimplify((w_in - ((w_f - 1) * dilation_w + 1) + 2 * pad_w) / stride_w + 1), // W + batch, // B + c_out_outer, // O + common::AutoSimplify((h_in - ((h_f - 1) * dilation_h + 1) + 2 * pad_h) / + stride_h + + 1), // H + common::AutoSimplify((w_in - ((w_f - 1) * dilation_w + 1) + 2 * pad_w) / + stride_w + + 1), // W c_out_inner}; ir::Tensor input_pad; if (pad_h == 0 && pad_w == 0) { input_pad = Compute( input->shape, - [=](Expr n, Expr icc, Expr yy, Expr xx, Expr icb) { return input(n, icc, yy, xx, icb); }, + [=](Expr n, Expr icc, Expr yy, Expr xx, Expr icb) { + return input(n, icc, yy, xx, icb); + }, UniqName("input_pad")); } else { - auto pad_h_bound = common::AutoSimplify((output_shape[2] - 1) * stride_h + (h_f - 1) * dilation_h + 1); - auto pad_w_bound = common::AutoSimplify((output_shape[3] - 1) * stride_w + (w_f - 1) * dilation_w + 1); - auto pad_out_h = std::min(pad_h_bound.as_int32(), common::AutoSimplify(h_in + 2 * pad_h).as_int32()); - auto pad_out_w = std::min(pad_w_bound.as_int32(), common::AutoSimplify(w_in + 2 * pad_w).as_int32()); - auto h_in_pad = common::AutoSimplify(h_in + pad_h); - auto w_in_pad = common::AutoSimplify(w_in + pad_w); - input_pad = Compute( + auto pad_h_bound = common::AutoSimplify((output_shape[2] - 1) * stride_h + + (h_f - 1) * dilation_h + 1); + auto pad_w_bound = common::AutoSimplify((output_shape[3] - 1) * stride_w + + (w_f - 1) * dilation_w + 1); + auto pad_out_h = + std::min(pad_h_bound.as_int32(), + common::AutoSimplify(h_in + 2 * pad_h).as_int32()); + auto pad_out_w = + std::min(pad_w_bound.as_int32(), + common::AutoSimplify(w_in + 2 * pad_w).as_int32()); + auto h_in_pad = common::AutoSimplify(h_in + pad_h); + auto w_in_pad = common::AutoSimplify(w_in + pad_w); + input_pad = Compute( {batch, c_in_outer, Expr(pad_out_h), Expr(pad_out_w), c_in_inner}, [=](Expr n, Expr icc, Expr yy, Expr xx, Expr icb) { auto cond = lang::logic_and({yy >= pad_h, xx >= pad_w}); @@ -462,7 +605,8 @@ std::vector Conv2d_NCHWc(const ir::Tensor &input, if (pad_out_w > w_in_pad.as_int32()) { cond = lang::logic_and({cond, xx < w_in_pad}); } - return ir::Select::Make(cond, input(n, icc, yy - pad_h, xx - pad_w, icb), ir::Zero(type)); + return ir::Select::Make( + cond, input(n, icc, yy - pad_h, xx - pad_w, icb), ir::Zero(type)); }, UniqName("input_pad")); } @@ -476,15 +620,27 @@ std::vector Conv2d_NCHWc(const ir::Tensor &input, ic_outer = common::AutoSimplify(fc / c_in_inner); ic_inner = common::AutoSimplify(fc % c_in_inner); } else { - ic_outer = common::AutoSimplify(((oc_chunk * c_out_inner + oc_block) / c_out_per_group * c_filter + fc) / + ic_outer = common::AutoSimplify(((oc_chunk * c_out_inner + oc_block) / + c_out_per_group * c_filter + + fc) / c_in_inner); - ic_inner = common::AutoSimplify(((oc_chunk * c_out_inner + oc_block) / c_out_per_group * c_filter + fc) % + ic_inner = common::AutoSimplify(((oc_chunk * c_out_inner + oc_block) / + c_out_per_group * c_filter + + fc) % c_in_inner); } - return lang::ReduceSum( - input_pad(n, ic_outer, oh * stride_h + fy * dilation_h, ow * stride_w + fx * dilation_w, ic_inner) * - weights(oc_chunk, fc / c_filter_inner, fy, fx, fc % c_filter_inner, oc_block), - {fc, fy, fx}); + return lang::ReduceSum(input_pad(n, + ic_outer, + oh * stride_h + fy * dilation_h, + ow * stride_w + fx * dilation_w, + ic_inner) * + weights(oc_chunk, + fc / c_filter_inner, + fy, + fx, + fc % c_filter_inner, + oc_block), + {fc, fy, fx}); }, UniqName("conv2d_NCHWc_out")); return {packed_out, input_pad}; @@ -500,8 +656,10 @@ std::vector Conv2d_NCHW_MKLDNN(const ir::Tensor &input, int dilation_h, int dilation_w, const std::string &output_name) { - CHECK_EQ(input->shape.size(), 4U) << "Input's dimension of Conv2d_NCHW op is not 4! Please check."; - CHECK_EQ(weights->shape.size(), 4U) << "Weight's dimension of Conv2d_NCHW op is not 4! Please check."; + CHECK_EQ(input->shape.size(), 4U) + << "Input's dimension of Conv2d_NCHW op is not 4! Please check."; + CHECK_EQ(weights->shape.size(), 4U) + << "Weight's dimension of Conv2d_NCHW op is not 4! Please check."; std::vector output_shape; std::vector new_weights_shape; std::vector input_pad_shape; @@ -547,38 +705,56 @@ std::vector Conv2d_NHWC(const ir::Tensor &input, int dilation_h, int dilation_w, const std::string &output_name) { - CHECK_EQ(input->shape.size(), 4U) << "Input's dimension of Conv2d_NHWC op is not 4! Please check."; - CHECK_EQ(weights->shape.size(), 4U) << "Weight's dimension of Conv2d_NHWC op is not 4! Please check."; + CHECK_EQ(input->shape.size(), 4U) + << "Input's dimension of Conv2d_NHWC op is not 4! Please check."; + CHECK_EQ(weights->shape.size(), 4U) + << "Weight's dimension of Conv2d_NHWC op is not 4! Please check."; std::vector output_shape; std::vector new_weights_shape; std::vector input_pad_shape; output_shape = { - input->shape[0], // B - Expr((input->shape[1] - ((weights->shape[2] - 1) * dilation_h + 1) + 2 * pad_h) / stride_h + 1), // H - Expr((input->shape[2] - ((weights->shape[3] - 1) * dilation_w + 1) + 2 * pad_w) / stride_w + 1), // W - weights->shape[0] // O + input->shape[0], // B + Expr((input->shape[1] - ((weights->shape[2] - 1) * dilation_h + 1) + + 2 * pad_h) / + stride_h + + 1), // H + Expr((input->shape[2] - ((weights->shape[3] - 1) * dilation_w + 1) + + 2 * pad_w) / + stride_w + + 1), // W + weights->shape[0] // O }; new_weights_shape = {weights->shape[0], weights->shape[1], dilation_h * (weights->shape[2] - 1) + 1, dilation_w * (weights->shape[3] - 1) + 1}; - input_pad_shape = {input->shape[0], input->shape[1] + 2 * pad_h, input->shape[2] + 2 * pad_w, input->shape[3]}; - auto input_pad = Compute( + input_pad_shape = {input->shape[0], + input->shape[1] + 2 * pad_h, + input->shape[2] + 2 * pad_w, + input->shape[3]}; + auto input_pad = Compute( input_pad_shape, [=](Expr nn, Expr yy, Expr xx, Expr cc) { - auto cond = - lang::logic_and({yy >= pad_h, yy - pad_h < input->shape[1], xx >= pad_w, xx - pad_w < input->shape[2]}); - return ir::Select::Make(cond, input(nn, yy - pad_h, xx - pad_w, cc), ir::Zero(input->type())); + auto cond = lang::logic_and({yy >= pad_h, + yy - pad_h < input->shape[1], + xx >= pad_w, + xx - pad_w < input->shape[2]}); + return ir::Select::Make(cond, + input(nn, yy - pad_h, xx - pad_w, cc), + ir::Zero(input->type())); }, UniqName("input_pad")); auto weights_dilation = Compute( new_weights_shape, [=](Expr nn, Expr cc, Expr yy, Expr xx) { - auto cond = lang::logic_and({(yy) % dilation_h == 0, xx % dilation_w == 0}); + auto cond = + lang::logic_and({(yy) % dilation_h == 0, xx % dilation_w == 0}); return ir::Select::Make( - cond, weights(nn, cc, yy / dilation_h, xx / dilation_w), common::make_const(weights->type(), 0)); + cond, + weights(nn, cc, yy / dilation_h, xx / dilation_w), + common::make_const(weights->type(), 0)); }, UniqName("weights_dilation")); @@ -586,16 +762,20 @@ std::vector Conv2d_NHWC(const ir::Tensor &input, Var fy(weights_dilation->shape[2], UniqName("fy")); Var fx(weights_dilation->shape[3], UniqName("fx")); - CHECK(MathEqual((weights->shape[0] * weights->shape[1]) % input->shape[3], Expr(0))) + CHECK(MathEqual((weights->shape[0] * weights->shape[1]) % input->shape[3], + Expr(0))) << "filter's output channel size must be divisible by group\n"; auto res = Compute( output_shape, [=](Expr nn, Expr yy, Expr xx, Expr ff) { return lang::ReduceSum( - input_pad(nn, - yy * stride_h + fy, - xx * stride_w + fx, - ff / (weights->shape[0] * weights->shape[1] / input->shape[3]) * weights->shape[1] + fc) * + input_pad( + nn, + yy * stride_h + fy, + xx * stride_w + fx, + ff / (weights->shape[0] * weights->shape[1] / input->shape[3]) * + weights->shape[1] + + fc) * weights_dilation(ff, fc, fy, fx), {fy, fx, fc}); }, @@ -610,11 +790,14 @@ std::vector Depthwise_Conv2d_NCHW(const Tensor &input, int stride_h, int stride_w, const std::string output_name) { - CHECK_EQ(input->shape.size(), 4U) << "Input's dimension of Depthwise_Conv2d_NCHW is not 4! Please check.\n"; - CHECK_EQ(weight->shape.size(), 4U) << "Weight's dimension of Depthwise_Conv2d_NCHW is not 4! Please check.\n"; + CHECK_EQ(input->shape.size(), 4U) + << "Input's dimension of Depthwise_Conv2d_NCHW is not 4! Please check.\n"; + CHECK_EQ(weight->shape.size(), 4U) + << "Weight's dimension of Depthwise_Conv2d_NCHW is not 4! Please " + "check.\n"; Expr in_h = input->shape[2]; Expr in_w = input->shape[3]; - Expr c_m = weight->shape[1]; // channel_multiplier + Expr c_m = weight->shape[1]; // channel_multiplier std::vector output_shape; CHECK(input->shape[0].is_constant()); CHECK(input->shape[1].is_constant()); @@ -624,9 +807,16 @@ std::vector Depthwise_Conv2d_NCHW(const Tensor &input, CHECK(weight->shape[2].is_constant()); CHECK(weight->shape[3].is_constant()); int B = (int)input->shape[0].get_constant(); - int O = (int)weight->shape[1].get_constant() * (int)input->shape[1].get_constant(); - int H = ((int)input->shape[2].get_constant() - (int)weight->shape[2].get_constant() + 2 * pad_h) / stride_h + 1; - int W = ((int)input->shape[3].get_constant() - (int)weight->shape[3].get_constant() + 2 * pad_w) / stride_w + 1; + int O = (int)weight->shape[1].get_constant() * + (int)input->shape[1].get_constant(); + int H = ((int)input->shape[2].get_constant() - + (int)weight->shape[2].get_constant() + 2 * pad_h) / + stride_h + + 1; + int W = ((int)input->shape[3].get_constant() - + (int)weight->shape[3].get_constant() + 2 * pad_w) / + stride_w + + 1; output_shape = { Expr(B), // B Expr(O), // O @@ -634,7 +824,9 @@ std::vector Depthwise_Conv2d_NCHW(const Tensor &input, Expr(W) // W }; auto input_pad = - (pad_h == 0 && pad_w == 0) ? Identity(input).front() : Pad(input, {Expr(0), Expr(0), Expr(pad_h), Expr(pad_w)}); + (pad_h == 0 && pad_w == 0) + ? Identity(input).front() + : Pad(input, {Expr(0), Expr(0), Expr(pad_h), Expr(pad_w)}); Var kernel_h = Var(weight->shape[2], "kh"); Var kernel_w = Var(weight->shape[3], "kw"); @@ -642,9 +834,13 @@ std::vector Depthwise_Conv2d_NCHW(const Tensor &input, auto res = Compute( output_shape, [=](Expr nn, Expr ff, Expr yy, Expr xx) { - return lang::ReduceSum(input_pad(nn, ff / c_m, yy * stride_h + kernel_h, xx * stride_w + kernel_w) * - weight(ff / c_m, ff % c_m, kernel_h, kernel_w), - {kernel_h, kernel_w}); + return lang::ReduceSum( + input_pad(nn, + ff / c_m, + yy * stride_h + kernel_h, + xx * stride_w + kernel_w) * + weight(ff / c_m, ff % c_m, kernel_h, kernel_w), + {kernel_h, kernel_w}); }, output_name); return {res, input_pad}; @@ -657,11 +853,14 @@ std::vector Depthwise_Conv2d_NHWC(const Tensor &input, int stride_h, int stride_w, const std::string output_name) { - CHECK_EQ(input->shape.size(), 4U) << "Input's dimension of Depthwise_Conv2d_NCHW is not 4! Please check.\n"; - CHECK_EQ(weight->shape.size(), 4U) << "Weight's dimension of Depthwise_Conv2d_NCHW is not 4! Please check.\n"; + CHECK_EQ(input->shape.size(), 4U) + << "Input's dimension of Depthwise_Conv2d_NCHW is not 4! Please check.\n"; + CHECK_EQ(weight->shape.size(), 4U) + << "Weight's dimension of Depthwise_Conv2d_NCHW is not 4! Please " + "check.\n"; Expr in_h = input->shape[1]; Expr in_w = input->shape[2]; - Expr c_m = weight->shape[1]; // channel_multiplier + Expr c_m = weight->shape[1]; // channel_multiplier std::vector output_shape; output_shape = { @@ -672,25 +871,31 @@ std::vector Depthwise_Conv2d_NHWC(const Tensor &input, }; auto input_pad = - (pad_h == 0 && pad_w == 0) ? Identity(input).front() : Pad(input, {Expr(0), Expr(pad_h), Expr(pad_w), Expr(0)}); + (pad_h == 0 && pad_w == 0) + ? Identity(input).front() + : Pad(input, {Expr(0), Expr(pad_h), Expr(pad_w), Expr(0)}); Var kernel_h = Var(weight->shape[2], "kh"); Var kernel_w = Var(weight->shape[3], "kw"); - auto res = Compute( + auto res = Compute( output_shape, [=](Expr nn, Expr yy, Expr xx, Expr ff) { - return lang::ReduceSum(input_pad(nn, yy * stride_h + kernel_h, xx * stride_w + kernel_w, ff / c_m) * - weight(ff / c_m, ff % c_m, kernel_h, kernel_w), - {kernel_h, kernel_w}); + return lang::ReduceSum( + input_pad(nn, + yy * stride_h + kernel_h, + xx * stride_w + kernel_w, + ff / c_m) * + weight(ff / c_m, ff % c_m, kernel_h, kernel_w), + {kernel_h, kernel_w}); }, output_name); return {res, input_pad}; } /** - * Can be used as a normalizer function for convolution or fully_connected operations. - * Specified for NCHW layout. - * Math: Y = (X - mean) / sqrt(variance + epsilon) * scale + bias + * Can be used as a normalizer function for convolution or fully_connected + * operations. Specified for NCHW layout. Math: Y = (X - mean) / sqrt(variance + + * epsilon) * scale + bias * @param input The input variable. * @param weights The weights containing mean, variance, scale and bias. * @param epsilon The param epsilon is added to avoid divide zero. @@ -704,16 +909,22 @@ ir::Tensor BatchNorm_NCHW(const ir::Tensor &input, const ir::Tensor &variance, float epsilon, const std::string &output_name) { - CHECK_EQ(input->shape.size(), 4U) << "Input's dimension of BatchNorm op is not 4! Please check."; - CHECK_EQ(scale->shape.size(), 1U) << "Scale's dimension of BatchNorm op is not 1! Please check."; - CHECK_EQ(bias->shape.size(), 1U) << "Bias's dimension of BatchNorm op is not 1! Please check."; - CHECK_EQ(mean->shape.size(), 1U) << "Mean's dimension of BatchNorm op is not 1! Please check."; - CHECK_EQ(variance->shape.size(), 1U) << "Variance's dimension of BatchNorm op is not 1! Please check."; + CHECK_EQ(input->shape.size(), 4U) + << "Input's dimension of BatchNorm op is not 4! Please check."; + CHECK_EQ(scale->shape.size(), 1U) + << "Scale's dimension of BatchNorm op is not 1! Please check."; + CHECK_EQ(bias->shape.size(), 1U) + << "Bias's dimension of BatchNorm op is not 1! Please check."; + CHECK_EQ(mean->shape.size(), 1U) + << "Mean's dimension of BatchNorm op is not 1! Please check."; + CHECK_EQ(variance->shape.size(), 1U) + << "Variance's dimension of BatchNorm op is not 1! Please check."; auto res = Compute( input->shape, [=](Expr n, Expr c, Expr h, Expr w) { return (input(n, c, h, w) - mean(c)) * scale(c) / - lang::Sqrt(variance(c) + common::make_const(input->type(), epsilon)) + + lang::Sqrt(variance(c) + + common::make_const(input->type(), epsilon)) + bias(c); }, UniqName(output_name)); @@ -727,18 +938,24 @@ ir::Tensor BatchNorm_NCHWc(const ir::Tensor &input, const ir::Tensor &variance, float epsilon, const std::string &output_name) { - CHECK_EQ(input->shape.size(), 5U) << "Input's dimension of BatchNorm op is not 5! Please check."; - CHECK_EQ(scale->shape.size(), 1U) << "Scale's dimension of BatchNorm op is not 1! Please check."; - CHECK_EQ(bias->shape.size(), 1U) << "Bias's dimension of BatchNorm op is not 1! Please check."; - CHECK_EQ(mean->shape.size(), 1U) << "Mean's dimension of BatchNorm op is not 1! Please check."; - CHECK_EQ(variance->shape.size(), 1U) << "Variance's dimension of BatchNorm op is not 1! Please check."; + CHECK_EQ(input->shape.size(), 5U) + << "Input's dimension of BatchNorm op is not 5! Please check."; + CHECK_EQ(scale->shape.size(), 1U) + << "Scale's dimension of BatchNorm op is not 1! Please check."; + CHECK_EQ(bias->shape.size(), 1U) + << "Bias's dimension of BatchNorm op is not 1! Please check."; + CHECK_EQ(mean->shape.size(), 1U) + << "Mean's dimension of BatchNorm op is not 1! Please check."; + CHECK_EQ(variance->shape.size(), 1U) + << "Variance's dimension of BatchNorm op is not 1! Please check."; Expr ic_bn = input->shape.back(); - auto res = Compute( + auto res = Compute( input->shape, [=](Expr n, Expr icc, Expr h, Expr w, Expr icb) { Expr new_c = icc * ic_bn + icb; return (input(n, icc, h, w, icb) - mean(new_c)) * scale(new_c) / - lang::Sqrt(variance(new_c) + common::make_const(input->type(), epsilon)) + + lang::Sqrt(variance(new_c) + + common::make_const(input->type(), epsilon)) + bias(new_c); }, UniqName(output_name)); @@ -752,7 +969,9 @@ ir::Tensor BatchNorm_NCHWc(const ir::Tensor &input, * @param output_name The name of output tensor. * @return The calculated output tensor. */ -std::vector Softmax(const ir::Tensor &A, int axis, const std::string &output_name) { +std::vector Softmax(const ir::Tensor &A, + int axis, + const std::string &output_name) { if (axis == -1) { axis = A->shape.size() - 1; } @@ -796,8 +1015,11 @@ std::vector Softmax(const ir::Tensor &A, int axis, const std::string } #ifdef CINN_WITH_MKLDNN -std::vector SoftmaxMKLDNN(const ir::Tensor &A, int axis, const std::string &output_name) { - CHECK_LE(A->shape.size(), 4U) << "Input's dimension of mkldnn softmax op is less than 4! Please check."; +std::vector SoftmaxMKLDNN(const ir::Tensor &A, + int axis, + const std::string &output_name) { + CHECK_LE(A->shape.size(), 4U) + << "Input's dimension of mkldnn softmax op is less than 4! Please check."; if (axis == -1) { axis = A->shape.size() - 1; } @@ -829,26 +1051,25 @@ std::vector SoftmaxMKLDNN(const ir::Tensor &A, int axis, const std:: /** * @brief Perform padding operation. * @param tensor The input tensor. - * @param pad_before Vector of Exprs describing the padding before the respective dimension - * @param pad_after Vector of Exprs describing the padding after the respective dimension + * @param pad_before Vector of Exprs describing the padding before the + * respective dimension + * @param pad_after Vector of Exprs describing the padding after the respective + * dimension * @param pad_value The value to fill padding elements with. Default is zero. * @param name The name of the output padding tensor - * @param pad_mode Padding type to use: "constant" pads with constant_value; "edge" pads using the edge values of the - * input array; "reflect" pads by reflecting values with respect to the edges. + * @param pad_mode Padding type to use: "constant" pads with constant_value; + * "edge" pads using the edge values of the input array; "reflect" pads by + * reflecting values with respect to the edges. * * @return the output tensor after padding. * * @note - * The pad_after vector must either be empty or have the same length as pad_before - * When pad_after is empty, it takes the same values as pad_before (symmetric padding) - * The pad vector applies from the leading dimensions and skips missing trailing dimensions: - * e.g. - * pad(t(i, j, k), {1}, {1}) returns the equivalent operation for - * the following pseudocode: - * for i in [0, t.shape[0] + 2): - * for j in [0, t.shape[0] + 2): - * for k in [0, t.shape[0] + 2): - * name(i,j,k) = + * The pad_after vector must either be empty or have the same length as + * pad_before When pad_after is empty, it takes the same values as pad_before + * (symmetric padding) The pad vector applies from the leading dimensions and + * skips missing trailing dimensions: e.g. pad(t(i, j, k), {1}, {1}) returns the + * equivalent operation for the following pseudocode: for i in [0, t.shape[0] + + * 2): for j in [0, t.shape[0] + 2): for k in [0, t.shape[0] + 2): name(i,j,k) = * i < 1 ? 0 : * ((1 <= i < t.shape[0] + 1) ? * t(i-1, j, k) : 0)); @@ -860,7 +1081,8 @@ Tensor Pad(const Tensor &tensor, Expr pad_value, const std::string &name, const std::string &pad_mode) { - // When pad_after is empty, it takes the same values as pad_before (symmetric padding) + // When pad_after is empty, it takes the same values as pad_before (symmetric + // padding) if (pad_after.size() < pad_before.size()) { for (size_t i = pad_after.size(); i < pad_before.size(); ++i) { pad_after.push_back(pad_before[i]); @@ -879,7 +1101,8 @@ Tensor Pad(const Tensor &tensor, if (i >= pad_before.size()) { output_shape.push_back(tensor->shape[i]); } else { - auto shape = common::AutoSimplify(tensor->shape[i] + pad_before[i] + pad_after[i]); + auto shape = + common::AutoSimplify(tensor->shape[i] + pad_before[i] + pad_after[i]); output_shape.push_back(shape); } } @@ -905,21 +1128,24 @@ Tensor Pad(const Tensor &tensor, } Expr sel_after; if (!MathEqual(pad_after[i], Expr(0))) { - sel_after = common::AutoSimplify(ovars[i] < pad_before[i] + tensor->shape[i]); + sel_after = + common::AutoSimplify(ovars[i] < pad_before[i] + tensor->shape[i]); sel.push_back(sel_after); } if (pad_mode == "edge") { pad_idx.push_back(Select::Make( ovars[i] < pad_before[i], 0, - Select::Make( - ovars[i] >= pad_before[i] + tensor->shape[i], tensor->shape[i] - 1, ovars[i] - pad_before[i]))); + Select::Make(ovars[i] >= pad_before[i] + tensor->shape[i], + tensor->shape[i] - 1, + ovars[i] - pad_before[i]))); } else if (pad_mode == "reflect") { - pad_idx.push_back(Select::Make(ovars[i] < pad_before[i], - pad_before[i] - ovars[i], - Select::Make(ovars[i] >= pad_before[i] + tensor->shape[i], - tensor->shape[i] * 2 - ovars[i] + pad_before[i] - 2, - ovars[i] - pad_before[i]))); + pad_idx.push_back(Select::Make( + ovars[i] < pad_before[i], + pad_before[i] - ovars[i], + Select::Make(ovars[i] >= pad_before[i] + tensor->shape[i], + tensor->shape[i] * 2 - ovars[i] + pad_before[i] - 2, + ovars[i] - pad_before[i]))); } } if (sel.size() != 0) { @@ -927,7 +1153,8 @@ Tensor Pad(const Tensor &tensor, if (pad_mode == "constant") { return Select::Make(FoldExpr(fn, sel), tensor(indices), pad_value); } else if (pad_mode == "edge" || pad_mode == "reflect") { - return Select::Make(FoldExpr(fn, sel), tensor(indices), tensor(pad_idx)); + return Select::Make( + FoldExpr(fn, sel), tensor(indices), tensor(pad_idx)); } } return tensor(indices); @@ -938,14 +1165,17 @@ Tensor Pad(const Tensor &tensor, /** * @brief Perform pooling on N-dimension of data. * - * @param tensor The input tensor with the shape of {N, C, H, W} or {N, H, W, C}. - * @param kernel_size Vector of N ints that indicates pooling kernel size. If N is 2, then is {pool_kernel_Height, - * pool_kernel_Width}. - * @param stride_size Vector of N ints that indicates pooling stride size. If N is 2, then is {pool_stride_Height, - * pool_stride_Width}. - * @param padding_size Vector of N*2 ints {head_pad_d1, head_pad_d2, ..., head_pad_dN, tail_pad_d1, tail_pad_d2, ..., - * tail_pad_dN}. If N is 2, then is {pad_height_top, pad_width_left, pad_height_bottom, pad_width_right]}. - * @param pool_type The type of pooling operator, currently support "max" and "avg". + * @param tensor The input tensor with the shape of {N, C, H, W} or {N, H, W, + * C}. + * @param kernel_size Vector of N ints that indicates pooling kernel size. If N + * is 2, then is {pool_kernel_Height, pool_kernel_Width}. + * @param stride_size Vector of N ints that indicates pooling stride size. If N + * is 2, then is {pool_stride_Height, pool_stride_Width}. + * @param padding_size Vector of N*2 ints {head_pad_d1, head_pad_d2, ..., + * head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN}. If N is 2, then is + * {pad_height_top, pad_width_left, pad_height_bottom, pad_width_right]}. + * @param pool_type The type of pooling operator, currently support "max" and + * "avg". * @param axis Vector of axes of the tensor for pooling. * @param ceil_mode Whether to use ceil when calculating the output size. * @param exclusive Whether include padding in the calculation'. @@ -966,15 +1196,19 @@ std::vector PoolImpl(const Tensor &tensor, CHECK(!kernel_size.empty()) << "Pooling kernel_size should not be empty\n"; int k_size = kernel_size.size(); int x_size = tensor->shape.size(); - CHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel\n"; - CHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must have double elements as kernel\n"; + CHECK_EQ(stride_size.size(), k_size) + << "Pooling stride_size must have same elements as kernel\n"; + CHECK_EQ(padding_size.size(), k_size * 2) + << "Pooling padding_size must have double elements as kernel\n"; CHECK_EQ(axis.size(), k_size) << "Axis must have same elements as kernel\n"; std::string pool_type; - std::transform(pooling_type.begin(), pooling_type.end(), std::back_inserter(pool_type), [](unsigned char c) { - return std::tolower(c); - }); - CHECK(pool_type == "max" || pool_type == "avg") << "pool_type for pool2d should be max or avg.\n"; + std::transform(pooling_type.begin(), + pooling_type.end(), + std::back_inserter(pool_type), + [](unsigned char c) { return std::tolower(c); }); + CHECK(pool_type == "max" || pool_type == "avg") + << "pool_type for pool2d should be max or avg.\n"; std::vector daxis; std::vector kernel(k_size); @@ -987,12 +1221,12 @@ std::vector PoolImpl(const Tensor &tensor, bool do_pad = false; for (int i = 0; i < k_size; i++) { - int ii = axis[i]; - kernel[i] = Expr(kernel_size[i]); - stride[i] = Expr(stride_size[i]); + int ii = axis[i]; + kernel[i] = Expr(kernel_size[i]); + stride[i] = Expr(stride_size[i]); pad_head[i] = Expr(padding_size[i]); pad_tail[i] = Expr(padding_size[i + k_size]); - do_pad = (do_pad) ? do_pad : (padding_size[i] || padding_size[i + k_size]); + do_pad = (do_pad) ? do_pad : (padding_size[i] || padding_size[i + k_size]); if (ceil_mode) { pad_tail[i] = common::AutoSimplify(pad_tail[i] + stride[i] - 1); @@ -1001,9 +1235,12 @@ std::vector PoolImpl(const Tensor &tensor, daxis.emplace_back(Var(kernel[i], UniqName("kernel_idx"))); pad_before[ii] = pad_head[i]; - pad_after[ii] = pad_tail[i]; + pad_after[ii] = pad_tail[i]; - auto out_dim = common::AutoSimplify((tensor->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i]) / stride[i] + 1); + auto out_dim = common::AutoSimplify( + (tensor->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i]) / + stride[i] + + 1); out_shape[ii] = out_dim; } @@ -1014,15 +1251,20 @@ std::vector PoolImpl(const Tensor &tensor, if (pool_type == "max") { Expr min_value = lang::min_value(tensor->type()); // Pad the input tensor with the pad_value of type's minimum value - temp = do_pad ? Pad(tensor, pad_before, pad_after, min_value, UniqName("pad_temp")) : tensor; - res = Compute( + temp = do_pad ? Pad(tensor, + pad_before, + pad_after, + min_value, + UniqName("pad_temp")) + : tensor; + res = Compute( out_shape, [=](const std::vector &output) { std::vector indices; for (auto &var : output) indices.push_back(var); for (int i = 0; i < k_size; i++) { - int ii = axis[i]; + int ii = axis[i]; indices[ii] = output[ii] * stride[i] + daxis[i]; } @@ -1031,15 +1273,16 @@ std::vector PoolImpl(const Tensor &tensor, output_name); } else if (pool_type == "avg") { // Pad the input tensor with pad_value zero - temp = do_pad ? Pad(tensor, pad_before, pad_after, 0, UniqName("pad_temp")) : tensor; - res = Compute( + temp = do_pad ? Pad(tensor, pad_before, pad_after, 0, UniqName("pad_temp")) + : tensor; + res = Compute( out_shape, [=](const std::vector &output) { std::vector indices; for (const Expr &var : output) indices.push_back(var); for (int i = 0; i < k_size; i++) { - int ii = axis[i]; + int ii = axis[i]; indices[ii] = output[ii] * stride[i] + daxis[i]; } @@ -1048,22 +1291,29 @@ std::vector PoolImpl(const Tensor &tensor, std::vector end(k_size); auto temp_factor = make_const(Int(32), 1); for (int i = 0; i < k_size; i++) { - int ii = axis[i]; - start[i] = common::AutoSimplify(output[ii] * stride[i] - pad_head[i]); - end[i] = Min::Make(start[i] + kernel[i], tensor->shape[ii]); - start[i] = Max::Make(start[i], make_const(Int(32), 0)); + int ii = axis[i]; + start[i] = + common::AutoSimplify(output[ii] * stride[i] - pad_head[i]); + end[i] = Min::Make(start[i] + kernel[i], tensor->shape[ii]); + start[i] = Max::Make(start[i], make_const(Int(32), 0)); temp_factor = temp_factor * (end[i] - start[i]); } common::AutoSimplify(temp_factor); Expr divide_factor = Max::Make(temp_factor, make_const(Int(32), 1)); - return lang::ReduceSum(ir::Div::Make(temp(indices), ir::Cast::Make(temp->type(), divide_factor)), {daxis}); + return lang::ReduceSum( + ir::Div::Make(temp(indices), + ir::Cast::Make(temp->type(), divide_factor)), + {daxis}); } else { auto temp_factor = make_const(Int(32), 1); for (int i = 0; i < k_size; i++) { temp_factor = temp_factor * kernel[i]; } common::AutoSimplify(temp_factor); - return lang::ReduceSum(ir::Div::Make(temp(indices), ir::Cast::Make(temp->type(), temp_factor)), daxis); + return lang::ReduceSum( + ir::Div::Make(temp(indices), + ir::Cast::Make(temp->type(), temp_factor)), + daxis); } }, output_name); @@ -1080,12 +1330,15 @@ std::vector PoolImpl(const Tensor &tensor, } VLOG(4) << "PoolImpl out_shape: " << cinn::utils::Join(out_shape, ","); CHECK(!do_pad); - temp = do_pad ? Pad(tensor, pad_before, pad_after, 0, UniqName("pad_temp")) : tensor; + temp = do_pad ? Pad(tensor, pad_before, pad_after, 0, UniqName("pad_temp")) + : tensor; std::vector reduce_axis; for (int i = 0; i < k_size; i++) { - reduce_axis.emplace_back(Var(Expr(static_cast(tensor->shape[axis[i]].get_constant()) / kernel_size[i]), - UniqName("adaptive_reduce"))); + reduce_axis.emplace_back( + Var(Expr(static_cast(tensor->shape[axis[i]].get_constant()) / + kernel_size[i]), + UniqName("adaptive_reduce"))); } res = Compute( @@ -1096,18 +1349,26 @@ std::vector PoolImpl(const Tensor &tensor, for (int i = 0; i < k_size; i++) { indices[axis[i]] = - output[axis[i]] * Expr(static_cast(tensor->shape[axis[i]].get_constant()) / kernel_size[i]) + + output[axis[i]] * + Expr(static_cast( + tensor->shape[axis[i]].get_constant()) / + kernel_size[i]) + reduce_axis[i]; } auto temp_factor = make_const(Int(32), 1); for (int i = 0; i < k_size; i++) { - temp_factor = temp_factor * Expr(static_cast(tensor->shape[axis[i]].get_constant()) / kernel_size[i]); + temp_factor = + temp_factor * + Expr(static_cast(tensor->shape[axis[i]].get_constant()) / + kernel_size[i]); } common::AutoSimplify(temp_factor); Expr divide_factor = Max::Make(temp_factor, make_const(Int(32), 1)); - return lang::ReduceSum(ir::Div::Make(temp(indices), ir::Cast::Make(temp->type(), divide_factor)), - {reduce_axis}); + return lang::ReduceSum( + ir::Div::Make(temp(indices), + ir::Cast::Make(temp->type(), divide_factor)), + {reduce_axis}); }, output_name); } @@ -1135,13 +1396,24 @@ std::vector Pool1d(const Tensor &tensor, } else { LOG(FATAL) << "Unsupported data format: " << data_format << std::endl; } - CHECK_EQ(tensor->shape.size(), 3U) << "pool1d requires tensor's shape_size to be 3\n"; + CHECK_EQ(tensor->shape.size(), 3U) + << "pool1d requires tensor's shape_size to be 3\n"; std::vector axis = {width_axis}; - return PoolImpl( - tensor, kernel_size, stride_size, padding_size, pool_type, axis, ceil_mode, exclusive, false, output_name); + return PoolImpl(tensor, + kernel_size, + stride_size, + padding_size, + pool_type, + axis, + ceil_mode, + exclusive, + false, + output_name); } -std::vector GlobalPool2d(const Tensor &tensor, const std::string &pool_type, const std::string &output_name) { +std::vector GlobalPool2d(const Tensor &tensor, + const std::string &pool_type, + const std::string &output_name) { // TODO 1. check warp shuffle is supported! // TODO 2. using `cub` with NVRTC Expr extend = tensor->shape[2] * tensor->shape[3]; @@ -1149,8 +1421,11 @@ std::vector GlobalPool2d(const Tensor &tensor, const std::string &pool_t auto temp = Compute( {tensor->shape[0], tensor->shape[1], Expr(32)}, [=](Expr n, Expr c, Expr k) -> Expr { - Expr offset = common::IndiceToAbsOffset(tensor->shape, {n, c, Expr(0), Expr(0)}); - return lang::CallExtern("cinn_warp_reduce_max_" + Type2StrForNN(tensor->type()), {tensor, offset, extend}); + Expr offset = common::IndiceToAbsOffset(tensor->shape, + {n, c, Expr(0), Expr(0)}); + return lang::CallExtern( + "cinn_warp_reduce_max_" + Type2StrForNN(tensor->type()), + {tensor, offset, extend}); }, UniqName(output_name + "_temp")); temp->WithBuffer(tensor->type()); @@ -1165,8 +1440,11 @@ std::vector GlobalPool2d(const Tensor &tensor, const std::string &pool_t auto temp = Compute( {tensor->shape[0], tensor->shape[1], Expr(32)}, [=](Expr n, Expr c, Expr k) -> Expr { - Expr offset = common::IndiceToAbsOffset(tensor->shape, {n, c, Expr(0), Expr(0)}); - return lang::CallExtern("cinn_warp_reduce_avg_" + Type2StrForNN(tensor->type()), {tensor, offset, extend}); + Expr offset = common::IndiceToAbsOffset(tensor->shape, + {n, c, Expr(0), Expr(0)}); + return lang::CallExtern( + "cinn_warp_reduce_avg_" + Type2StrForNN(tensor->type()), + {tensor, offset, extend}); }, UniqName(output_name + "_temp")); temp->WithBuffer(tensor->type()); @@ -1194,24 +1472,32 @@ std::vector Pool2d(const Tensor &tensor, bool adaptive, const std::string &output_name) { int height_axis = -1; - int width_axis = -1; + int width_axis = -1; if (data_format == "NCHW") { height_axis = 2; - width_axis = 3; + width_axis = 3; } else if (data_format == "NHWC") { height_axis = 1; - width_axis = 2; + width_axis = 2; } else if (data_format == "AnyLayout") { height_axis = 2; - width_axis = 3; + width_axis = 3; } else { LOG(FATAL) << "Unsupported data format: " << data_format << std::endl; } CHECK(tensor->shape.size() == 4U || tensor->shape.size() == 5U) << "pool2d requires tensor's shape_size to be 4 or 5\n"; std::vector axis = {height_axis, width_axis}; - return PoolImpl( - tensor, kernel_size, stride_size, padding_size, pool_type, axis, ceil_mode, exclusive, adaptive, output_name); + return PoolImpl(tensor, + kernel_size, + stride_size, + padding_size, + pool_type, + axis, + ceil_mode, + exclusive, + adaptive, + output_name); } std::vector Pool3d(const Tensor &tensor, @@ -1224,20 +1510,21 @@ std::vector Pool3d(const Tensor &tensor, const std::string &data_format, const std::string &output_name) { int height_axis = -1; - int width_axis = -1; - int depth_axis = -1; + int width_axis = -1; + int depth_axis = -1; if (data_format == "NCDHW") { - depth_axis = 2; + depth_axis = 2; height_axis = 3; - width_axis = 4; + width_axis = 4; } else if (data_format == "NDHWC") { - depth_axis = 1; + depth_axis = 1; height_axis = 2; - width_axis = 3; + width_axis = 3; } else { LOG(FATAL) << "Unsupported data format: " << data_format << std::endl; } - CHECK_EQ(tensor->shape.size(), 5U) << "pool1d requires tensor's shape_size to be 5\n"; + CHECK_EQ(tensor->shape.size(), 5U) + << "pool1d requires tensor's shape_size to be 5\n"; std::vector axis = {depth_axis, height_axis, width_axis}; return PoolImpl(tensor, kernel_size, @@ -1259,14 +1546,17 @@ Tensor DropoutInfer(const ir::Tensor &tensor, return Compute( tensor->shape, [=](const std::vector &indice) { - return tensor(indice) * common::make_const(tensor->type(), 1 - dropout_prob); + return tensor(indice) * + common::make_const(tensor->type(), 1 - dropout_prob); }, output_name); } else if (dropout_implementation == "upscale_in_train") { - // The name here must be consistent, otherwise it cannot participate in the fusion schedule. + // The name here must be consistent, otherwise it cannot participate in the + // fusion schedule. return Identity(tensor, output_name).front(); } else { - LOG(FATAL) << "dropout_implementation attr must be 'downgrade_in_infer' or 'upscale_in_train'\n"; + LOG(FATAL) << "dropout_implementation attr must be 'downgrade_in_infer' or " + "'upscale_in_train'\n"; } } @@ -1274,13 +1564,16 @@ ir::Tensor Select(const ir::Tensor &condition, const ir::Tensor &true_value, const ir::Tensor &false_value, const std::string &output_name) { - CHECK(condition->type().is_bool()) << "The condition tensor type should be bool!"; - CHECK(condition->shape == true_value->shape && true_value->shape == false_value->shape) + CHECK(condition->type().is_bool()) + << "The condition tensor type should be bool!"; + CHECK(condition->shape == true_value->shape && + true_value->shape == false_value->shape) << "The input tensor shape is not equal!"; return lang::Compute( condition->shape, [=](const std::vector &indice) { - return common::select(condition(indice), true_value(indice), false_value(indice)); + return common::select( + condition(indice), true_value(indice), false_value(indice)); }, output_name); } diff --git a/paddle/cinn/hlir/pe/nn.h b/paddle/cinn/hlir/pe/nn.h index eb98754db0a9b..438f088cd5535 100755 --- a/paddle/cinn/hlir/pe/nn.h +++ b/paddle/cinn/hlir/pe/nn.h @@ -35,7 +35,9 @@ namespace pe { * * @return The result Tensor. */ -ir::Tensor Relu(const ir::Tensor &A, double threshold = 0.0, const std::string &output_name = UniqName("T_Relu_out")); +ir::Tensor Relu(const ir::Tensor &A, + double threshold = 0.0, + const std::string &output_name = UniqName("T_Relu_out")); /** * @brief Rectified Linear Unit bounded by six. @@ -46,7 +48,9 @@ ir::Tensor Relu(const ir::Tensor &A, double threshold = 0.0, const std::string & * * @return The result Tensor. */ -ir::Tensor Relu6(const ir::Tensor &A, double threshold = 0.0, const std::string &output_name = UniqName("T_Relu6_out")); +ir::Tensor Relu6(const ir::Tensor &A, + double threshold = 0.0, + const std::string &output_name = UniqName("T_Relu6_out")); /** * @brief Leaky Rectified Linear Unit. @@ -57,9 +61,10 @@ ir::Tensor Relu6(const ir::Tensor &A, double threshold = 0.0, const std::string * * @return The result Tensor. */ -ir::Tensor LeakyRelu(const ir::Tensor &A, - double alpha = 0.1, - const std::string &output_name = UniqName("T_LeakyRelu_out")); +ir::Tensor LeakyRelu( + const ir::Tensor &A, + double alpha = 0.1, + const std::string &output_name = UniqName("T_LeakyRelu_out")); /** * @brief Leaky Rectified Linear Unit. @@ -73,11 +78,12 @@ ir::Tensor LeakyRelu(const ir::Tensor &A, */ ir::Tensor PRelu(const ir::Tensor &A, const ir::Tensor &slope, - const int axis = 1, + const int axis = 1, const std::string &output_name = UniqName("T_PRelu_out")); /** - * @brief Perform a 2-D convolution with an NCHW-layout using winograd algorithm. + * @brief Perform a 2-D convolution with an NCHW-layout using winograd + * algorithm. * * @param input The 4-D input tensor {N, C_in, H, W} * @param weights The 4-D weight tensor {C_out, C_in/group, filter_h, filter_w} @@ -91,18 +97,20 @@ ir::Tensor PRelu(const ir::Tensor &A, * * @return the output tensor */ -std::vector Conv2d_winograd_NCHW(const ir::Tensor &input, - const ir::Tensor &weights, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - const std::string &output_name = UniqName("T_Conv2d_winograd_NCHW_out")); +std::vector Conv2d_winograd_NCHW( + const ir::Tensor &input, + const ir::Tensor &weights, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + const std::string &output_name = UniqName("T_Conv2d_winograd_NCHW_out")); /** - * @brief Perform a 2-D convolution with an NCHW-layout and support group and depthwise convolution. + * @brief Perform a 2-D convolution with an NCHW-layout and support group and + * depthwise convolution. * * @param input The 4-D input tensor {N, C_in, H, W} * @param weights The 4-D weight tensor {C_out, C_in/group, filter_h, filter_w} @@ -116,34 +124,37 @@ std::vector Conv2d_winograd_NCHW(const ir::Tensor &input, * * @return the output tensor */ -std::vector Conv2d_NCHW(const ir::Tensor &input, - const ir::Tensor &weights, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - const std::string &output_name = UniqName("T_Conv2d_NCHW_out"), - bool choose_direct_compute = false); +std::vector Conv2d_NCHW( + const ir::Tensor &input, + const ir::Tensor &weights, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + const std::string &output_name = UniqName("T_Conv2d_NCHW_out"), + bool choose_direct_compute = false); -std::vector Conv2d_NCHW_5D(const ir::Tensor &input, - const ir::Tensor &weights, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - std::string key, - const std::string &output_name = UniqName("T_Conv2d_NCHW_5D_out"), - const common::Target &target = common::DefaultHostTarget()); +std::vector Conv2d_NCHW_5D( + const ir::Tensor &input, + const ir::Tensor &weights, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + std::string key, + const std::string &output_name = UniqName("T_Conv2d_NCHW_5D_out"), + const common::Target &target = common::DefaultHostTarget()); /** * @brief Perform a 2-D convolution with an NCHWc-layout. * * @param input The 5-D input tensor {N, C_in_outer, H, W, C_in_inner} - * @param weight The 6-D weight tensor {C_out_outer, C_filter_outer, filter_h, filter_w, C_filter_inner, C_out_inner} + * @param weight The 6-D weight tensor {C_out_outer, C_filter_outer, filter_h, + * filter_w, C_filter_inner, C_out_inner} * @param pad_h padding applied to the height of the image, default is 0 * @param pad_w padding applied to the width of the image, default is 0 * @param stride_h striding applied to the height of the image, default is 1 @@ -155,31 +166,34 @@ std::vector Conv2d_NCHW_5D(const ir::Tensor &input, * * @return the output tensor */ -std::vector Conv2d_NCHWc(const ir::Tensor &input, - const ir::Tensor &weights, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - const std::string &output_name = UniqName("T_Conv2d_NCHWc_out"), - const common::Target &target = common::DefaultHostTarget()); +std::vector Conv2d_NCHWc( + const ir::Tensor &input, + const ir::Tensor &weights, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + const std::string &output_name = UniqName("T_Conv2d_NCHWc_out"), + const common::Target &target = common::DefaultHostTarget()); #ifdef CINN_WITH_MKLDNN -std::vector Conv2d_NCHW_MKLDNN(const ir::Tensor &input, - const ir::Tensor &weights, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - const std::string &output_name = UniqName("T_Conv2d_NCHW_out")); +std::vector Conv2d_NCHW_MKLDNN( + const ir::Tensor &input, + const ir::Tensor &weights, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + const std::string &output_name = UniqName("T_Conv2d_NCHW_out")); #endif /** - * @brief Perform a 2-D convolution with an NHWC-layout and support group and depthwise convolution. + * @brief Perform a 2-D convolution with an NHWC-layout and support group and + * depthwise convolution. * * @param input The 4-D input tensor {N, H, W, C_in} * @param weight The 4-D weight tensor {C_out, C_in/group, filter_h, filter_w} @@ -194,99 +208,113 @@ std::vector Conv2d_NCHW_MKLDNN(const ir::Tensor &input, * * @return the output tensors */ -std::vector Conv2d_NHWC(const ir::Tensor &input, - const ir::Tensor &weights, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - const std::string &output_name = UniqName("T_Conv2d_NHWC_out")); +std::vector Conv2d_NHWC( + const ir::Tensor &input, + const ir::Tensor &weights, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + const std::string &output_name = UniqName("T_Conv2d_NHWC_out")); /** * @brief Perform a 2-D depthwise convolution with an NCHW-layout * * @param input The 4-D input tensor {N, C_in, H, W} - * @param weight The 4-D weight tensor {C_in, channel_multiplier, filter_h, filter_w} - * @param pad_h padding counts applied to the height of the image, before and after (symmetric padding) - * @param pad_w padding counts applied to the width of the image, before and after (symmetric padding) - * @param stride_h striding counts applied to the height of the image, before and after (symmetric padding) - * @param stride_w striding counts applied to the width of the image, before and after (symmetric padding) + * @param weight The 4-D weight tensor {C_in, channel_multiplier, filter_h, + * filter_w} + * @param pad_h padding counts applied to the height of the image, before and + * after (symmetric padding) + * @param pad_w padding counts applied to the width of the image, before and + * after (symmetric padding) + * @param stride_h striding counts applied to the height of the image, before + * and after (symmetric padding) + * @param stride_w striding counts applied to the width of the image, before and + * after (symmetric padding) * @param output_shapes The shape of the output tensors * @param output_name The name of the output tensors * * @return the output tensor */ -std::vector Depthwise_Conv2d_NCHW(const ir::Tensor &input, - const ir::Tensor &weight, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - const std::string output_name = UniqName("T_depthwise_conv2d_nchw")); +std::vector Depthwise_Conv2d_NCHW( + const ir::Tensor &input, + const ir::Tensor &weight, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + const std::string output_name = UniqName("T_depthwise_conv2d_nchw")); /** * @brief Perform a 2-D depthwise convolution with an NHWC-layout * * @param input The 4-D input tensor {N, H, W, C_in} - * @param weight The 4-D weight tensor {C_in, channel_multiplier, filter_h, filter_w} - * @param pad_h padding counts applied to the height of the image, before and after (symmetric padding) - * @param pad_w padding counts applied to the width of the image, before and after (symmetric padding) - * @param stride_h striding counts applied to the height of the image, before and after (symmetric padding) - * @param stride_w striding counts applied to the width of the image, before and after (symmetric padding) + * @param weight The 4-D weight tensor {C_in, channel_multiplier, filter_h, + * filter_w} + * @param pad_h padding counts applied to the height of the image, before and + * after (symmetric padding) + * @param pad_w padding counts applied to the width of the image, before and + * after (symmetric padding) + * @param stride_h striding counts applied to the height of the image, before + * and after (symmetric padding) + * @param stride_w striding counts applied to the width of the image, before and + * after (symmetric padding) * @param output_shapes The shape of the output tensors * @param output_name The name of the output tensor * * @return the output tensors */ -std::vector Depthwise_Conv2d_NHWC(const ir::Tensor &input, - const ir::Tensor &weight, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - const std::string output_name = UniqName("T_depthwise_conv2d_nhwc")); +std::vector Depthwise_Conv2d_NHWC( + const ir::Tensor &input, + const ir::Tensor &weight, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + const std::string output_name = UniqName("T_depthwise_conv2d_nhwc")); -ir::Tensor BatchNorm_NCHW(const ir::Tensor &input, - const ir::Tensor &scale, - const ir::Tensor &bias, - const ir::Tensor &mean, - const ir::Tensor &variance, - float epsilon, - const std::string &output_name = UniqName("T_BatchNorm_NCHW_out")); +ir::Tensor BatchNorm_NCHW( + const ir::Tensor &input, + const ir::Tensor &scale, + const ir::Tensor &bias, + const ir::Tensor &mean, + const ir::Tensor &variance, + float epsilon, + const std::string &output_name = UniqName("T_BatchNorm_NCHW_out")); -ir::Tensor BatchNorm_NCHWc(const ir::Tensor &input, - const ir::Tensor &scale, - const ir::Tensor &bias, - const ir::Tensor &mean, - const ir::Tensor &variance, - float epsilon, - const std::string &output_name = UniqName("T_BatchNorm_NCHWc_out")); +ir::Tensor BatchNorm_NCHWc( + const ir::Tensor &input, + const ir::Tensor &scale, + const ir::Tensor &bias, + const ir::Tensor &mean, + const ir::Tensor &variance, + float epsilon, + const std::string &output_name = UniqName("T_BatchNorm_NCHWc_out")); /** * @brief Perform padding operation. * @param tensor The input tensor. - * @param pad_before Vector of Exprs describing the padding before the respective dimension - * @param pad_after Vector of Exprs describing the padding after the respective dimension + * @param pad_before Vector of Exprs describing the padding before the + * respective dimension + * @param pad_after Vector of Exprs describing the padding after the respective + * dimension * @param pad_value The value to fill padding elements with. Default is zero. * @param name The name of the output padding tensor - * @param pad_mode Padding type to use: "constant" pads with constant_value; "edge" pads using the edge values of the - * input array; "reflect" pads by reflecting values with respect to the edges. + * @param pad_mode Padding type to use: "constant" pads with constant_value; + * "edge" pads using the edge values of the input array; "reflect" pads by + * reflecting values with respect to the edges. * * @return the output tensor after padding. * * @note - * The pad_after vector must either be empty or have the same length as pad_before - * When pad_after is empty, it takes the same values as pad_before (symmetric padding) - * The pad vector applies from the leading dimensions and skips missing trailing dimensions: - * e.g. - * pad(t(i, j, k), {1}, {1}) returns the equivalent operation for - * the following pseudocode: - * for i in [0, t.shape[0] + 2): - * for j in [0, t.shape[0] + 2): - * for k in [0, t.shape[0] + 2): - * name(i,j,k) = + * The pad_after vector must either be empty or have the same length as + * pad_before When pad_after is empty, it takes the same values as pad_before + * (symmetric padding) The pad vector applies from the leading dimensions and + * skips missing trailing dimensions: e.g. pad(t(i, j, k), {1}, {1}) returns the + * equivalent operation for the following pseudocode: for i in [0, t.shape[0] + + * 2): for j in [0, t.shape[0] + 2): for k in [0, t.shape[0] + 2): name(i,j,k) = * i < 1 ? 0 : * ((1 <= i < t.shape[0] + 1) ? * t(i-1, j, k) : 0)); @@ -295,115 +323,138 @@ ir::Tensor BatchNorm_NCHWc(const ir::Tensor &input, ir::Tensor Pad(const ir::Tensor &tensor, const std::vector &pad_before, std::vector pad_after = std::vector(), - Expr pad_value = Expr(), - const std::string &name = UniqName("T_pad_out"), + Expr pad_value = Expr(), + const std::string &name = UniqName("T_pad_out"), const std::string &pad_mode = "constant"); -std::vector Softmax(const ir::Tensor &A, - int axis = -1, - const std::string &output_name = UniqName("T_softmax_out")); +std::vector Softmax( + const ir::Tensor &A, + int axis = -1, + const std::string &output_name = UniqName("T_softmax_out")); #ifdef CINN_WITH_MKLDNN -std::vector SoftmaxMKLDNN(const ir::Tensor &A, - int axis = -1, - const std::string &output_name = UniqName("T_softmax_out")); +std::vector SoftmaxMKLDNN( + const ir::Tensor &A, + int axis = -1, + const std::string &output_name = UniqName("T_softmax_out")); #endif /** * @brief Perform pooling on the width dimension of the tensor. - * Width axis is determined by the data_format string in which 'W' means width. Only support NCW and NWC - * data_format. + * Width axis is determined by the data_format string in which 'W' means + * width. Only support NCW and NWC data_format. * @param tensor The input tensor with shape of {N, C, W} or {N, W, C} * @param kernel_size Vector of ints: {pool_kernel_width} * @param stride_size Vector of ints: {pool_stride_width} * @param padding_size Vector of ints: {head_pad_width, tail_pad_width} - * @param pool_type The type of pooling operator, currently support "max" and "avg". Default is "max". - * @param ceil_mode Whether to use ceil when calculating the output size. Default is false. + * @param pool_type The type of pooling operator, currently support "max" and + * "avg". Default is "max". + * @param ceil_mode Whether to use ceil when calculating the output size. + * Default is false. * @param exclusive Whether include padding in the calculation. Default is True. - * @param data_format The input data format. Only support NCW and NWC data_format. + * @param data_format The input data format. Only support NCW and NWC + * data_format. * @param output_name the name of the output tensor after padding and pooling. * * @return the vector of padding tensor and pooling tensor. */ -std::vector Pool1d(const ir::Tensor &tensor, - const std::vector &kernel_size, - const std::vector &stride_size, - const std::vector &padding_size, - const std::string &pool_type = "max", - bool ceil_mode = false, - bool exclusive = true, - const std::string &data_format = "NCW", - const std::string &output_name = UniqName("T_Pool1d_out")); +std::vector Pool1d( + const ir::Tensor &tensor, + const std::vector &kernel_size, + const std::vector &stride_size, + const std::vector &padding_size, + const std::string &pool_type = "max", + bool ceil_mode = false, + bool exclusive = true, + const std::string &data_format = "NCW", + const std::string &output_name = UniqName("T_Pool1d_out")); /** * @brief Perform pooling on the height and width dimension of the tensor. - * Height and width axes are determined by the data_format string in which 'H' means height and 'W' means width. - * Only support NCHW and NHWC data_format. + * Height and width axes are determined by the data_format string in + * which 'H' means height and 'W' means width. Only support NCHW and NHWC + * data_format. * @param tensor The input tensor with shape of {N, C, H, W} or {N, H, W, C} * @param kernel_size Vector of ints: {pool_kernel_height, pool_kernel_width} * @param stride_size Vector of ints: {pool_stride_height, pool_stride_width} - * @param padding_size Vector of ints: {head_pad_height, head_pad_width, tail_pad_height, tail_pad_width} - * @param pool_type The type of pooling operator, currently support "max" and "avg". Default is "max". - * @param ceil_mode Whether to use ceil when calculating the output size. Default is false. + * @param padding_size Vector of ints: {head_pad_height, head_pad_width, + * tail_pad_height, tail_pad_width} + * @param pool_type The type of pooling operator, currently support "max" and + * "avg". Default is "max". + * @param ceil_mode Whether to use ceil when calculating the output size. + * Default is false. * @param exclusive Whether include padding in the calculation. Default is True. - * @param data_format The input data format. Only support NCHW and NHWC data_format. + * @param data_format The input data format. Only support NCHW and NHWC + * data_format. * @param output_name the name of the output tensor after padding and pooling. * * @return the vector of padding tensor and pooling tensor. */ -std::vector Pool2d(const ir::Tensor &tensor, - const std::vector &kernel_size, - const std::vector &stride_size, - const std::vector &padding_size, - const std::string &pool_type = "max", - bool ceil_mode = false, - bool exclusive = true, - const std::string &data_format = "NCHW", - bool adaptive = false, - const std::string &output_name = UniqName("T_Pool2d_out")); +std::vector Pool2d( + const ir::Tensor &tensor, + const std::vector &kernel_size, + const std::vector &stride_size, + const std::vector &padding_size, + const std::string &pool_type = "max", + bool ceil_mode = false, + bool exclusive = true, + const std::string &data_format = "NCHW", + bool adaptive = false, + const std::string &output_name = UniqName("T_Pool2d_out")); std::vector GlobalPool2d(const ir::Tensor &tensor, const std::string &pool_type, const std::string &output_name); /** - * @brief Perform pooling on the depth, height and width dimension of the tensor. - * Depth, height and width axis is determined by the data_format string in which 'D' means depth, 'H' means - * height and 'W' means width. Only support NCDHW and NDHWC data_format. - * @param tensor The input tensor with shape of {N, C, D, H, W} or {N, D, H, W, C} - * @param kernel_size Vector of ints: {pool_kernel_depth, pool_kernel_height, pool_kernel_width} - * @param stride_size Vector of ints: {pool_stride_depth, pool_stride_height, pool_stride_width} - * @param padding_size Vector of ints: {head_pad_depth, head_pad_height, head_pad_width, tail_pad_depth, - * tail_pad_height, tail_pad_width} - * @param pool_type The type of pooling operator, currently support "max" and "avg". Default is "max". - * @param ceil_mode Whether to use ceil when calculating the output size. Default is false. + * @brief Perform pooling on the depth, height and width dimension of the + * tensor. Depth, height and width axis is determined by the data_format string + * in which 'D' means depth, 'H' means height and 'W' means width. Only support + * NCDHW and NDHWC data_format. + * @param tensor The input tensor with shape of {N, C, D, H, W} or {N, D, H, W, + * C} + * @param kernel_size Vector of ints: {pool_kernel_depth, pool_kernel_height, + * pool_kernel_width} + * @param stride_size Vector of ints: {pool_stride_depth, pool_stride_height, + * pool_stride_width} + * @param padding_size Vector of ints: {head_pad_depth, head_pad_height, + * head_pad_width, tail_pad_depth, tail_pad_height, tail_pad_width} + * @param pool_type The type of pooling operator, currently support "max" and + * "avg". Default is "max". + * @param ceil_mode Whether to use ceil when calculating the output size. + * Default is false. * @param exclusive Whether include padding in the calculation. Default is True. - * @param data_format The input data format. Only support NCDHW and NDHWC data_format. + * @param data_format The input data format. Only support NCDHW and NDHWC + * data_format. * @param output_name the name of the output tensor after padding and pooling. */ -std::vector Pool3d(const ir::Tensor &tensor, - const std::vector &kernel_size, - const std::vector &stride_size, - const std::vector &padding_size, - const std::string &pool_type = "max", - bool ceil_mode = false, - bool exclusive = true, - const std::string &data_format = "NCDHW", - const std::string &output_name = UniqName("T_Pool3d_out")); +std::vector Pool3d( + const ir::Tensor &tensor, + const std::vector &kernel_size, + const std::vector &stride_size, + const std::vector &padding_size, + const std::string &pool_type = "max", + bool ceil_mode = false, + bool exclusive = true, + const std::string &data_format = "NCDHW", + const std::string &output_name = UniqName("T_Pool3d_out")); /** - * @brief Perform dropout in the inference which will downgrade the outcome at inference or keep the same. + * @brief Perform dropout in the inference which will downgrade the outcome at + * inference or keep the same. * @param tensor The input tensor * @param dropout_prob float. Probability of setting units to zero. - * @param dropout_implementation ['downgrade_in_infer'(default)|'upscale_in_train'] + * @param dropout_implementation + * ['downgrade_in_infer'(default)|'upscale_in_train'] * 1. downgrade_in_infer(default), downgrade the outcome at inference * out = input * (1.0 - dropout_prob) * 2. upscale_in_train, keep the same * out = input * @param output_name the name of the output tensor. */ -ir::Tensor DropoutInfer(const ir::Tensor &tensor, - float dropout_prob, - const std::string &dropout_implementation = "downgrade_in_infer", - const std::string &output_name = UniqName("T_Dropout_infer_out")); +ir::Tensor DropoutInfer( + const ir::Tensor &tensor, + float dropout_prob, + const std::string &dropout_implementation = "downgrade_in_infer", + const std::string &output_name = UniqName("T_Dropout_infer_out")); /** * @brief Perform Select for meta op 'Select'. diff --git a/paddle/cinn/hlir/pe/nn_util.cc b/paddle/cinn/hlir/pe/nn_util.cc index 3a7d4cbce516f..30d316bcaf845 100644 --- a/paddle/cinn/hlir/pe/nn_util.cc +++ b/paddle/cinn/hlir/pe/nn_util.cc @@ -24,23 +24,35 @@ namespace pe { using cinn::lang::Compute; using ir::Tensor; -std::vector>> get_winograd_val(const int& tile_size, const int& kernel_size) { - std::unordered_map>>> all_vals; +std::vector>> get_winograd_val( + const int& tile_size, const int& kernel_size) { + std::unordered_map>>> + all_vals; { std::string keys = "2+3"; std::vector>> nums; - std::vector> A = {{1., 0.}, {1., -1.}, {1., 1.}, {0., 1.}}; + std::vector> A = { + {1., 0.}, {1., -1.}, {1., 1.}, {0., 1.}}; nums.push_back(A); - std::vector> B = {{1., 0., 0., 0.}, {0., -1., 1., -1.}, {-1., 1., 1., 0.}, {0., 0., 0., 1.}}; + std::vector> B = {{1., 0., 0., 0.}, + {0., -1., 1., -1.}, + {-1., 1., 1., 0.}, + {0., 0., 0., 1.}}; nums.push_back(B); - std::vector> G = {{1., 0., 0.}, {0.5, -0.5, 0.5}, {0.5, 0.5, 0.5}, {0., 0., 1.}}; + std::vector> G = { + {1., 0., 0.}, {0.5, -0.5, 0.5}, {0.5, 0.5, 0.5}, {0., 0., 1.}}; nums.push_back(G); all_vals[keys] = nums; } { std::string keys = "2+5"; std::vector>> nums; - std::vector> A = {{1.0, 0.0}, {1.0, -1.0}, {1.0, 1.0}, {1.0, 0.5}, {1.0, -2.0}, {0.0, 1.0}}; + std::vector> A = {{1.0, 0.0}, + {1.0, -1.0}, + {1.0, 1.0}, + {1.0, 0.5}, + {1.0, -2.0}, + {0.0, 1.0}}; nums.push_back(A); std::vector> B = {{1.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {-1.5, 1.0, -1.0, -2.0, 0.5, 1.0}, @@ -49,30 +61,52 @@ std::vector>> get_winograd_val(const int& tile_si {1.0, 1.0, 1.0, 1.0, 1.0, 1.5}, {0.0, 0.0, 0.0, 0.0, 0.0, 1.0}}; nums.push_back(B); - std::vector> G = { - {1.0, 0.0, 0.0, 0.0, 0.0}, - {-0.3333333333333333, 0.3333333333333333, -0.3333333333333333, 0.3333333333333333, -0.3333333333333333}, - {0.3333333333333333, 0.3333333333333333, 0.3333333333333333, 0.3333333333333333, 0.3333333333333333}, - {-1.0666666666666667, -0.5333333333333333, -0.26666666666666666, -0.13333333333333333, -0.06666666666666667}, - {0.06666666666666667, -0.13333333333333333, 0.26666666666666666, -0.5333333333333333, 1.0666666666666667}, - {0.0, 0.0, 0.0, 0.0, 1.0}}; + std::vector> G = {{1.0, 0.0, 0.0, 0.0, 0.0}, + {-0.3333333333333333, + 0.3333333333333333, + -0.3333333333333333, + 0.3333333333333333, + -0.3333333333333333}, + {0.3333333333333333, + 0.3333333333333333, + 0.3333333333333333, + 0.3333333333333333, + 0.3333333333333333}, + {-1.0666666666666667, + -0.5333333333333333, + -0.26666666666666666, + -0.13333333333333333, + -0.06666666666666667}, + {0.06666666666666667, + -0.13333333333333333, + 0.26666666666666666, + -0.5333333333333333, + 1.0666666666666667}, + {0.0, 0.0, 0.0, 0.0, 1.0}}; nums.push_back(G); all_vals[keys] = nums; } { std::string keys = "2+7"; std::vector>> nums; - std::vector> A = { - {1.0, 0.0}, {1.0, -1.0}, {1.0, 1.0}, {1.0, 0.5}, {1.0, -0.5}, {1.0, 2.0}, {1.0, -2.0}, {0.0, 1.0}}; + std::vector> A = {{1.0, 0.0}, + {1.0, -1.0}, + {1.0, 1.0}, + {1.0, 0.5}, + {1.0, -0.5}, + {1.0, 2.0}, + {1.0, -2.0}, + {0.0, 1.0}}; nums.push_back(A); - std::vector> B = {{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, - {0.0, -1.0, 1.0, 2.0, -2.0, 0.5, -0.5, -1.0}, - {-5.25, 1.0, 1.0, 4.0, 4.0, 0.25, 0.25, 0.0}, - {0.0, 4.25, -4.25, -2.5, 2.5, -2.5, 2.5, 5.25}, - {5.25, -4.25, -4.25, -5.0, -5.0, -1.25, -1.25, 0.0}, - {0.0, -1.0, 1.0, 0.5, -0.5, 2.0, -2.0, -5.25}, - {-1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0}, - {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0}}; + std::vector> B = { + {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, + {0.0, -1.0, 1.0, 2.0, -2.0, 0.5, -0.5, -1.0}, + {-5.25, 1.0, 1.0, 4.0, 4.0, 0.25, 0.25, 0.0}, + {0.0, 4.25, -4.25, -2.5, 2.5, -2.5, 2.5, 5.25}, + {5.25, -4.25, -4.25, -5.0, -5.0, -1.25, -1.25, 0.0}, + {0.0, -1.0, 1.0, 0.5, -0.5, 2.0, -2.0, -5.25}, + {-1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0}, + {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0}}; nums.push_back(B); std::vector> G = {{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {-0.2222222222222222, @@ -138,12 +172,13 @@ std::vector>> get_winograd_val(const int& tile_si {1.0, 1.0, 1.0, 1.0, 1.0, 1.5}, {0.0, 0.0, 0.0, 0.0, 0.0, 1.0}}; nums.push_back(B); - std::vector> G = {{1.0, 0.0, 0.0}, - {-0.3333333333333333, 0.3333333333333333, -0.3333333333333333}, - {0.3333333333333333, 0.3333333333333333, 0.3333333333333333}, - {-1.0666666666666667, -0.5333333333333333, -0.26666666666666666}, - {0.06666666666666667, -0.13333333333333333, 0.26666666666666666}, - {0.0, 0.0, 1.0}}; + std::vector> G = { + {1.0, 0.0, 0.0}, + {-0.3333333333333333, 0.3333333333333333, -0.3333333333333333}, + {0.3333333333333333, 0.3333333333333333, 0.3333333333333333}, + {-1.0666666666666667, -0.5333333333333333, -0.26666666666666666}, + {0.06666666666666667, -0.13333333333333333, 0.26666666666666666}, + {0.0, 0.0, 1.0}}; nums.push_back(G); all_vals[keys] = nums; } @@ -159,24 +194,48 @@ std::vector>> get_winograd_val(const int& tile_si {1.0, -2.0, 4.0, -8.0}, {0.0, 0.0, 0.0, 1.0}}; nums.push_back(A); - std::vector> B = {{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, - {0.0, -1.0, 1.0, 2.0, -2.0, 0.5, -0.5, -1.0}, - {-5.25, 1.0, 1.0, 4.0, 4.0, 0.25, 0.25, 0.0}, - {0.0, 4.25, -4.25, -2.5, 2.5, -2.5, 2.5, 5.25}, - {5.25, -4.25, -4.25, -5.0, -5.0, -1.25, -1.25, 0.0}, - {0.0, -1.0, 1.0, 0.5, -0.5, 2.0, -2.0, -5.25}, - {-1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0}, - {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0}}; + std::vector> B = { + {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, + {0.0, -1.0, 1.0, 2.0, -2.0, 0.5, -0.5, -1.0}, + {-5.25, 1.0, 1.0, 4.0, 4.0, 0.25, 0.25, 0.0}, + {0.0, 4.25, -4.25, -2.5, 2.5, -2.5, 2.5, 5.25}, + {5.25, -4.25, -4.25, -5.0, -5.0, -1.25, -1.25, 0.0}, + {0.0, -1.0, 1.0, 0.5, -0.5, 2.0, -2.0, -5.25}, + {-1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0}, + {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0}}; nums.push_back(B); - std::vector> G = { - {1.0, 0.0, 0.0, 0.0, 0.0}, - {-0.2222222222222222, 0.2222222222222222, -0.2222222222222222, 0.2222222222222222, -0.2222222222222222}, - {-0.2222222222222222, -0.2222222222222222, -0.2222222222222222, -0.2222222222222222, -0.2222222222222222}, - {0.7111111111111111, 0.35555555555555557, 0.17777777777777778, 0.08888888888888889, 0.044444444444444446}, - {0.7111111111111111, -0.35555555555555557, 0.17777777777777778, -0.08888888888888889, 0.044444444444444446}, - {0.011111111111111112, 0.022222222222222223, 0.044444444444444446, 0.08888888888888889, 0.17777777777777778}, - {0.011111111111111112, -0.022222222222222223, 0.044444444444444446, -0.08888888888888889, 0.17777777777777778}, - {0.0, 0.0, 0.0, 0.0, 1.0}}; + std::vector> G = {{1.0, 0.0, 0.0, 0.0, 0.0}, + {-0.2222222222222222, + 0.2222222222222222, + -0.2222222222222222, + 0.2222222222222222, + -0.2222222222222222}, + {-0.2222222222222222, + -0.2222222222222222, + -0.2222222222222222, + -0.2222222222222222, + -0.2222222222222222}, + {0.7111111111111111, + 0.35555555555555557, + 0.17777777777777778, + 0.08888888888888889, + 0.044444444444444446}, + {0.7111111111111111, + -0.35555555555555557, + 0.17777777777777778, + -0.08888888888888889, + 0.044444444444444446}, + {0.011111111111111112, + 0.022222222222222223, + 0.044444444444444446, + 0.08888888888888889, + 0.17777777777777778}, + {0.011111111111111112, + -0.022222222222222223, + 0.044444444444444446, + -0.08888888888888889, + 0.17777777777777778}, + {0.0, 0.0, 0.0, 0.0, 1.0}}; nums.push_back(G); all_vals[keys] = nums; } @@ -197,12 +256,57 @@ std::vector>> get_winograd_val(const int& tile_si std::vector> B = { {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {3.75, 1.0, -1.0, -2.0, 2.0, -0.5, 0.5, 4.0, -0.25, 1.0}, - {-6.25, 2.75, -4.75, -11.5, 3.4999999999999996, -2.125, 1.625, -1.0, -1.0, 3.75}, - {-19.6875, -9.0, 1.5, -10.5, -19.5, 2.0625, -3.9375, -21.0, 1.3125, -6.25}, - {10.5, -10.6875, 21.1875, 18.375, -0.375, 10.875, -7.875, 5.25, 5.25, -19.6875}, - {19.6875, 21.1875, 10.6875, 15.75, 21.75, 0.1875, 9.1875, 21.0, -1.3125, 10.5}, + {-6.25, + 2.75, + -4.75, + -11.5, + 3.4999999999999996, + -2.125, + 1.625, + -1.0, + -1.0, + 3.75}, + {-19.6875, + -9.0, + 1.5, + -10.5, + -19.5, + 2.0625, + -3.9375, + -21.0, + 1.3125, + -6.25}, + {10.5, + -10.6875, + 21.1875, + 18.375, + -0.375, + 10.875, + -7.875, + 5.25, + 5.25, + -19.6875}, + {19.6875, + 21.1875, + 10.6875, + 15.75, + 21.75, + 0.1875, + 9.1875, + 21.0, + -1.3125, + 10.5}, {-6.25, -1.5, -9.0, -7.875, -4.125, -9.75, 5.25, -5.25, -5.25, 19.6875}, - {-3.75, -4.75, -2.75, -3.25, -4.25, -1.7499999999999998, -5.75, -4.0, 0.25, -6.25}, + {-3.75, + -4.75, + -2.75, + -3.25, + -4.25, + -1.7499999999999998, + -5.75, + -4.0, + 0.25, + -6.25}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -3.7500000000000004}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0}}; nums.push_back(B); @@ -267,25 +371,30 @@ std::vector>> get_winograd_val(const int& tile_si nums.push_back(G); all_vals[keys] = nums; } - std::string keys = std::to_string(tile_size) + "+" + std::to_string(kernel_size); + std::string keys = + std::to_string(tile_size) + "+" + std::to_string(kernel_size); return all_vals[keys]; } -ir::Tensor const_matrix(const std::vector>& input, const std::string& name) { - int row = input.size(); - int col = input[0].size(); +ir::Tensor const_matrix(const std::vector>& input, + const std::string& name) { + int row = input.size(); + int col = input[0].size(); std::vector tensor_shape = {Expr(row), Expr(col)}; - auto result = Compute( + auto result = Compute( tensor_shape, [=](Expr yy, Expr xx) { auto now = cinn::common::make_const(1.0f); for (int ii = 0; ii < row; ii++) { for (int jj = 0; jj < col; jj++) { - // if (common::is_zero(Expr(ii)-yy) && common::is_zero(Expr(jj)-xx)) { + // if (common::is_zero(Expr(ii)-yy) && common::is_zero(Expr(jj)-xx)) + // { // now = cinn::common::make_const(input[ii][jj]); // } - auto cond = common::and_all({Expr(ii) - yy == 0, Expr(jj) - xx == 0}); - now = common::select(cond, cinn::common::make_const(input[ii][jj]), now); + auto cond = + common::and_all({Expr(ii) - yy == 0, Expr(jj) - xx == 0}); + now = common::select( + cond, cinn::common::make_const(input[ii][jj]), now); } } return now; @@ -294,8 +403,10 @@ ir::Tensor const_matrix(const std::vector>& input, const std: return result; } -std::vector winograd_transform_matrices(const int& tile_size, const int& kernel_size) { - std::vector>> vals = get_winograd_val(tile_size, kernel_size); +std::vector winograd_transform_matrices(const int& tile_size, + const int& kernel_size) { + std::vector>> vals = + get_winograd_val(tile_size, kernel_size); CHECK_EQ(vals.size(), 3U) << "vals_size of winograd is not 3! Please check."; std::vector> A = vals[0]; @@ -303,18 +414,19 @@ std::vector winograd_transform_matrices(const int& tile_size, const std::vector> G = vals[2]; std::string name_a = "A_matrix"; - auto tensor_a = const_matrix(A, name_a); + auto tensor_a = const_matrix(A, name_a); std::string name_b = "B_matrix"; - auto tensor_b = const_matrix(B, name_b); + auto tensor_b = const_matrix(B, name_b); std::string name_g = "G_matrix"; - auto tensor_g = const_matrix(G, name_g); + auto tensor_g = const_matrix(G, name_g); return {tensor_a, tensor_b, tensor_g}; } -int GetPostParallelSize(const std::vector& inshape, const std::vector& axes) { +int GetPostParallelSize(const std::vector& inshape, + const std::vector& axes) { int parallel_size = 1; for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { parallel_size *= inshape[idx]; @@ -322,7 +434,8 @@ int GetPostParallelSize(const std::vector& inshape, const std::vector& return parallel_size; } -int GetParallelSize(const std::vector& inshape, const std::vector& axes) { +int GetParallelSize(const std::vector& inshape, + const std::vector& axes) { int parallel_size = 1; for (int idx = 0; idx < inshape.size(); ++idx) { if (std::find(axes.begin(), axes.end(), idx) != axes.end()) { @@ -352,8 +465,8 @@ std::vector GetFirstStepReduceShape(const std::vector& shape, CHECK_GT(unfold_size, 1); // fuse reduce axis. - int insert_zero_num = 0; - int last_axis_index = axes.size() - 1; + int insert_zero_num = 0; + int last_axis_index = axes.size() - 1; int last_reduce_size = shape[axes.back()]; for (; last_axis_index >= 1; --last_axis_index) { if (axes[last_axis_index] - 1 != axes[last_axis_index - 1]) { diff --git a/paddle/cinn/hlir/pe/nn_util.h b/paddle/cinn/hlir/pe/nn_util.h index fe27e6c8b9b1e..7ea201e0461ab 100644 --- a/paddle/cinn/hlir/pe/nn_util.h +++ b/paddle/cinn/hlir/pe/nn_util.h @@ -30,11 +30,14 @@ namespace cinn { namespace hlir { namespace pe { -ir::Tensor const_matrix(const std::vector>& input, const std::string& name); +ir::Tensor const_matrix(const std::vector>& input, + const std::string& name); -std::vector>> get_winograd_val(const int& tile_size, const int& kernel_size); +std::vector>> get_winograd_val( + const int& tile_size, const int& kernel_size); -std::vector winograd_transform_matrices(const int& tile_size, const int& kernel_size); +std::vector winograd_transform_matrices(const int& tile_size, + const int& kernel_size); std::vector GetFirstStepReduceShape(const std::vector& shape, const std::vector& axes, diff --git a/paddle/cinn/hlir/pe/pe_broadcast_test.cc b/paddle/cinn/hlir/pe/pe_broadcast_test.cc index 7248a55bd65aa..533c3ebdd9706 100644 --- a/paddle/cinn/hlir/pe/pe_broadcast_test.cc +++ b/paddle/cinn/hlir/pe/pe_broadcast_test.cc @@ -27,11 +27,13 @@ namespace hlir { namespace pe { using ir::Tensor; -void TestBroadcastPE( - const std::string &fn_name, - Tensor (*func_op)(const Tensor &A, const Tensor &B, const std::string &output_name, const Expr &axis), - float (*fn_runtime)(float, float), - int set_value = 0) { +void TestBroadcastPE(const std::string &fn_name, + Tensor (*func_op)(const Tensor &A, + const Tensor &B, + const std::string &output_name, + const Expr &axis), + float (*fn_runtime)(float, float), + int set_value = 0) { Expr M(100), N(32); Placeholder A("A", {M, N}); @@ -47,7 +49,7 @@ void TestBroadcastPE( builder.AddFunction(func); LOG(INFO) << "func:\n" << func; - auto jit = backends::ExecutionEngine::Create({}); + auto jit = backends::ExecutionEngine::Create({}); auto module = builder.Build(); jit->Link(module); @@ -58,13 +60,23 @@ void TestBroadcastPE( cinn_buffer_t *A_buf; cinn_buffer_t *B_buf; if (set_value != 0) { - A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_val(set_value).Build(); - B_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_val(set_value).Build(); + A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_val(set_value) + .Build(); + B_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_val(set_value) + .Build(); } else { - A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); - B_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); + A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_random() + .Build(); + B_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_random() + .Build(); } - auto *C_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_zero().Build(); + auto *C_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_zero() + .Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; @@ -78,22 +90,24 @@ void TestBroadcastPE( } } -void TestBroadcastPE1( - const std::string &fn_name, - Tensor (*func_op)(const Tensor &A, const Tensor &B, const std::string &output_name, const Expr &axis), - float (*fn_runtime)(float, float), - int set_value = 0) { +void TestBroadcastPE1(const std::string &fn_name, + Tensor (*func_op)(const Tensor &A, + const Tensor &B, + const std::string &output_name, + const Expr &axis), + float (*fn_runtime)(float, float), + int set_value = 0) { Expr M(100), N(32), K(10); Placeholder A("A", {M, N, K}); Placeholder B("B", {N}); - auto C = func_op(A.tensor(), B.tensor(), "C", Expr(1)); - auto stages = CreateStages({C}); + auto C = func_op(A.tensor(), B.tensor(), "C", Expr(1)); + auto stages = CreateStages({C}); Target target = common::DefaultHostTarget(); Module::Builder builder("module0", target); auto func = Lower("fn", stages, {A, B, C}); builder.AddFunction(func); LOG(INFO) << "func:\n" << func; - auto jit = backends::ExecutionEngine::Create({}); + auto jit = backends::ExecutionEngine::Create({}); auto module = builder.Build(); jit->Link(module); auto fn = jit->Lookup("fn"); @@ -102,13 +116,25 @@ void TestBroadcastPE1( cinn_buffer_t *A_buf; cinn_buffer_t *B_buf; if (set_value != 0) { - A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32(), K.as_int32()}).set_val(set_value).Build(); - B_buf = common::BufferBuilder(Float(32), {N.as_int32()}).set_val(set_value).Build(); + A_buf = common::BufferBuilder(Float(32), + {M.as_int32(), N.as_int32(), K.as_int32()}) + .set_val(set_value) + .Build(); + B_buf = common::BufferBuilder(Float(32), {N.as_int32()}) + .set_val(set_value) + .Build(); } else { - A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32(), K.as_int32()}).set_random().Build(); - B_buf = common::BufferBuilder(Float(32), {N.as_int32()}).set_random().Build(); + A_buf = common::BufferBuilder(Float(32), + {M.as_int32(), N.as_int32(), K.as_int32()}) + .set_random() + .Build(); + B_buf = + common::BufferBuilder(Float(32), {N.as_int32()}).set_random().Build(); } - auto *C_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32(), K.as_int32()}).set_zero().Build(); + auto *C_buf = common::BufferBuilder( + Float(32), {M.as_int32(), N.as_int32(), K.as_int32()}) + .set_zero() + .Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; fn_(args, 3); @@ -125,22 +151,24 @@ void TestBroadcastPE1( } } -void TestBroadcastPE2( - const std::string &fn_name, - Tensor (*func_op)(const Tensor &A, const Tensor &B, const std::string &output_name, const Expr &axis), - float (*fn_runtime)(float, float), - int set_value = 0) { +void TestBroadcastPE2(const std::string &fn_name, + Tensor (*func_op)(const Tensor &A, + const Tensor &B, + const std::string &output_name, + const Expr &axis), + float (*fn_runtime)(float, float), + int set_value = 0) { Expr M(100), N(32), K(10), R(1); Placeholder A("A", {M, N, K, R}); Placeholder B("B", {N, K}); - auto C = func_op(A.tensor(), B.tensor(), "C", Expr(1)); - auto stages = CreateStages({C}); + auto C = func_op(A.tensor(), B.tensor(), "C", Expr(1)); + auto stages = CreateStages({C}); Target target = common::DefaultHostTarget(); Module::Builder builder("module0", target); auto func = Lower("fn", stages, {A, B, C}); builder.AddFunction(func); LOG(INFO) << "func:\n" << func; - auto jit = backends::ExecutionEngine::Create({}); + auto jit = backends::ExecutionEngine::Create({}); auto module = builder.Build(); jit->Link(module); auto fn = jit->Lookup("fn"); @@ -149,17 +177,29 @@ void TestBroadcastPE2( cinn_buffer_t *A_buf; cinn_buffer_t *B_buf; if (set_value != 0) { - A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32(), K.as_int32(), R.as_int32()}) + A_buf = + common::BufferBuilder( + Float(32), {M.as_int32(), N.as_int32(), K.as_int32(), R.as_int32()}) + .set_val(set_value) + .Build(); + B_buf = common::BufferBuilder(Float(32), {N.as_int32(), K.as_int32()}) .set_val(set_value) .Build(); - B_buf = common::BufferBuilder(Float(32), {N.as_int32(), K.as_int32()}).set_val(set_value).Build(); } else { A_buf = - common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32(), K.as_int32(), R.as_int32()}).set_random().Build(); - B_buf = common::BufferBuilder(Float(32), {N.as_int32(), K.as_int32()}).set_random().Build(); + common::BufferBuilder( + Float(32), {M.as_int32(), N.as_int32(), K.as_int32(), R.as_int32()}) + .set_random() + .Build(); + B_buf = common::BufferBuilder(Float(32), {N.as_int32(), K.as_int32()}) + .set_random() + .Build(); } auto *C_buf = - common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32(), K.as_int32(), R.as_int32()}).set_zero().Build(); + common::BufferBuilder( + Float(32), {M.as_int32(), N.as_int32(), K.as_int32(), R.as_int32()}) + .set_zero() + .Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; fn_(args, 3); @@ -181,11 +221,16 @@ void TestBroadcastPE2( #define RULE(test_name__, rule__) \ float test_name__(float a, float b) { rule__ } -#define TEST_BROADCAST_PE_FP32_BASIC(test_name__) \ - TEST(broadcast_pe, test_name__) { TestBroadcastPE("PE_Broadcast_" #test_name__ "_fp32", test_name__, test_name__); } +#define TEST_BROADCAST_PE_FP32_BASIC(test_name__) \ + TEST(broadcast_pe, test_name__) { \ + TestBroadcastPE( \ + "PE_Broadcast_" #test_name__ "_fp32", test_name__, test_name__); \ + } -#define TEST_BROADCAST_PE_FP32_SET_BASIC(test_name__) \ - TEST(broadcast_pe, test_name__) { TestBroadcastPE("PE_Broadcast_" #test_name__ "_fp32", test_name__, value); } +#define TEST_BROADCAST_PE_FP32_SET_BASIC(test_name__) \ + TEST(broadcast_pe, test_name__) { \ + TestBroadcastPE("PE_Broadcast_" #test_name__ "_fp32", test_name__, value); \ + } #define TEST_BROADCAST_PE_FP32(test_name__, rule__) \ RULE(test_name__, rule__) \ diff --git a/paddle/cinn/hlir/pe/pe_elementwise_test.cc b/paddle/cinn/hlir/pe/pe_elementwise_test.cc index d31aaa95ba33c..c96a28a19762b 100644 --- a/paddle/cinn/hlir/pe/pe_elementwise_test.cc +++ b/paddle/cinn/hlir/pe/pe_elementwise_test.cc @@ -37,8 +37,8 @@ template void TestElementwisePE(const std::string &fn_name, const FuncOp &func_op, const FuncRuntime &fn_runtime, - Type type = Float(32), - int set_value = 0, + Type type = Float(32), + int set_value = 0, bool test_benchmark = true) { Expr M(1024), N(2048); @@ -60,7 +60,7 @@ void TestElementwisePE(const std::string &fn_name, LOG(INFO) << "func:\n" << func; builder.AddFunction(func); - auto jit = backends::ExecutionEngine::Create({}); + auto jit = backends::ExecutionEngine::Create({}); auto module = builder.Build(); jit->Link(module); @@ -70,11 +70,17 @@ void TestElementwisePE(const std::string &fn_name, cinn_buffer_t *A_buf; if (set_value != 0) { - A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_val(set_value).Build(); + A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_val(set_value) + .Build(); } else { - A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); + A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_random() + .Build(); } - auto *B_buf = common::BufferBuilder(type, {M.as_int32(), N.as_int32()}).set_align(type.bits()).Build(); + auto *B_buf = common::BufferBuilder(type, {M.as_int32(), N.as_int32()}) + .set_align(type.bits()) + .Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf); cinn_pod_value_t args[] = {a_arg, b_arg}; @@ -91,7 +97,8 @@ void TestElementwisePE(const std::string &fn_name, fn_(args, 2); } test_op_time = timer.Stop() / repeat_; - LOG(INFO) << "repeat times: " << repeat_ << ", kernel run time: " << test_op_time << " ms"; + LOG(INFO) << "repeat times: " << repeat_ + << ", kernel run time: " << test_op_time << " ms"; } else { fn_(args, 2); } @@ -115,17 +122,23 @@ bool isfinite(float e) { return std::isfinite(e); } bool isinf(float e) { return std::isinf(e); } float rsqrt(float e) { return 1.0f / sqrtf(e); } -#define TEST_ELEMENTWISE_PE_FP32(test_name__, PE__) \ - TEST(elementwise_pe, test_name__) { \ - cinn::hlir::pe::TestElementwisePE("PE_Elementwise_" #test_name__ "_fp32", PE__, test_name__); \ +#define TEST_ELEMENTWISE_PE_FP32(test_name__, PE__) \ + TEST(elementwise_pe, test_name__) { \ + cinn::hlir::pe::TestElementwisePE( \ + "PE_Elementwise_" #test_name__ "_fp32", PE__, test_name__); \ } -#define TEST_ELEMENTWISE_PE_FP32_BOOL(test_name__, PE__) \ - TEST(elementwise_pe, test_name__) { \ - cinn::hlir::pe::TestElementwisePE("PE_Elementwise_" #test_name__ "_fp32", PE__, test_name__, Bool()); \ +#define TEST_ELEMENTWISE_PE_FP32_BOOL(test_name__, PE__) \ + TEST(elementwise_pe, test_name__) { \ + cinn::hlir::pe::TestElementwisePE( \ + "PE_Elementwise_" #test_name__ "_fp32", PE__, test_name__, Bool()); \ } -#define TEST_ELEMENTWISE_PE_FP32_SET(test_name__, PE__, value__) \ - TEST(elementwise_pe, test_name__) { \ - cinn::hlir::pe::TestElementwisePE("PE_Elementwise_" #test_name__ "_fp32", PE__, test_name__, Float(32), value__); \ +#define TEST_ELEMENTWISE_PE_FP32_SET(test_name__, PE__, value__) \ + TEST(elementwise_pe, test_name__) { \ + cinn::hlir::pe::TestElementwisePE("PE_Elementwise_" #test_name__ "_fp32", \ + PE__, \ + test_name__, \ + Float(32), \ + value__); \ } TEST_ELEMENTWISE_PE_FP32(expf, Exp) diff --git a/paddle/cinn/hlir/pe/pe_transform_test.cc b/paddle/cinn/hlir/pe/pe_transform_test.cc index 0fcd520623d5f..b69b48b4b85bf 100644 --- a/paddle/cinn/hlir/pe/pe_transform_test.cc +++ b/paddle/cinn/hlir/pe/pe_transform_test.cc @@ -46,7 +46,7 @@ TEST(MatmulPE, MatmulCase1) { auto C = hlir::pe::Matmul(A.tensor(), B.tensor(), false, false, 1, "C"); - auto stages = CreateStages({A, B}); + auto stages = CreateStages({A, B}); std::vector tensor_args = {A, B}; for (size_t i = 0; i < C.size(); i++) { tensor_args.push_back(C[i]); @@ -58,15 +58,17 @@ TEST(MatmulPE, MatmulCase1) { builder.AddFunction(func); LOG(INFO) << "func:\n" << func; - auto jit = backends::ExecutionEngine::Create({}); + auto jit = backends::ExecutionEngine::Create({}); auto module = builder.Build(); jit->Link(module); auto fn = jit->Lookup("fn"); CHECK(fn); - auto fn_ = reinterpret_cast(fn); - cinn_buffer_t *A_buf = common::BufferBuilder(Float(32), {m, k}).set_random().Build(); - cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), {k, n}).set_random().Build(); + auto fn_ = reinterpret_cast(fn); + cinn_buffer_t *A_buf = + common::BufferBuilder(Float(32), {m, k}).set_random().Build(); + cinn_buffer_t *B_buf = + common::BufferBuilder(Float(32), {k, n}).set_random().Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf); std::vector args = {a_arg, b_arg}; std::vector C_buf; @@ -82,9 +84,9 @@ TEST(MatmulPE, MatmulCase1) { args.push_back(arg); } fn_(reinterpret_cast(args.data()), args.size()); - auto *ad = reinterpret_cast(A_buf->memory); - auto *bd = reinterpret_cast(B_buf->memory); - auto *cd = reinterpret_cast(C_buf[0]->memory); + auto *ad = reinterpret_cast(A_buf->memory); + auto *bd = reinterpret_cast(B_buf->memory); + auto *cd = reinterpret_cast(C_buf[0]->memory); int size_a = m; int size_b = n; int size_c = k; @@ -118,19 +120,20 @@ TEST(ScatterAssign, ScatterAssign) { auto target = common::DefaultHostTarget(); #endif - auto output = hlir::pe::ScatterAssign(input.tensor(), assign.tensor(), indexs.tensor(), target, axis); + auto output = hlir::pe::ScatterAssign( + input.tensor(), assign.tensor(), indexs.tensor(), target, axis); auto stages = CreateStages({input, assign, indexs, output}); - auto func = Lower("fn", stages, {input, assign, indexs, output}); + auto func = Lower("fn", stages, {input, assign, indexs, output}); LOG(INFO) << "func:\n" << func; #ifdef CINN_WITH_CUDA Module::Builder builder("ScatterAssign_Builder", target); builder.AddFunction(func); - auto module = builder.Build(); + auto module = builder.Build(); auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); + auto &host_module = std::get<0>(host_module_device_module); + auto &device_module = std::get<1>(host_module_device_module); backends::CodeGenCUDA_Dev codegen(target); auto source_code = codegen.Compile(builder.Build()); @@ -141,7 +144,8 @@ TEST(ScatterAssign, ScatterAssign) { auto ptx = compiler(source_code); CHECK(!ptx.empty()); // cuda_module load ptx - runtime::cuda::CUDAModule cuda_module(ptx, runtime::cuda::CUDAModule::Kind::PTX); + runtime::cuda::CUDAModule cuda_module(ptx, + runtime::cuda::CUDAModule::Kind::PTX); #endif // CINN_WITH_CUDA } @@ -151,17 +155,18 @@ TEST(SliceAssign, SliceAssign) { int k = 32; Expr M(m), N(n), K(k); - std::vector axis = {0, 1}; - std::vector starts = {32, 32}; - std::vector ends = {64, 64}; + std::vector axis = {0, 1}; + std::vector starts = {32, 32}; + std::vector ends = {64, 64}; std::vector strides = {1, 1}; Placeholder input("A", {M, M}); Placeholder assign("B", {N, N}); - auto output = hlir::pe::SliceAssign(input.tensor(), assign.tensor(), axis, starts, ends, strides); + auto output = hlir::pe::SliceAssign( + input.tensor(), assign.tensor(), axis, starts, ends, strides); auto stages = CreateStages({input, assign, output}); - auto func = Lower("fn", stages, {input, assign, output}); + auto func = Lower("fn", stages, {input, assign, output}); LOG(INFO) << "func:\n" << func; #ifdef CINN_WITH_CUDA @@ -169,10 +174,10 @@ TEST(SliceAssign, SliceAssign) { Module::Builder builder("SliceAssign_Builder", target); builder.AddFunction(func); - auto module = builder.Build(); + auto module = builder.Build(); auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); + auto &host_module = std::get<0>(host_module_device_module); + auto &device_module = std::get<1>(host_module_device_module); backends::CodeGenCUDA_Dev codegen(target); auto source_code = codegen.Compile(builder.Build()); @@ -183,7 +188,8 @@ TEST(SliceAssign, SliceAssign) { auto ptx = compiler(source_code); CHECK(!ptx.empty()); - runtime::cuda::CUDAModule cuda_module(ptx, runtime::cuda::CUDAModule::Kind::PTX); + runtime::cuda::CUDAModule cuda_module(ptx, + runtime::cuda::CUDAModule::Kind::PTX); #endif } @@ -197,10 +203,11 @@ TEST(Concat, ConcatCase0) { Placeholder C("C", {M, N}); Placeholder D("D", {M, N}); - std::vector inputs{A.tensor(), B.tensor(), C.tensor(), D.tensor()}; + std::vector inputs{ + A.tensor(), B.tensor(), C.tensor(), D.tensor()}; auto output = hlir::pe::Concat(inputs, 1); auto stages = CreateStages({output}); - auto func = Lower("fn", stages, {A, B, C, D, output}); + auto func = Lower("fn", stages, {A, B, C, D, output}); LOG(INFO) << "func:\n" << func; #ifdef CINN_WITH_CUDA @@ -208,10 +215,10 @@ TEST(Concat, ConcatCase0) { Module::Builder builder("Concat_Builder", target); builder.AddFunction(func); - auto module = builder.Build(); + auto module = builder.Build(); auto host_module_device_module = backends::SplitCudaAndHostModule(module); - auto &host_module = std::get<0>(host_module_device_module); - auto &device_module = std::get<1>(host_module_device_module); + auto &host_module = std::get<0>(host_module_device_module); + auto &device_module = std::get<1>(host_module_device_module); backends::CodeGenCUDA_Dev codegen(target); auto source_code = codegen.Compile(builder.Build()); diff --git a/paddle/cinn/hlir/pe/reduction.cc b/paddle/cinn/hlir/pe/reduction.cc index 073e8a80549d7..d4e2daa893a26 100644 --- a/paddle/cinn/hlir/pe/reduction.cc +++ b/paddle/cinn/hlir/pe/reduction.cc @@ -37,17 +37,21 @@ using ir::Tensor; using lang::Compute; /** - * @brief transform reduction axes which could be empty or have negative elements into real axes with valid dimension - * indices. + * @brief transform reduction axes which could be empty or have negative + * elements into real axes with valid dimension indices. * * @param ndim Number of dimensions of the output tensor. * @param axes The axes parameter. - * @param real_axes A non-empty sorted array of valid dimension indices, with no duplicates. + * @param real_axes A non-empty sorted array of valid dimension indices, with no + * duplicates. * - * @notes If the input axes are empty, the result will be axes including all dimensions. If any input element is - * negative, it will be treated as an offset from the last dimension (same as python indexing rules). + * @notes If the input axes are empty, the result will be axes including all + * dimensions. If any input element is negative, it will be treated as an offset + * from the last dimension (same as python indexing rules). */ -void GetRealAxes(int ndim, const std::vector& axes, std::vector* real_axes) { +void GetRealAxes(int ndim, + const std::vector& axes, + std::vector* real_axes) { CHECK(real_axes); if (axes.empty()) { for (int i = 0; i < ndim; ++i) { @@ -58,11 +62,13 @@ void GetRealAxes(int ndim, const std::vector& axes, std::vector* real_ if (axis < 0) { axis += ndim; } - CHECK_LE(axis, ndim) << "exceeds the maximum dimension: " << ndim << std::endl; + CHECK_LE(axis, ndim) << "exceeds the maximum dimension: " << ndim + << std::endl; CHECK_GE(axis, 0); real_axes->push_back(axis); } - real_axes->resize(std::unique(real_axes->begin(), real_axes->end()) - real_axes->begin()); + real_axes->resize(std::unique(real_axes->begin(), real_axes->end()) - + real_axes->begin()); std::sort(real_axes->begin(), real_axes->end()); } } @@ -91,11 +97,13 @@ std::string Type2StrForReduce(common::Type type) { /** * @brief Calculate the target reduced shape. * - * @param real_axes A non-empty sorted array of valid dimension indices, with no duplicates. + * @param real_axes A non-empty sorted array of valid dimension indices, with no + * duplicates. * @param output_shape The output Tensor shape. * @param tensor The input tensor. - * @param keep_dims If this is set to true, the reduced axes are kept as dimensions with size one. This enables the - * result to broadcast correctly against the input array. + * @param keep_dims If this is set to true, the reduced axes are kept as + * dimensions with size one. This enables the result to broadcast correctly + * against the input array. */ void GetOutputShape(const std::vector& real_axes, std::vector* output_shape, @@ -130,7 +138,8 @@ void GetOutputShape(const std::vector& real_axes, * @param fn The reduction function eg. ReduceSum * @param output_shape The output Tensor shape. * @param real_axes The real axes where the reduction is performed. - * @param squeeze_axes The real axes to squeeze. If unsqueezed, reduced axes will have shape 1 in the output tensor. + * @param squeeze_axes The real axes to squeeze. If unsqueezed, reduced axes + * will have shape 1 in the output tensor. * @param initial Starting value for the sum. * @param output_name The name of the output Tensor. * @@ -147,7 +156,8 @@ Tensor DoReduce(const Tensor& tensor, std::vector reduce_axes; int reduce_k_id = 0; for (auto& axis : real_axes) { - std::string name = cinn::UniqName(std::string("reduce_k_") + std::to_string(reduce_k_id)); + std::string name = + cinn::UniqName(std::string("reduce_k_") + std::to_string(reduce_k_id)); reduce_axes.push_back(Var(tensor->shape[axis], name)); reduce_k_id++; } @@ -157,7 +167,8 @@ Tensor DoReduce(const Tensor& tensor, int reduce_cnt = 0; for (size_t i = 0; i < tensor->shape.size(); ++i) { - bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) != squeeze_axes.end(); + bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) != + squeeze_axes.end(); if (std::find(real_axes.begin(), real_axes.end(), i) != real_axes.end()) { eval_indice.push_back(reduce_axes[reduce_cnt]); reduce_cnt++; @@ -180,7 +191,8 @@ Tensor DoReduce(const Tensor& tensor, * @param tensor The input tensor. * @param axes The axes along which the reduction are performed. * @param fn The reduction function eg. ReduceSum - * @param keep_dims If it is set true, the axes which are reduced are left in the result as dimensions with size one. + * @param keep_dims If it is set true, the axes which are reduced are left in + * the result as dimensions with size one. * @param initial Starting value for the sum. * * @return The result tensor. @@ -198,31 +210,66 @@ Tensor Reduce(const Tensor& tensor, GetRealAxes(static_cast(ndim), axes, &real_axes); std::vector output_shapes; GetOutputShape(real_axes, &output_shapes, tensor, keep_dims); - return DoReduce( - tensor, fn, output_shapes, real_axes, keep_dims ? std::vector() : real_axes, initial, output_name); + return DoReduce(tensor, + fn, + output_shapes, + real_axes, + keep_dims ? std::vector() : real_axes, + initial, + output_name); } -Tensor ReduceSum(const Tensor& A, const std::vector& axes, const bool keep_dims, const std::string& output_name) { - return Reduce(A, axes, lang::ReduceSum, keep_dims, ir::Zero(A->type()), output_name); +Tensor ReduceSum(const Tensor& A, + const std::vector& axes, + const bool keep_dims, + const std::string& output_name) { + return Reduce( + A, axes, lang::ReduceSum, keep_dims, ir::Zero(A->type()), output_name); } -Tensor ReduceProd(const Tensor& A, const std::vector& axes, const bool keep_dims, const std::string& output_name) { - return Reduce(A, axes, lang::ReduceMul, keep_dims, lang::One(A->type()), output_name); +Tensor ReduceProd(const Tensor& A, + const std::vector& axes, + const bool keep_dims, + const std::string& output_name) { + return Reduce( + A, axes, lang::ReduceMul, keep_dims, lang::One(A->type()), output_name); } -Tensor ReduceMax(const Tensor& A, const std::vector& axes, const bool keep_dims, const std::string& output_name) { - return Reduce(A, axes, lang::ReduceMax, keep_dims, lang::min_value(A->type()), output_name); +Tensor ReduceMax(const Tensor& A, + const std::vector& axes, + const bool keep_dims, + const std::string& output_name) { + return Reduce(A, + axes, + lang::ReduceMax, + keep_dims, + lang::min_value(A->type()), + output_name); } -Tensor ReduceMin(const Tensor& A, const std::vector& axes, const bool keep_dims, const std::string& output_name) { - return Reduce(A, axes, lang::ReduceMin, keep_dims, lang::max_value(A->type()), output_name); +Tensor ReduceMin(const Tensor& A, + const std::vector& axes, + const bool keep_dims, + const std::string& output_name) { + return Reduce(A, + axes, + lang::ReduceMin, + keep_dims, + lang::max_value(A->type()), + output_name); } -Tensor ReduceAll(const Tensor& A, const std::vector& axes, const bool keep_dims, const std::string& output_name) { +Tensor ReduceAll(const Tensor& A, + const std::vector& axes, + const bool keep_dims, + const std::string& output_name) { return Reduce(A, axes, lang::ReduceAll, keep_dims, Expr(true), output_name); } -Tensor ReduceAny(const Tensor& A, const std::vector& axes, const bool keep_dims, const std::string& output_name) { +Tensor ReduceAny(const Tensor& A, + const std::vector& axes, + const bool keep_dims, + const std::string& output_name) { return Reduce(A, axes, lang::ReduceAny, keep_dims, Expr(false), output_name); } @@ -241,12 +288,14 @@ std::vector WarpReduce(const ir::Tensor& A, } // comput tmp output shape. - std::vector tmp_shape(A->shape.begin(), A->shape.begin() + shape_size_without_reduce_dim); + std::vector tmp_shape(A->shape.begin(), + A->shape.begin() + shape_size_without_reduce_dim); tmp_shape.push_back(Expr(32)); auto tmp_out = Compute( tmp_shape, [=](const std::vector& indexs) -> Expr { - std::vector tmp_indexs(indexs.begin(), indexs.begin() + indexs.size() - 1); + std::vector tmp_indexs(indexs.begin(), + indexs.begin() + indexs.size() - 1); for (int idx = 0; idx < last_reduce_dim_num; ++idx) { tmp_indexs.push_back(Expr(0)); } @@ -257,7 +306,8 @@ std::vector WarpReduce(const ir::Tensor& A, UniqName(output_name + "_" + reduce_type)); // compute ouput shape. - std::vector out_shape(A->shape.begin(), A->shape.begin() + shape_size_without_reduce_dim); + std::vector out_shape(A->shape.begin(), + A->shape.begin() + shape_size_without_reduce_dim); for (int idx = 0; idx < last_reduce_dim_num && keep_dim; ++idx) { out_shape.push_back(Expr(1)); } @@ -268,7 +318,8 @@ std::vector WarpReduce(const ir::Tensor& A, auto out = Compute( out_shape, [=](const std::vector& indexs) -> Expr { - std::vector tmp_indexs(indexs.begin(), indexs.begin() + shape_size_without_reduce_dim); + std::vector tmp_indexs( + indexs.begin(), indexs.begin() + shape_size_without_reduce_dim); tmp_indexs.push_back(Expr(0)); return tmp_out(tmp_indexs); }, @@ -281,24 +332,33 @@ std::vector WarpReduceMax(const ir::Tensor& A, const int last_reduce_dim_num, const bool keep_dim, const std::string& output_name) { - return WarpReduce( - A, last_reduce_dim_num, keep_dim, "cinn_warp_reduce_max" + Type2StrForReduce(A->type()), output_name); + return WarpReduce(A, + last_reduce_dim_num, + keep_dim, + "cinn_warp_reduce_max" + Type2StrForReduce(A->type()), + output_name); } std::vector WarpReduceSum(const ir::Tensor& A, const int last_reduce_dim_num, const bool keep_dim, const std::string& output_name) { - return WarpReduce( - A, last_reduce_dim_num, keep_dim, "cinn_warp_reduce_sum" + Type2StrForReduce(A->type()), output_name); + return WarpReduce(A, + last_reduce_dim_num, + keep_dim, + "cinn_warp_reduce_sum" + Type2StrForReduce(A->type()), + output_name); } std::vector WarpReduceAvg(const ir::Tensor& A, const int last_reduce_dim_num, const bool keep_dim, const std::string& output_name) { - return WarpReduce( - A, last_reduce_dim_num, keep_dim, "cinn_warp_reduce_avg" + Type2StrForReduce(A->type()), output_name); + return WarpReduce(A, + last_reduce_dim_num, + keep_dim, + "cinn_warp_reduce_avg" + Type2StrForReduce(A->type()), + output_name); } std::vector BlockReduceInternal(const ir::Tensor& A, @@ -314,12 +374,15 @@ std::vector BlockReduceInternal(const ir::Tensor& A, } // compute tmp output shape. - std::vector tmp_shape(A->shape.begin(), A->shape.begin() + axes.front()); + std::vector tmp_shape(A->shape.begin(), + A->shape.begin() + axes.front()); tmp_shape.push_back(reduce_width); // compute the reduce dimension stride. std::vector last_reduce_stride(A->shape.size() - axes.front(), Expr(1)); - for (int idx = A->shape.size(), index = int(last_reduce_stride.size()) - 2; index >= 0; --index) { + for (int idx = A->shape.size(), index = int(last_reduce_stride.size()) - 2; + index >= 0; + --index) { last_reduce_stride[index] = last_reduce_stride[index + 1] * A->shape[--idx]; } @@ -328,7 +391,8 @@ std::vector BlockReduceInternal(const ir::Tensor& A, [=](const std::vector& indexs) -> Expr { // comput index map from output to input. auto last_index = indexs.back(); - std::vector input_indexs(indexs.begin(), indexs.begin() + indexs.size() - 1); + std::vector input_indexs(indexs.begin(), + indexs.begin() + indexs.size() - 1); for (int idx = 0; idx < A->shape.size() - axes.front(); ++idx) { input_indexs.push_back(last_index / last_reduce_stride[idx]); last_index = last_index % last_reduce_stride[idx]; @@ -341,8 +405,10 @@ std::vector BlockReduceInternal(const ir::Tensor& A, UniqName(output_name + "_tmp")); // compute output shape. - std::vector out_shape(A->shape.begin(), A->shape.begin() + axes.front()); - int tailf = keep_dim ? (int(A->shape.size()) - axes.front()) : (int(A->shape.size()) - axes.back() - 1); + std::vector out_shape(A->shape.begin(), + A->shape.begin() + axes.front()); + int tailf = keep_dim ? (int(A->shape.size()) - axes.front()) + : (int(A->shape.size()) - axes.back() - 1); for (int idx = 0; idx < tailf; ++idx) { out_shape.push_back(Expr(1)); } @@ -353,7 +419,8 @@ std::vector BlockReduceInternal(const ir::Tensor& A, auto out = Compute( out_shape, [=](const std::vector& indexs) -> Expr { - std::vector tmp_indexs(indexs.begin(), indexs.begin() + axes.front()); + std::vector tmp_indexs(indexs.begin(), + indexs.begin() + axes.front()); tmp_indexs.push_back(Expr(0)); return tmp_out(tmp_indexs); }, @@ -366,15 +433,24 @@ std::vector BlockReduceSumInternal(const ir::Tensor& A, const bool keep_dim, const std::string& output_name) { return BlockReduceInternal( - A, axes, keep_dim, "cinn_block_reduce_sum" + Type2StrForReduce(A->type()) + "_internal", output_name); + A, + axes, + keep_dim, + "cinn_block_reduce_sum" + Type2StrForReduce(A->type()) + "_internal", + output_name); } -std::vector BlockReduceProdInternal(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name) { +std::vector BlockReduceProdInternal( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name) { return BlockReduceInternal( - A, axes, keep_dim, "cinn_block_reduce_prod" + Type2StrForReduce(A->type()) + "_internal", output_name); + A, + axes, + keep_dim, + "cinn_block_reduce_prod" + Type2StrForReduce(A->type()) + "_internal", + output_name); } std::vector BlockReduceMaxInternal(const ir::Tensor& A, @@ -382,7 +458,11 @@ std::vector BlockReduceMaxInternal(const ir::Tensor& A, const bool keep_dim, const std::string& output_name) { return BlockReduceInternal( - A, axes, keep_dim, "cinn_block_reduce_max" + Type2StrForReduce(A->type()) + "_internal", output_name); + A, + axes, + keep_dim, + "cinn_block_reduce_max" + Type2StrForReduce(A->type()) + "_internal", + output_name); } std::vector BlockReduceMinInternal(const ir::Tensor& A, @@ -390,25 +470,32 @@ std::vector BlockReduceMinInternal(const ir::Tensor& A, const bool keep_dim, const std::string& output_name) { return BlockReduceInternal( - A, axes, keep_dim, "cinn_block_reduce_min" + Type2StrForReduce(A->type()) + "_internal", output_name); + A, + axes, + keep_dim, + "cinn_block_reduce_min" + Type2StrForReduce(A->type()) + "_internal", + output_name); } std::vector BlockReduceAllInternal(const ir::Tensor& A, const std::vector& axes, const bool keep_dim, const std::string& output_name) { - return BlockReduceInternal(A, axes, keep_dim, "cinn_block_reduce_all_internal", output_name); + return BlockReduceInternal( + A, axes, keep_dim, "cinn_block_reduce_all_internal", output_name); } std::vector BlockReduceAnyInternal(const ir::Tensor& A, const std::vector& axes, const bool keep_dim, const std::string& output_name) { - return BlockReduceInternal(A, axes, keep_dim, "cinn_block_reduce_any_internal", output_name); + return BlockReduceInternal( + A, axes, keep_dim, "cinn_block_reduce_any_internal", output_name); } /** - * @brief compute the sum of array elements over the last dimension with block reduce + * @brief compute the sum of array elements over the last dimension with block + * reduce * * @param A The input Tensor. * @param last_reduce_dim_num the number of last reduce dimension. @@ -428,12 +515,14 @@ std::vector BlockReduce(const ir::Tensor& A, } // compute tmp output tensor shape - std::vector tmp_shape(A->shape.begin(), A->shape.begin() + axes.front()); + std::vector tmp_shape(A->shape.begin(), + A->shape.begin() + axes.front()); tmp_shape.push_back(Expr(block_size)); auto tmp_out = Compute( tmp_shape, [=](const std::vector& indexs) -> Expr { - std::vector tmp_indexs(indexs.begin(), indexs.begin() + axes.front()); + std::vector tmp_indexs(indexs.begin(), + indexs.begin() + axes.front()); for (int idx = 0; idx < A->shape.size() - axes.front(); ++idx) { tmp_indexs.push_back(Expr(0)); } @@ -447,8 +536,10 @@ std::vector BlockReduce(const ir::Tensor& A, UniqName(output_name + "_tmp")); // compute output tensor shape. - std::vector out_shape(A->shape.begin(), A->shape.begin() + axes.front()); - int tailf = keep_dim ? (int(A->shape.size()) - axes.front()) : (int(A->shape.size()) - axes.back() - 1); + std::vector out_shape(A->shape.begin(), + A->shape.begin() + axes.front()); + int tailf = keep_dim ? (int(A->shape.size()) - axes.front()) + : (int(A->shape.size()) - axes.back() - 1); for (int idx = 0; idx < tailf; ++idx) { out_shape.push_back(Expr(1)); } @@ -460,7 +551,8 @@ std::vector BlockReduce(const ir::Tensor& A, out_shape, [=](const std::vector& indexs) -> Expr { // compute input index - std::vector tmp_indexs(indexs.begin(), indexs.begin() + axes.front()); + std::vector tmp_indexs(indexs.begin(), + indexs.begin() + axes.front()); tmp_indexs.push_back(Expr(0)); return tmp_out(tmp_indexs); }, @@ -474,8 +566,12 @@ std::vector BlockReduceSum(const ir::Tensor& A, const int block_size, const bool keep_dim, const std::string& output_name) { - return BlockReduce( - A, axes, block_size, keep_dim, "cinn_block_reduce_sum" + Type2StrForReduce(A->type()), output_name); + return BlockReduce(A, + axes, + block_size, + keep_dim, + "cinn_block_reduce_sum" + Type2StrForReduce(A->type()), + output_name); } std::vector BlockReduceProd(const ir::Tensor& A, @@ -483,8 +579,12 @@ std::vector BlockReduceProd(const ir::Tensor& A, const int block_size, const bool keep_dim, const std::string& output_name) { - return BlockReduce( - A, axes, block_size, keep_dim, "cinn_block_reduce_prod" + Type2StrForReduce(A->type()), output_name); + return BlockReduce(A, + axes, + block_size, + keep_dim, + "cinn_block_reduce_prod" + Type2StrForReduce(A->type()), + output_name); } std::vector BlockReduceMax(const ir::Tensor& A, @@ -492,8 +592,12 @@ std::vector BlockReduceMax(const ir::Tensor& A, const int block_size, const bool keep_dim, const std::string& output_name) { - return BlockReduce( - A, axes, block_size, keep_dim, "cinn_block_reduce_max" + Type2StrForReduce(A->type()), output_name); + return BlockReduce(A, + axes, + block_size, + keep_dim, + "cinn_block_reduce_max" + Type2StrForReduce(A->type()), + output_name); } std::vector BlockReduceMin(const ir::Tensor& A, @@ -501,8 +605,12 @@ std::vector BlockReduceMin(const ir::Tensor& A, const int block_size, const bool keep_dim, const std::string& output_name) { - return BlockReduce( - A, axes, block_size, keep_dim, "cinn_block_reduce_min" + Type2StrForReduce(A->type()), output_name); + return BlockReduce(A, + axes, + block_size, + keep_dim, + "cinn_block_reduce_min" + Type2StrForReduce(A->type()), + output_name); } std::vector BlockReduceAll(const ir::Tensor& A, @@ -510,7 +618,8 @@ std::vector BlockReduceAll(const ir::Tensor& A, const int block_size, const bool keep_dim, const std::string& output_name) { - return BlockReduce(A, axes, block_size, keep_dim, "cinn_block_reduce_all", output_name); + return BlockReduce( + A, axes, block_size, keep_dim, "cinn_block_reduce_all", output_name); } std::vector BlockReduceAny(const ir::Tensor& A, @@ -518,7 +627,8 @@ std::vector BlockReduceAny(const ir::Tensor& A, const int block_size, const bool keep_dim, const std::string& output_name) { - return BlockReduce(A, axes, block_size, keep_dim, "cinn_block_reduce_any", output_name); + return BlockReduce( + A, axes, block_size, keep_dim, "cinn_block_reduce_any", output_name); } int GetPostParallelSize(const ir::Tensor& A, const std::vector& axes) { @@ -540,8 +650,10 @@ int GetParallelSize(const ir::Tensor& A, const std::vector& axes) { return parallel_size; } -using ReduceFunc = - std::function&, const bool, const std::string&)>; +using ReduceFunc = std::function&, + const bool, + const std::string&)>; std::vector ReduceInternal(const ir::Tensor& A, const std::vector& axes, @@ -550,17 +662,21 @@ std::vector ReduceInternal(const ir::Tensor& A, ReduceFunc reduce_func, ir::Expr initial, std::string reduce_type) { - int tail = 0; + int tail = 0; bool inbound = true; std::vector inshape; - std::transform( - A->shape.begin(), A->shape.end(), std::back_inserter(inshape), [](ir::Expr expr) { return expr.as_int32(); }); + std::transform(A->shape.begin(), + A->shape.end(), + std::back_inserter(inshape), + [](ir::Expr expr) { return expr.as_int32(); }); auto reduce_shape = GetFirstStepReduceShape(inshape, axes, inbound, tail); CHECK_GT(reduce_shape.size(), 0); - VLOG(4) << "Reduce " << output_name << " on " << reduce_type << " with input shape=[" - << cinn::utils::Join(inshape, ", ") << "], and first step reduce_shape=[" - << cinn::utils::Join(reduce_shape, ", ") << "] at axes=[" << cinn::utils::Join(axes, ", ") << "]"; + VLOG(4) << "Reduce " << output_name << " on " << reduce_type + << " with input shape=[" << cinn::utils::Join(inshape, ", ") + << "], and first step reduce_shape=[" + << cinn::utils::Join(reduce_shape, ", ") << "] at axes=[" + << cinn::utils::Join(axes, ", ") << "]"; // reshape input auto do_reshape_inbound = [&]() { @@ -579,19 +695,23 @@ std::vector ReduceInternal(const ir::Tensor& A, strides.insert(strides.begin(), strides.front() * ir::Expr(inshape[idx])); } CHECK_EQ(strides.size(), axes.size() - axis_index); - std::transform(reduce_shape.begin(), reduce_shape.end(), std::back_inserter(reshape_output_shape), [](int val) { - return ir::Expr(val); - }); + std::transform(reduce_shape.begin(), + reduce_shape.end(), + std::back_inserter(reshape_output_shape), + [](int val) { return ir::Expr(val); }); return Compute( reshape_output_shape, [=](const std::vector& indexs) -> Expr { // index is last axis in axes and index is last axis >= tail. - auto selected = ir::And::Make(ir::EQ::Make(indexs[axis], ir::Expr(reduce_shape[axis] - 1)), - ir::GE::Make(indexs[axis + 1], ir::Expr(tail))); - auto index = indexs[axis] * ir::Expr(reshape_output_shape[axis + 1]) + indexs[axis + 1]; + auto selected = ir::And::Make( + ir::EQ::Make(indexs[axis], ir::Expr(reduce_shape[axis] - 1)), + ir::GE::Make(indexs[axis + 1], ir::Expr(tail))); + auto index = indexs[axis] * ir::Expr(reshape_output_shape[axis + 1]) + + indexs[axis + 1]; // first part index - std::vector tmp_indexs(indexs.begin(), indexs.begin() + axes[axis_index]); + std::vector tmp_indexs(indexs.begin(), + indexs.begin() + axes[axis_index]); // second part index for (int idx = 0; idx < strides.size(); ++idx) { tmp_indexs.push_back(index / strides[idx]); @@ -602,14 +722,18 @@ std::vector ReduceInternal(const ir::Tensor& A, tmp_indexs.push_back(indexs[idx]); } - CHECK_EQ(tmp_indexs.size(), A->shape.size()) << "Indexs size is not equal to Input shape!"; + CHECK_EQ(tmp_indexs.size(), A->shape.size()) + << "Indexs size is not equal to Input shape!"; return ir::Select::Make(selected, A(tmp_indexs), initial); }, UniqName(output_name + "_reshape")); }; - auto reshape = inbound ? pe::Reshape(A, reduce_shape, output_name + "_reshape") : do_reshape_inbound(); + auto reshape = inbound + ? pe::Reshape(A, reduce_shape, output_name + "_reshape") + : do_reshape_inbound(); // do first step reduce - auto internal = reduce_func(reshape, axes, keep_dim, output_name + "_internal"); + auto internal = + reduce_func(reshape, axes, keep_dim, output_name + "_internal"); // do second step reduce std::vector s_axes = {}; if (keep_dim) { @@ -622,28 +746,45 @@ std::vector ReduceInternal(const ir::Tensor& A, return {reduce_out, internal, reshape}; } -#define BLOCK_SHUFFLE_REDUCE(name, reduce_type, initial) \ - std::vector BlockShuffleReduce##name( \ - const ir::Tensor& A, const std::vector& axes, const bool keep_dim, const std::string& output_name) { \ - if (common::GetMaxThreads() / GetParallelSize(A, axes) <= 1) { \ - return {Reduce##name(A, axes, keep_dim, output_name)}; \ - } else { \ - auto rs = ReduceInternal(A, axes, keep_dim, output_name, Reduce##name, initial, reduce_type); \ - if (rs.size() == 0) { \ - return {Reduce##name(A, axes, keep_dim, output_name)}; \ - } else \ - return rs; \ - } \ +#define BLOCK_SHUFFLE_REDUCE(name, reduce_type, initial) \ + std::vector BlockShuffleReduce##name( \ + const ir::Tensor& A, \ + const std::vector& axes, \ + const bool keep_dim, \ + const std::string& output_name) { \ + if (common::GetMaxThreads() / GetParallelSize(A, axes) <= 1) { \ + return {Reduce##name(A, axes, keep_dim, output_name)}; \ + } else { \ + auto rs = ReduceInternal( \ + A, axes, keep_dim, output_name, Reduce##name, initial, reduce_type); \ + if (rs.size() == 0) { \ + return {Reduce##name(A, axes, keep_dim, output_name)}; \ + } else \ + return rs; \ + } \ } -BLOCK_SHUFFLE_REDUCE(Sum, "block_shuffle_sum" + Type2StrForReduce(A->type()), ir::Zero(A->type())); -BLOCK_SHUFFLE_REDUCE(Prod, "block_shuffle_prod" + Type2StrForReduce(A->type()), lang::One(A->type())); -BLOCK_SHUFFLE_REDUCE(Max, "block_shuffle_max" + Type2StrForReduce(A->type()), lang::min_value(A->type())); -BLOCK_SHUFFLE_REDUCE(Min, "block_shuffle_min" + Type2StrForReduce(A->type()), lang::max_value(A->type())); -BLOCK_SHUFFLE_REDUCE(All, "block_shuffle_all" + Type2StrForReduce(A->type()), Expr(true)); -BLOCK_SHUFFLE_REDUCE(Any, "block_shuffle_any" + Type2StrForReduce(A->type()), Expr(false)); - -bool WithoutLastDimInReduce(const std::vector& inshape, const std::vector& axes) { +BLOCK_SHUFFLE_REDUCE(Sum, + "block_shuffle_sum" + Type2StrForReduce(A->type()), + ir::Zero(A->type())); +BLOCK_SHUFFLE_REDUCE(Prod, + "block_shuffle_prod" + Type2StrForReduce(A->type()), + lang::One(A->type())); +BLOCK_SHUFFLE_REDUCE(Max, + "block_shuffle_max" + Type2StrForReduce(A->type()), + lang::min_value(A->type())); +BLOCK_SHUFFLE_REDUCE(Min, + "block_shuffle_min" + Type2StrForReduce(A->type()), + lang::max_value(A->type())); +BLOCK_SHUFFLE_REDUCE(All, + "block_shuffle_all" + Type2StrForReduce(A->type()), + Expr(true)); +BLOCK_SHUFFLE_REDUCE(Any, + "block_shuffle_any" + Type2StrForReduce(A->type()), + Expr(false)); + +bool WithoutLastDimInReduce(const std::vector& inshape, + const std::vector& axes) { // if last axis is in reduce. if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || std::find(axes.begin(), axes.end(), -1) != axes.end()) { @@ -663,20 +804,25 @@ bool WithoutLastDimInReduce(const std::vector& inshape, const std::vec }; using BlockReduceFunc = - std::function(const ir::Tensor&, const std::vector&, const bool, const std::string&)>; - -std::vector TwoStepBlockReduceInternal(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name, - ReduceFunc reduce_func, - BlockReduceFunc block_reduce_func, - ir::Expr initial) { - CHECK(!WithoutLastDimInReduce(A->shape, axes)) << "Can't find last axis in reduce!"; + std::function(const ir::Tensor&, + const std::vector&, + const bool, + const std::string&)>; + +std::vector TwoStepBlockReduceInternal( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name, + ReduceFunc reduce_func, + BlockReduceFunc block_reduce_func, + ir::Expr initial) { + CHECK(!WithoutLastDimInReduce(A->shape, axes)) + << "Can't find last axis in reduce!"; // If the number of current device SM is smaller than the number of SM // required by Warp Reduce, the performance of Warp Reduce is better. // Otherwise, use Block Reduce. - auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); int need_reduce_last_count = 1; for (int i = 0; i < A->shape.size(); i++) { if (find(axes.begin(), axes.end(), i) == axes.end()) { @@ -684,13 +830,15 @@ std::vector TwoStepBlockReduceInternal(const ir::Tensor& A, } } int warp_reduce_need_sm_count = - ceil((need_reduce_last_count * 32) / float(common::DefaultNVGPUTarget().get_max_threads_per_sm())); + ceil((need_reduce_last_count * 32) / + float(common::DefaultNVGPUTarget().get_max_threads_per_sm())); // Set Num_max_threads to 32 is Warp Reduce - if (common::DefaultNVGPUTarget().get_multi_processor_count() < warp_reduce_need_sm_count) { + if (common::DefaultNVGPUTarget().get_multi_processor_count() < + warp_reduce_need_sm_count) { max_num_threads = 32; } - int lane = A->shape[axes.back()].as_int32(); + int lane = A->shape[axes.back()].as_int32(); int index = static_cast(axes.size()) - 2; for (; index >= 0; --index) { if (lane >= max_num_threads / 2) { @@ -704,8 +852,8 @@ std::vector TwoStepBlockReduceInternal(const ir::Tensor& A, std::vector first_axes(axes.begin(), axes.begin() + index + 1); std::vector second_axes(axes.begin() + index + 1, axes.end()); - bool keep_dim_first = keep_dim; - bool keep_dim_second = keep_dim; + bool keep_dim_first = keep_dim; + bool keep_dim_second = keep_dim; auto reduce_reshape_func = [&first_axes, &keep_dim_first, &second_axes, @@ -719,10 +867,11 @@ std::vector TwoStepBlockReduceInternal(const ir::Tensor& A, max_num_threads, &initial]() { bool check_bound = true; - std::vector out_shape(A->shape.begin(), A->shape.begin() + second_axes.front()); + std::vector out_shape(A->shape.begin(), + A->shape.begin() + second_axes.front()); if (second_axes.size() == 1) { int times = 1; - int tail = max_num_threads; + int tail = max_num_threads; for (; tail >= max_num_threads / 2; --tail) { if (lane % tail == 0) { check_bound = false; @@ -740,13 +889,15 @@ std::vector TwoStepBlockReduceInternal(const ir::Tensor& A, } } else { int times = 1; - int head = A->shape[second_axes.front()].as_int32(); - int tail = lane / head; + int head = A->shape[second_axes.front()].as_int32(); + int tail = lane / head; // from (1024, 512) check one size as tail. - for (int idx = (max_num_threads / tail); idx > (max_num_threads / 2 / tail); --idx) { + for (int idx = (max_num_threads / tail); + idx > (max_num_threads / 2 / tail); + --idx) { if (head % idx == 0) { check_bound = false; - times = idx; + times = idx; tail *= idx; break; } @@ -769,26 +920,32 @@ std::vector TwoStepBlockReduceInternal(const ir::Tensor& A, keep_dim_second = false; } else { keep_dim_second = true; - tail_count = A->shape.size() - out_shape.size(); + tail_count = A->shape.size() - out_shape.size(); for (int idx = 0; idx < tail_count; ++idx) { out_shape.push_back(Expr(1)); } } } else { - second_axes = {static_cast(out_shape.size()) - static_cast(first_axes.size()) - 1}; + second_axes = {static_cast(out_shape.size()) - + static_cast(first_axes.size()) - 1}; } int size_without_tail = out_shape.size() - tail_count; std::vector tail_strides(A->shape.size() - (size_without_tail - 2), 1); - for (int idx = static_cast(tail_strides.size()) - 2, index = static_cast(A->shape.size()) - 1; idx >= 0; + for (int idx = static_cast(tail_strides.size()) - 2, + index = static_cast(A->shape.size()) - 1; + idx >= 0; --idx, --index) { tail_strides[idx] = tail_strides[idx + 1] * A->shape[index].as_int32(); } auto out = Compute( out_shape, [=](const std::vector& indexs) -> Expr { - Expr index = indexs[size_without_tail - 1] + indexs[size_without_tail - 2] * out_shape[size_without_tail - 1]; - std::vector tmp_indexs(indexs.begin(), indexs.begin() + size_without_tail - 2); + Expr index = + indexs[size_without_tail - 1] + + indexs[size_without_tail - 2] * out_shape[size_without_tail - 1]; + std::vector tmp_indexs(indexs.begin(), + indexs.begin() + size_without_tail - 2); // last and the second of last. auto selected = ir::LT::Make(index, Expr(lane)); for (auto tail_stride : tail_strides) { @@ -796,7 +953,8 @@ std::vector TwoStepBlockReduceInternal(const ir::Tensor& A, index = index % Expr(tail_stride); } - CHECK_EQ(tmp_indexs.size(), A->shape.size()) << "Indexs size is not equal to Input shape!"; + CHECK_EQ(tmp_indexs.size(), A->shape.size()) + << "Indexs size is not equal to Input shape!"; if (check_bound) { return ir::Select::Make(selected, A(tmp_indexs), initial); } else { @@ -819,13 +977,18 @@ std::vector TwoStepBlockReduceInternal(const ir::Tensor& A, } if (first_axes.size()) { VLOG(3) << "Do Reduce Internal!"; - results.push_back( - reduce_func(results.size() ? results.back() : A, first_axes, keep_dim_first, output_name + "_internal")); + results.push_back(reduce_func(results.size() ? results.back() : A, + first_axes, + keep_dim_first, + output_name + "_internal")); results.back()->WithBuffer("local"); } if (second_axes.size()) { VLOG(3) << "Do Block Reduce!"; - auto res = block_reduce_func(results.size() ? results.back() : A, second_axes, keep_dim_second, output_name); + auto res = block_reduce_func(results.size() ? results.back() : A, + second_axes, + keep_dim_second, + output_name); results.push_back(res[1]); results.push_back(res[0]); } @@ -837,46 +1000,78 @@ std::vector TwoStepBlockReduceSum(const ir::Tensor& A, const std::vector& axes, const bool keep_dim, const std::string& output_name) { - return TwoStepBlockReduceInternal( - A, axes, keep_dim, output_name, ReduceSum, BlockReduceSumInternal, ir::Zero(A->type())); + return TwoStepBlockReduceInternal(A, + axes, + keep_dim, + output_name, + ReduceSum, + BlockReduceSumInternal, + ir::Zero(A->type())); } std::vector TwoStepBlockReduceProd(const ir::Tensor& A, const std::vector& axes, const bool keep_dim, const std::string& output_name) { - return TwoStepBlockReduceInternal( - A, axes, keep_dim, output_name, ReduceProd, BlockReduceProdInternal, lang::One(A->type())); + return TwoStepBlockReduceInternal(A, + axes, + keep_dim, + output_name, + ReduceProd, + BlockReduceProdInternal, + lang::One(A->type())); } std::vector TwoStepBlockReduceMax(const ir::Tensor& A, const std::vector& axes, const bool keep_dim, const std::string& output_name) { - return TwoStepBlockReduceInternal( - A, axes, keep_dim, output_name, ReduceMax, BlockReduceMaxInternal, lang::min_value(A->type())); + return TwoStepBlockReduceInternal(A, + axes, + keep_dim, + output_name, + ReduceMax, + BlockReduceMaxInternal, + lang::min_value(A->type())); } std::vector TwoStepBlockReduceMin(const ir::Tensor& A, const std::vector& axes, const bool keep_dim, const std::string& output_name) { - return TwoStepBlockReduceInternal( - A, axes, keep_dim, output_name, ReduceMin, BlockReduceMinInternal, lang::max_value(A->type())); + return TwoStepBlockReduceInternal(A, + axes, + keep_dim, + output_name, + ReduceMin, + BlockReduceMinInternal, + lang::max_value(A->type())); } std::vector TwoStepBlockReduceAll(const ir::Tensor& A, const std::vector& axes, const bool keep_dim, const std::string& output_name) { - return TwoStepBlockReduceInternal(A, axes, keep_dim, output_name, ReduceAll, BlockReduceAllInternal, Expr(true)); + return TwoStepBlockReduceInternal(A, + axes, + keep_dim, + output_name, + ReduceAll, + BlockReduceAllInternal, + Expr(true)); } std::vector TwoStepBlockReduceAny(const ir::Tensor& A, const std::vector& axes, const bool keep_dim, const std::string& output_name) { - return TwoStepBlockReduceInternal(A, axes, keep_dim, output_name, ReduceAny, BlockReduceAnyInternal, Expr(false)); + return TwoStepBlockReduceInternal(A, + axes, + keep_dim, + output_name, + ReduceAny, + BlockReduceAnyInternal, + Expr(false)); } } // namespace pe diff --git a/paddle/cinn/hlir/pe/reduction.h b/paddle/cinn/hlir/pe/reduction.h index 117eaf3aec73b..ceb82e8f6fe0b 100644 --- a/paddle/cinn/hlir/pe/reduction.h +++ b/paddle/cinn/hlir/pe/reduction.h @@ -26,10 +26,12 @@ namespace pe { * * @param A The input Tensor * @param stages The stage map - * @param axis Axis or axes along which a sum is performed. If axis is empty, the operation will sum over all elements - * of the input array. If axis is negative it counts from the last to the first axis. - * @param keep_dims If it is set true, the axes which are reduced are left in the result as dimensions with size one. - * With this option, the result will broadcast correctly against the input array. + * @param axis Axis or axes along which a sum is performed. If axis is empty, + * the operation will sum over all elements of the input array. If axis is + * negative it counts from the last to the first axis. + * @param keep_dims If it is set true, the axes which are reduced are left in + * the result as dimensions with size one. With this option, the result will + * broadcast correctly against the input array. * @param initial Starting value for the sum. * @param output_name The name of the output Tensor * @@ -37,7 +39,7 @@ namespace pe { */ ir::Tensor ReduceSum(const ir::Tensor& A, const std::vector& axis, - const bool keep_dims = false, + const bool keep_dims = false, const std::string& output_name = "T_Reduce_Sum_out"); /** @@ -45,10 +47,12 @@ ir::Tensor ReduceSum(const ir::Tensor& A, * * @param A The input Tensor * @param stages The stage map - * @param axis Axis or axes along which a production is performed. If axis is empty, the operation will product over all - * elements of the input array. If axis is negative it counts from the last to the first axis. - * @param keep_dims If it is set true, the axes which are reduced are left in the result as dimensions with size one. - * With this option, the result will broadcast correctly against the input array. + * @param axis Axis or axes along which a production is performed. If axis is + * empty, the operation will product over all elements of the input array. If + * axis is negative it counts from the last to the first axis. + * @param keep_dims If it is set true, the axes which are reduced are left in + * the result as dimensions with size one. With this option, the result will + * broadcast correctly against the input array. * @param initial Starting value for the production. * @param output_name The name of the output Tensor * @@ -56,7 +60,7 @@ ir::Tensor ReduceSum(const ir::Tensor& A, */ ir::Tensor ReduceProd(const ir::Tensor& A, const std::vector& axis, - const bool keep_dims = false, + const bool keep_dims = false, const std::string& output_name = "T_Reduce_Prod_out"); /** @@ -64,17 +68,19 @@ ir::Tensor ReduceProd(const ir::Tensor& A, * * @param A The input Tensor * @param stages The stage map - * @param axis Axis or axes to find the maximum over. If axis is empty, the operation will product over all elements of - * the input array. If axis is negative it counts from the last to the first axis. - * @param keep_dims If it is set true, the axes which are reduced are left in the result as dimensions with size one. - * With this option, the result will broadcast correctly against the input array. + * @param axis Axis or axes to find the maximum over. If axis is empty, the + * operation will product over all elements of the input array. If axis is + * negative it counts from the last to the first axis. + * @param keep_dims If it is set true, the axes which are reduced are left in + * the result as dimensions with size one. With this option, the result will + * broadcast correctly against the input array. * @param output_name The name of the output Tensor * * @return The result Tensor. */ ir::Tensor ReduceMax(const ir::Tensor& A, const std::vector& axis, - const bool keep_dims = false, + const bool keep_dims = false, const std::string& output_name = "T_Reduce_Max_out"); /** @@ -82,17 +88,19 @@ ir::Tensor ReduceMax(const ir::Tensor& A, * * @param A The input Tensor * @param stages The stage map - * @param axis Axis or axes to find the minimum over. If axis is empty, the operation will product over all elements of - * the input array. If axis is negative it counts from the last to the first axis. - * @param keep_dims If it is set true, the axes which are reduced are left in the result as dimensions with size one. - * With this option, the result will broadcast correctly against the input array. + * @param axis Axis or axes to find the minimum over. If axis is empty, the + * operation will product over all elements of the input array. If axis is + * negative it counts from the last to the first axis. + * @param keep_dims If it is set true, the axes which are reduced are left in + * the result as dimensions with size one. With this option, the result will + * broadcast correctly against the input array. * @param output_name The name of the output Tensor * * @return The result Tensor. */ ir::Tensor ReduceMin(const ir::Tensor& A, const std::vector& axis, - const bool keep_dims = false, + const bool keep_dims = false, const std::string& output_name = "T_Reduce_Min_out"); /** @@ -100,17 +108,19 @@ ir::Tensor ReduceMin(const ir::Tensor& A, * * @param A The input Tensor * @param stages The stage map - * @param axis Axis or axes to find the logic and over. If axis is empty, the operation will product over all elements - * of the input array. If axis is negative it counts from the last to the first axis. - * @param keep_dims If it is set true, the axes which are reduced are left in the result as dimensions with size one. - * With this option, the result will broadcast correctly against the input array. + * @param axis Axis or axes to find the logic and over. If axis is empty, the + * operation will product over all elements of the input array. If axis is + * negative it counts from the last to the first axis. + * @param keep_dims If it is set true, the axes which are reduced are left in + * the result as dimensions with size one. With this option, the result will + * broadcast correctly against the input array. * @param output_name The name of the output Tensor * * @return The result Tensor. */ ir::Tensor ReduceAll(const ir::Tensor& A, const std::vector& axis, - const bool keep_dims = false, + const bool keep_dims = false, const std::string& output_name = "T_Reduce_All_out"); /** @@ -118,17 +128,19 @@ ir::Tensor ReduceAll(const ir::Tensor& A, * * @param A The input Tensor * @param stages The stage map - * @param axis Axis or axes to find the logic or over. If axis is empty, the operation will product over all elements of - * the input array. If axis is negative it counts from the last to the first axis. - * @param keep_dims If it is set true, the axes which are reduced are left in the result as dimensions with size one. - * With this option, the result will broadcast correctly against the input array. + * @param axis Axis or axes to find the logic or over. If axis is empty, the + * operation will product over all elements of the input array. If axis is + * negative it counts from the last to the first axis. + * @param keep_dims If it is set true, the axes which are reduced are left in + * the result as dimensions with size one. With this option, the result will + * broadcast correctly against the input array. * @param output_name The name of the output Tensor * * @return The result Tensor. */ ir::Tensor ReduceAny(const ir::Tensor& A, const std::vector& axis, - const bool keep_dims = false, + const bool keep_dims = false, const std::string& output_name = "T_Reduce_Any_out"); /** @@ -139,10 +151,11 @@ ir::Tensor ReduceAny(const ir::Tensor& A, * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector WarpReduceMax(const ir::Tensor& A, - const int last_reduce_dim_num, - const bool keep_dim = false, - const std::string& output_name = "T_Warp_Reduce_Max_out"); +std::vector WarpReduceMax( + const ir::Tensor& A, + const int last_reduce_dim_num, + const bool keep_dim = false, + const std::string& output_name = "T_Warp_Reduce_Max_out"); /** * @brief compute the sum of array elements over the last dimension @@ -152,10 +165,11 @@ std::vector WarpReduceMax(const ir::Tensor& A, * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector WarpReduceSum(const ir::Tensor& A, - const int last_reduce_dim_num, - const bool keep_dim = false, - const std::string& output_name = "T_Warp_Reduce_Sum_out"); +std::vector WarpReduceSum( + const ir::Tensor& A, + const int last_reduce_dim_num, + const bool keep_dim = false, + const std::string& output_name = "T_Warp_Reduce_Sum_out"); /** * @brief compute the average of array elements over the last dimension @@ -165,219 +179,252 @@ std::vector WarpReduceSum(const ir::Tensor& A, * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector WarpReduceAvg(const ir::Tensor& A, - const int last_reduce_dim_num, - const bool keep_dim = false, - const std::string& output_name = "T_Warp_Reduce_Avg_out"); +std::vector WarpReduceAvg( + const ir::Tensor& A, + const int last_reduce_dim_num, + const bool keep_dim = false, + const std::string& output_name = "T_Warp_Reduce_Avg_out"); /** - * @brief compute the sum of array elements over the last dimension with block reduce. - * 'BlockReduceSumInternal' is used as the internal compute of reduce sum, do not use it directly. + * @brief compute the sum of array elements over the last dimension with block + * reduce. 'BlockReduceSumInternal' is used as the internal compute of reduce + * sum, do not use it directly. * * @param A The input Tensor. * @param last_reduce_dim_num the number of last reduce dimension. * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector BlockReduceSumInternal(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim = false, - const std::string& output_name = "T_Block_Reduce_Sum_Internal_out"); +std::vector BlockReduceSumInternal( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_Sum_Internal_out"); /** - * @brief compute the Product of array elements over the last dimension with block reduce. - * 'BlockReduceSumInternal' is used as the internal compute of reduce sum, do not use it directly. + * @brief compute the Product of array elements over the last dimension with + * block reduce. 'BlockReduceSumInternal' is used as the internal compute of + * reduce sum, do not use it directly. * * @param A The input Tensor. * @param last_reduce_dim_num the number of last reduce dimension. * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector BlockReduceProdInternal(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim = false, - const std::string& output_name = "T_Block_Reduce_Prod_Internal_out"); +std::vector BlockReduceProdInternal( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_Prod_Internal_out"); /** - * @brief compute the Max of array elements over the last dimension with block reduce. - * 'BlockReduceSumInternal' is used as the internal compute of reduce sum, do not use it directly. + * @brief compute the Max of array elements over the last dimension with block + * reduce. 'BlockReduceSumInternal' is used as the internal compute of reduce + * sum, do not use it directly. * * @param A The input Tensor. * @param last_reduce_dim_num the number of last reduce dimension. * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector BlockReduceMaxInternal(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim = false, - const std::string& output_name = "T_Block_Reduce_Max_Internal_out"); +std::vector BlockReduceMaxInternal( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_Max_Internal_out"); /** - * @brief compute the Min of array elements over the last dimension with block reduce. - * 'BlockReduceSumInternal' is used as the internal compute of reduce sum, do not use it directly. + * @brief compute the Min of array elements over the last dimension with block + * reduce. 'BlockReduceSumInternal' is used as the internal compute of reduce + * sum, do not use it directly. * * @param A The input Tensor. * @param last_reduce_dim_num the number of last reduce dimension. * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector BlockReduceMinInternal(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim = false, - const std::string& output_name = "T_Block_Reduce_Min_Internal_out"); +std::vector BlockReduceMinInternal( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_Min_Internal_out"); /** - * @brief compute the logic and of array elements over the last dimension with block reduce. - * 'BlockReduceSumInternal' is used as the internal compute of reduce sum, do not use it directly. + * @brief compute the logic and of array elements over the last dimension with + * block reduce. 'BlockReduceSumInternal' is used as the internal compute of + * reduce sum, do not use it directly. * * @param A The input Tensor. * @param last_reduce_dim_num the number of last reduce dimension. * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector BlockReduceAllInternal(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim = false, - const std::string& output_name = "T_Block_Reduce_All_Internal_out"); +std::vector BlockReduceAllInternal( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_All_Internal_out"); /** - * @brief compute the logic or of array elements over the last dimension with block reduce. - * 'BlockReduceSumInternal' is used as the internal compute of reduce sum, do not use it directly. + * @brief compute the logic or of array elements over the last dimension with + * block reduce. 'BlockReduceSumInternal' is used as the internal compute of + * reduce sum, do not use it directly. * * @param A The input Tensor. * @param last_reduce_dim_num the number of last reduce dimension. * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector BlockReduceAnyInternal(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim = false, - const std::string& output_name = "T_Block_Reduce_Any_Internal_out"); +std::vector BlockReduceAnyInternal( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_Any_Internal_out"); /** - * @brief compute the Sum of array elements over the last dimension with block reduce + * @brief compute the Sum of array elements over the last dimension with block + * reduce * * @param A The input Tensor. * @param last_reduce_dim_num the number of last reduce dimension. * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector BlockReduceSum(const ir::Tensor& A, - const std::vector& axes, - const int block_size, - const bool keep_dim = false, - const std::string& output_name = "T_Block_Reduce_Sum_out"); +std::vector BlockReduceSum( + const ir::Tensor& A, + const std::vector& axes, + const int block_size, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_Sum_out"); /** - * @brief compute the Product of array elements over the last dimension with block reduce + * @brief compute the Product of array elements over the last dimension with + * block reduce * * @param A The input Tensor. * @param last_reduce_dim_num the number of last reduce dimension. * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector BlockReduceProd(const ir::Tensor& A, - const std::vector& axes, - const int block_size, - const bool keep_dim = false, - const std::string& output_name = "T_Block_Reduce_Prod_out"); +std::vector BlockReduceProd( + const ir::Tensor& A, + const std::vector& axes, + const int block_size, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_Prod_out"); /** - * @brief compute the Max of array elements over the last dimension with block reduce + * @brief compute the Max of array elements over the last dimension with block + * reduce * * @param A The input Tensor. * @param last_reduce_dim_num the number of last reduce dimension. * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector BlockReduceMax(const ir::Tensor& A, - const std::vector& axes, - const int block_size, - const bool keep_dim = false, - const std::string& output_name = "T_Block_Reduce_Max_out"); +std::vector BlockReduceMax( + const ir::Tensor& A, + const std::vector& axes, + const int block_size, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_Max_out"); /** - * @brief compute the Min of array elements over the last dimension with block reduce + * @brief compute the Min of array elements over the last dimension with block + * reduce * * @param A The input Tensor. * @param last_reduce_dim_num the number of last reduce dimension. * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector BlockReduceMin(const ir::Tensor& A, - const std::vector& axes, - const int block_size, - const bool keep_dim = false, - const std::string& output_name = "T_Block_Reduce_Min_out"); +std::vector BlockReduceMin( + const ir::Tensor& A, + const std::vector& axes, + const int block_size, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_Min_out"); /** - * @brief compute the logic and of array elements over the last dimension with block reduce + * @brief compute the logic and of array elements over the last dimension with + * block reduce * * @param A The input Tensor. * @param last_reduce_dim_num the number of last reduce dimension. * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector BlockReduceAll(const ir::Tensor& A, - const std::vector& axes, - const int block_size, - const bool keep_dim = false, - const std::string& output_name = "T_Block_Reduce_All_out"); +std::vector BlockReduceAll( + const ir::Tensor& A, + const std::vector& axes, + const int block_size, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_All_out"); /** - * @brief compute the logic or of array elements over the last dimension with block reduce + * @brief compute the logic or of array elements over the last dimension with + * block reduce * * @param A The input Tensor. * @param last_reduce_dim_num the number of last reduce dimension. * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector BlockReduceAny(const ir::Tensor& A, - const std::vector& axes, - const int block_size, - const bool keep_dim = false, - const std::string& output_name = "T_Block_Reduce_Any_out"); +std::vector BlockReduceAny( + const ir::Tensor& A, + const std::vector& axes, + const int block_size, + const bool keep_dim = false, + const std::string& output_name = "T_Block_Reduce_Any_out"); /** - * @brief compute the value of array elements over the last dimension with block reduce + * @brief compute the value of array elements over the last dimension with block + * reduce * * @param A The input Tensor. * @param axes the reduce axes. * @param keep_dim keep the output tensor shape size as input. * @param output_name The name of the output Tensor. */ -std::vector BlockShuffleReduceSum(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name = "T_Reduce_Sum_out"); - -std::vector BlockShuffleReduceProd(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name = "T_Reduce_Prod_out"); - -std::vector BlockShuffleReduceMax(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name = "T_Reduce_Max_out"); - -std::vector BlockShuffleReduceMin(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name = "T_Reduce_Min_out"); - -std::vector BlockShuffleReduceAll(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name = "T_Reduce_All_out"); - -std::vector BlockShuffleReduceAny(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name = "T_Reduce_Any_out"); +std::vector BlockShuffleReduceSum( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name = "T_Reduce_Sum_out"); + +std::vector BlockShuffleReduceProd( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name = "T_Reduce_Prod_out"); + +std::vector BlockShuffleReduceMax( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name = "T_Reduce_Max_out"); + +std::vector BlockShuffleReduceMin( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name = "T_Reduce_Min_out"); + +std::vector BlockShuffleReduceAll( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name = "T_Reduce_All_out"); + +std::vector BlockShuffleReduceAny( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name = "T_Reduce_Any_out"); /** - * @brief compute the value of array elements over the last dimension with block reduce + * @brief compute the value of array elements over the last dimension with block + * reduce * * @param A The input Tensor. * @param axes the reduce axes. @@ -385,35 +432,41 @@ std::vector BlockShuffleReduceAny(const ir::Tensor& A, * @param output_name The name of the output Tensor. */ -std::vector TwoStepBlockReduceSum(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name = "T_Reduce_Sum_out"); - -std::vector TwoStepBlockReduceProd(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name = "T_Reduce_Prod_out"); - -std::vector TwoStepBlockReduceMax(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name = "T_Reduce_Max_out"); - -std::vector TwoStepBlockReduceMin(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name = "T_Reduce_Min_out"); - -std::vector TwoStepBlockReduceAll(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name = "T_Reduce_All_out"); - -std::vector TwoStepBlockReduceAny(const ir::Tensor& A, - const std::vector& axes, - const bool keep_dim, - const std::string& output_name = "T_Reduce_Any_out"); +std::vector TwoStepBlockReduceSum( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name = "T_Reduce_Sum_out"); + +std::vector TwoStepBlockReduceProd( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name = "T_Reduce_Prod_out"); + +std::vector TwoStepBlockReduceMax( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name = "T_Reduce_Max_out"); + +std::vector TwoStepBlockReduceMin( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name = "T_Reduce_Min_out"); + +std::vector TwoStepBlockReduceAll( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name = "T_Reduce_All_out"); + +std::vector TwoStepBlockReduceAny( + const ir::Tensor& A, + const std::vector& axes, + const bool keep_dim, + const std::string& output_name = "T_Reduce_Any_out"); } // namespace pe } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/pe/schedule.cc b/paddle/cinn/hlir/pe/schedule.cc index 68dd70044aa9d..6e7e571f53aba 100644 --- a/paddle/cinn/hlir/pe/schedule.cc +++ b/paddle/cinn/hlir/pe/schedule.cc @@ -47,7 +47,8 @@ ScheduleParam::ScheduleParam(common::Target::Arch arch) { break; } default: { - LOG(FATAL) << "Schedule params must be initialized with target x86 or nvgpu."; + LOG(FATAL) + << "Schedule params must be initialized with target x86 or nvgpu."; } } } @@ -61,12 +62,12 @@ int GetInnerSplitter(int origin, int other_axis) { two_exp *= 2; } two_exp = two_exp / 2; - int a = SplitEven(two_exp); - int b = two_exp / a; + int a = SplitEven(two_exp); + int b = two_exp / a; while (a * other_axis >= 1024 || b * other_axis >= 1024) { two_exp = two_exp / 2; - a = SplitEven(two_exp); - b = two_exp / a; + a = SplitEven(two_exp); + b = two_exp / a; } if (origin == two_exp) { return 2; @@ -86,7 +87,7 @@ int SplitEven(int origin) { int GetBasicFactor(const Type &type, const common::Target &target) { int target_native_vector_bits = target.get_target_bits() * 8; - int type_bits = type.bits(); + int type_bits = type.bits(); return target_native_vector_bits / type_bits; } @@ -95,7 +96,8 @@ int GetBetterSplitFactor(int shape, int split_factor) { while (better_factor > shape) { better_factor /= 2; } - if (better_factor < shape && better_factor != split_factor) return better_factor * 2; + if (better_factor < shape && better_factor != split_factor) + return better_factor * 2; return better_factor; } @@ -114,8 +116,8 @@ void ScheduleInjectiveCPU(poly::Stage *stage, const std::vector &output_shape, const common::Target &target, bool vectorizable) { - int dims = stage->n_out_dims(); - int factor = GetBasicFactor(stage->tensor()->type(), target); + int dims = stage->n_out_dims(); + int factor = GetBasicFactor(stage->tensor()->type(), target); poly::Iterator fused = stage->axis(0); if (dims >= 5) { fused = stage->Fuse({0, 1, 2}); @@ -128,8 +130,8 @@ void ScheduleInjectiveCPU(poly::Stage *stage, if (vectorizable) { poly::Iterator lo; poly::Iterator li; - int last_shape = stage->GetDimRange(dims - 1); - factor = GetVectorizeFactor(last_shape, factor); + int last_shape = stage->GetDimRange(dims - 1); + factor = GetVectorizeFactor(last_shape, factor); std::tie(lo, li) = stage->Split(stage->axis(dims - 1), factor); stage->Vectorize(li, factor); if (dims == 1) { @@ -144,17 +146,20 @@ void ScheduleInjectiveCPU1(poly::Stage *stage, bool vectorizable) { int dims = stage->n_out_dims(); if (dims > 1) { - CHECK_EQ(stage->n_out_dims(), stage->n_in_dims()) << "The dims of op are not equal"; + CHECK_EQ(stage->n_out_dims(), stage->n_in_dims()) + << "The dims of op are not equal"; CHECK_EQ(stage->n_out_dims(), output_shape.size()) << "The origin stage out dims should be same with output_shape sizes"; - poly::Iterator fused = stage->axis(dims - 1); + poly::Iterator fused = stage->axis(dims - 1); int target_native_vector_bits = target.get_target_bits() * 8; - int type_bits = stage->tensor()->type().bits(); - int prod_size = output_shape.back(); - // fuse conservatively for the complex index from poly and may not benefit a lot compared with llvm optimization, - // only fuse the last two dims when the last dimension is too small and can split and vectorize Todo: try reorder + int type_bits = stage->tensor()->type().bits(); + int prod_size = output_shape.back(); + // fuse conservatively for the complex index from poly and may not benefit a + // lot compared with llvm optimization, only fuse the last two dims when the + // last dimension is too small and can split and vectorize Todo: try reorder if (output_shape.back() * type_bits < target_native_vector_bits) { - int last_two_dim_bits = output_shape[dims - 2] * output_shape[dims - 1] * type_bits; + int last_two_dim_bits = + output_shape[dims - 2] * output_shape[dims - 1] * type_bits; if (last_two_dim_bits % target_native_vector_bits == 0) { fused = stage->Fuse(dims - 2, dims - 1); prod_size *= output_shape[dims - 2]; @@ -168,7 +173,7 @@ void ScheduleInjectiveCPU1(poly::Stage *stage, stage->Vectorize(fused, split_factor); } } else { - auto ssplit = stage->Split(fused, split_factor); + auto ssplit = stage->Split(fused, split_factor); auto &j_outer = std::get<0>(ssplit); auto &j_inner = std::get<1>(ssplit); stage->Vectorize(j_inner, split_factor); @@ -180,8 +185,10 @@ void ScheduleInjectiveCPU1(poly::Stage *stage, } } -int GetArrayPackingFactor(int shape, const Type &type, const common::Target &target) { - int split_base = GetBasicFactor(type, target); +int GetArrayPackingFactor(int shape, + const Type &type, + const common::Target &target) { + int split_base = GetBasicFactor(type, target); int split_factor = 1; // temporily use shape-1 instead of shape for isl wrong for1 elimination int i = split_base * split_base < shape ? split_base * split_base : shape; @@ -194,7 +201,9 @@ int GetArrayPackingFactor(int shape, const Type &type, const common::Target &tar return split_factor; } -void MatmulScheduleCUDA(poly::StageMap stages, const ir::Tensor &output, const common::Target &target) { +void MatmulScheduleCUDA(poly::StageMap stages, + const ir::Tensor &output, + const common::Target &target) { stages[output]->Split(1, 2); stages[output]->Bind(0, "blockIdx.x"); stages[output]->Bind(1, "threadIdx.x"); @@ -207,20 +216,22 @@ void MatmulScheduleCPU(poly::StageMap stages, CHECK_EQ(output->type(), packedB->type()); int basic_split_factor = GetBasicFactor(packedB->type(), target); // packedB - int packedB_dims = stages[packedB]->axis_names().size(); - int packed_last_dim = packedB->shape[packedB_dims - 1].as_int32(); - int packedB_split_factor = GetBetterSplitFactor(packed_last_dim, basic_split_factor); + int packedB_dims = stages[packedB]->axis_names().size(); + int packed_last_dim = packedB->shape[packedB_dims - 1].as_int32(); + int packedB_split_factor = + GetBetterSplitFactor(packed_last_dim, basic_split_factor); // tempory solution for indivisible case - if (packedB_split_factor >= 8 && packed_last_dim % packedB_split_factor == 0) { + if (packedB_split_factor >= 8 && + packed_last_dim % packedB_split_factor == 0) { stages[packedB]->Vectorize(packedB_dims - 1, packedB_split_factor); } // output int output_size = output->shape.size(); // M, N - int M = output->shape[output_size - 2].as_int32(); - int N = output->shape[output_size - 1].as_int32(); - int bm = GetArrayPackingFactor(M, output->type(), target); - int bn = GetArrayPackingFactor(N, output->type(), target); + int M = output->shape[output_size - 2].as_int32(); + int N = output->shape[output_size - 1].as_int32(); + int bm = GetArrayPackingFactor(M, output->type(), target); + int bn = GetArrayPackingFactor(N, output->type(), target); int out_axis_dims = stages[output]->axis_names().size(); CHECK_GE(out_axis_dims, 3U) << "output tensor's size should be at least 3"; poly::Iterator i_axis = stages[output]->axis(out_axis_dims - 3); @@ -252,11 +263,11 @@ void MatmulScheduleCPU(poly::StageMap stages, all_axes_outer.push_back(j_axis); } // K - int K = packedB->shape[packedB->shape.size() - 2].as_int32(); + int K = packedB->shape[packedB->shape.size() - 2].as_int32(); int k_split_factor = GetBetterSplitFactor(K, basic_split_factor); - out_axis_dims = stages[output]->axis_names().size(); - auto k_axis = stages[output]->axis(out_axis_dims - 1); - bool is_k_splited = false; + out_axis_dims = stages[output]->axis_names().size(); + auto k_axis = stages[output]->axis(out_axis_dims - 1); + bool is_k_splited = false; if (k_split_factor >= 4) { auto axes = stages[output]->Split(k_axis, k_split_factor); k_axes.push_back(std::get<0>(axes)); @@ -296,12 +307,14 @@ void MatmulScheduleCPU(poly::StageMap stages, stages[output]->Reorder(all_axes); // vectorize output's last dimemsion auto out_domain = stages[output]->transformed_domain(); - auto range = poly::isl_set_get_axis_range(out_domain.get(), out_axis_dims - 1); - auto &min = std::get<0>(range); - auto &max = std::get<1>(range); + auto range = + poly::isl_set_get_axis_range(out_domain.get(), out_axis_dims - 1); + auto &min = std::get<0>(range); + auto &max = std::get<1>(range); CHECK_EQ(min.get_num_si(), 0) << "axis range should begin from zero"; - int out_last_dim = max.get_num_si() + 1; - int output_split_factor = GetBetterSplitFactor(out_last_dim, basic_split_factor); + int out_last_dim = max.get_num_si() + 1; + int output_split_factor = + GetBetterSplitFactor(out_last_dim, basic_split_factor); // tempory solution for indivisible case if (output_split_factor >= 8 && packed_last_dim % output_split_factor == 0) { stages[output]->Vectorize(out_axis_dims - 1, output_split_factor); @@ -312,19 +325,20 @@ void MulScheduleCPU(poly::StageMap stages, const ir::Tensor &output, const ir::Tensor &reduce_first, const common::Target &target) { - int split_factor = GetBasicFactor(output->type(), target); - auto out_reduce_axis = output->reduce_axis; + int split_factor = GetBasicFactor(output->type(), target); + auto out_reduce_axis = output->reduce_axis; std::vector reduce_first_shape = reduce_first->shape; - std::vector output_shape = output->shape; + std::vector output_shape = output->shape; CHECK_EQ(reduce_first_shape.size(), 3U); CHECK_EQ(output_shape.size(), 2U); // reduce_first init - auto reduce_first_init = reduce_first->GetInitTensor(stages, target); + auto reduce_first_init = reduce_first->GetInitTensor(stages, target); int reduce_first_init_dim = stages[reduce_first_init]->axis_names().size(); - stages[reduce_first_init]->ComputeAt2(stages[reduce_first], reduce_first_init_dim - 2); + stages[reduce_first_init]->ComputeAt2(stages[reduce_first], + reduce_first_init_dim - 2); // output init - auto out_init = output->GetInitTensor(stages, target); + auto out_init = output->GetInitTensor(stages, target); int out_init_dim = stages[out_init]->axis_names().size(); stages[out_init]->ComputeAt2(stages[output], out_init_dim - 1); // reduce_first @@ -349,13 +363,14 @@ int GetThreadBindAxis(const std::vector &shape) { return thread_axis; } -int GetBlockBindAxis(const std::vector &shape, const int thread_axis) { +int GetBlockBindAxis(const std::vector &shape, + const int thread_axis) { int block_axis = 0, max_dim_size = shape[0].as_int32(); for (int idx = 0; idx <= thread_axis; ++idx) { if (max_dim_size < shape[idx].as_int32()) { if (idx < thread_axis) { max_dim_size = shape[idx].as_int32(); - block_axis = idx; + block_axis = idx; } else { if (max_dim_size == 1) { block_axis = thread_axis; @@ -371,12 +386,16 @@ void CudaReduceSchedule(poly::StageMap stages, int last_dimension_num, const common::Target &target) { int parallel_thread_num = 1; - for (int idx = output->shape.size() - 1; idx >= static_cast(output->shape.size()) - last_dimension_num; --idx) { + for (int idx = output->shape.size() - 1; + idx >= static_cast(output->shape.size()) - last_dimension_num; + --idx) { parallel_thread_num *= output->shape[idx].as_int32(); } int index = output->shape.size() - last_dimension_num; - for (int idx = output->shape.size() - last_dimension_num; idx < static_cast(output->shape.size()) - 1; ++idx) { + for (int idx = output->shape.size() - last_dimension_num; + idx < static_cast(output->shape.size()) - 1; + ++idx) { stages[output]->Fuse(index, index + 1); } @@ -397,7 +416,10 @@ void CudaReduceSchedule(poly::StageMap stages, } } -void CudaWarpReduceSchedule(poly::StageMap stages, ir::Tensor tmp_out, ir::Tensor out, const common::Target &target) { +void CudaWarpReduceSchedule(poly::StageMap stages, + ir::Tensor tmp_out, + ir::Tensor out, + const common::Target &target) { int sum_out_dim = 1; for (int idx = 0; idx < static_cast(tmp_out->shape.size()) - 2; ++idx) { stages[out]->Fuse(0, 1); @@ -460,8 +482,10 @@ void CudaBlockReduceSchedule(poly::StageMap stages, const common::Target &target) { int output_shape_size_without_reduce = tmp_out->shape.size() - 1; // fuse last parallel dimension - for (int idx = 0; idx < reduce_tmp_out->shape.size() - tmp_out->shape.size(); ++idx) { - stages[reduce_tmp_out]->Fuse(output_shape_size_without_reduce, output_shape_size_without_reduce + 1); + for (int idx = 0; idx < reduce_tmp_out->shape.size() - tmp_out->shape.size(); + ++idx) { + stages[reduce_tmp_out]->Fuse(output_shape_size_without_reduce, + output_shape_size_without_reduce + 1); } // fuse parallel dimension @@ -490,8 +514,11 @@ void CudaBlockReduceSchedule(poly::StageMap stages, stages[out]->Bind(0, "blockIdx.x"); } -void CudaBlockShuffleReduceSchedule( - poly::StageMap stages, ir::Tensor reshape, ir::Tensor internal, ir::Tensor out, const common::Target &target) { +void CudaBlockShuffleReduceSchedule(poly::StageMap stages, + ir::Tensor reshape, + ir::Tensor internal, + ir::Tensor out, + const common::Target &target) { int fuse_times = internal->shape.size() - 2; for (int idx = 0; idx < fuse_times; ++idx) { stages[internal]->Fuse(0, 1); @@ -559,7 +586,10 @@ void CudaTwoStepReduceSchedule(poly::StageMap stages, stages[out]->Bind(0, "blockIdx.x"); } -void SoftmaxScheduleCPU(poly::StageMap stage, const ir::Tensor &output, const ir::Tensor &temp, int axis) { +void SoftmaxScheduleCPU(poly::StageMap stage, + const ir::Tensor &output, + const ir::Tensor &temp, + int axis) { if (axis == -1) { axis += output->shape.size(); } @@ -572,8 +602,10 @@ void SoftmaxScheduleCPU(poly::StageMap stage, const ir::Tensor &output, const ir stage[temp]->ComputeAt(stage[output], 0); } -void GlobalPoolScheduleGPU(poly::StageMap stages, const std::vector &output, const common::Target &target) { - auto &out = output[0]; +void GlobalPoolScheduleGPU(poly::StageMap stages, + const std::vector &output, + const common::Target &target) { + auto &out = output[0]; auto &reduce = output[1]; stages[out]->Fuse(0, 1); stages[out]->Split(0, 32); @@ -583,13 +615,17 @@ void GlobalPoolScheduleGPU(poly::StageMap stages, const std::vector stages[reduce]->SetBuffer("local"); stages[reduce]->Bind(2, "threadIdx.x"); } -void PoolScheduleCPU(poly::StageMap stages, const ir::Tensor &output, const common::Target &target) { +void PoolScheduleCPU(poly::StageMap stages, + const ir::Tensor &output, + const common::Target &target) { CHECK_GE(stages[output]->n_out_dims(), 2); stages[output]->Fuse({0, 1}); stages[output]->Parallel(0); } -void PoolScheduleGPU(poly::StageMap stages, ir::Tensor &output, const common::Target &target) { +void PoolScheduleGPU(poly::StageMap stages, + ir::Tensor &output, + const common::Target &target) { CHECK_GE(stages[output]->axis_names().size(), 4); stages[output]->Fuse({0, 1, 2, 3}); stages[output]->Split(0, 1024); @@ -642,7 +678,7 @@ void GetConv2dFactors(absl::flat_hash_map *factors, } } int bn_base = GetBasicFactor(type, target); - int oc_bn = 1; + int oc_bn = 1; for (int i = bn_base; i > 1; i--) { if (oc < 1) break; if (oc % i == 0) { @@ -669,7 +705,7 @@ void GetConv2dFactors(absl::flat_hash_map *factors, (*factors)["oc_bn"] = oc_bn; (*factors)["ic_bn"] = ic_bn; (*factors)["fc_bn"] = fc_bn; - int ow_bn = 1; + int ow_bn = 1; if (oh < 1) { for (int i = bn_base; i > 1; i--) { @@ -689,7 +725,7 @@ void GetConv2dFactors(absl::flat_hash_map *factors, ow_bn = i; for (int j = oh; j >= 1; j--) { if (oh % j == 0 && j * ow_bn <= 16) { - oh_bn = j; + oh_bn = j; (*factors)["oh_bn"] = oh_bn; (*factors)["ow_bn"] = ow_bn; return; @@ -708,7 +744,7 @@ void GetConv2d1x1Factors(absl::flat_hash_map *factors, const Type &type, const common::Target &target) { int bn_base = GetBasicFactor(type, target); - int oc_bn = 1; + int oc_bn = 1; for (int i = bn_base; i > 1; i--) { if (oc < 1) break; if (oc % i == 0) { @@ -726,16 +762,16 @@ void GetConv2d1x1Factors(absl::flat_hash_map *factors, } (*factors)["oc_bn"] = oc_bn; (*factors)["ic_bn"] = ic_bn; - int ow_bn = 1; - int oh_bn = 1; - int begin = std::min(ow, bn_base); + int ow_bn = 1; + int oh_bn = 1; + int begin = std::min(ow, bn_base); for (int i = begin; i >= 1; i--) { if (ow < 1) break; if (ow % i == 0) { ow_bn = i; for (int j = oh; j >= 1; j--) { if (oh % j == 0 && j * ow_bn <= 16) { - oh_bn = j; + oh_bn = j; (*factors)["oh_bn"] = oh_bn; (*factors)["ow_bn"] = ow_bn; return; @@ -752,8 +788,9 @@ std::string GenerateX86ConvKey(const std::vector &input_shape, const std::vector &dilations, const int &index, const std::string &model_name) { - // format: (model_name + index +)schedule_name + input_shape + weight_shape + strides + paddings + dilations - // e.g. resnet18 0 X86ScheduleConv input 1 3 224 224 weight 64 3 7 7 stride 2 2 padding 3 3 dilation 1 1 + // format: (model_name + index +)schedule_name + input_shape + weight_shape + + // strides + paddings + dilations e.g. resnet18 0 X86ScheduleConv input 1 3 + // 224 224 weight 64 3 7 7 stride 2 2 padding 3 3 dilation 1 1 std::string key; if (model_name != "") { key = model_name + " index " + std::to_string(index) + " "; @@ -789,7 +826,8 @@ std::string GenerateX86ConvKey(const std::vector &input_shape, const std::vector &dilations, const int &index, const std::string &model_name) { - // format: (model_name + index +)schedule_name + input_shape + weight_shape + strides + paddings + dilations + // format: (model_name + index +)schedule_name + input_shape + weight_shape + + // strides + paddings + dilations std::string key; if (model_name != "") { key = model_name + " index " + std::to_string(index) + " "; @@ -820,8 +858,8 @@ std::string GenerateX86ConvKey(const std::vector &input_shape, void CreateX86SerialData(const std::string &file_name) { /** The format of serial data is: - * hash_key: schedule_name + shape of input + shape of weights + stride + padding + dilation - * value: vector of params + * hash_key: schedule_name + shape of input + shape of weights + stride + + * padding + dilation value: vector of params */ SaveSerialData(CreateX86Params(), file_name); } @@ -835,16 +873,18 @@ void Conv2d_NCHWc_1X1_Schedule_CPU(poly::StageMap stages, const common::Target &target, const std::string &key, bool do_padding) { - CHECK(target.arch == Target::Arch::X86) << "Conv2d_NCHWc_1X1_Schedule_CPU schedule only used in x86"; + CHECK(target.arch == Target::Arch::X86) + << "Conv2d_NCHWc_1X1_Schedule_CPU schedule only used in x86"; CHECK(packed_out.defined()); CHECK(input_pad.defined()); auto type = packed_out->type(); absl::flat_hash_map conv2d_factors; - CHECK_EQ(packed_out->shape.size(), 5U) << "packed_out's shape size should be 5"; - Expr h_out = common::AutoSimplify(packed_out->shape[2]); - Expr w_out = common::AutoSimplify(packed_out->shape[3]); - int oh = h_out.as_int32(); - int ow = w_out.as_int32(); + CHECK_EQ(packed_out->shape.size(), 5U) + << "packed_out's shape size should be 5"; + Expr h_out = common::AutoSimplify(packed_out->shape[2]); + Expr w_out = common::AutoSimplify(packed_out->shape[3]); + int oh = h_out.as_int32(); + int ow = w_out.as_int32(); int basic_split_factor = GetBasicFactor(type, target); GetConv2dFactors(&conv2d_factors, -1, -1, -1, oh, ow, type, target, key); int oh_bn_size = conv2d_factors["oh_bn"]; @@ -852,8 +892,8 @@ void Conv2d_NCHWc_1X1_Schedule_CPU(poly::StageMap stages, auto input_shape = input_pad->shape; CHECK_EQ(input_shape.size(), 5U) << "input shape size should be 5"; - Expr oc_bn = common::AutoSimplify(packed_out->shape.back()); - Expr ic_bn = common::AutoSimplify(input_shape.back()); + Expr oc_bn = common::AutoSimplify(packed_out->shape.back()); + Expr ic_bn = common::AutoSimplify(input_shape.back()); int oc_bn_size = oc_bn.as_int32(); int ic_bn_size = ic_bn.as_int32(); VLOG(3) << "oh_bn_size " << oh_bn_size; @@ -863,23 +903,28 @@ void Conv2d_NCHWc_1X1_Schedule_CPU(poly::StageMap stages, // data if (data.defined()) { - CHECK_GE(stages[data]->n_out_dims(), 3U) << "data's out_dims should be more than 3"; + CHECK_GE(stages[data]->n_out_dims(), 3U) + << "data's out_dims should be more than 3"; stages[data]->Fuse({0, 1, 2}); stages[data]->ComputeInline(); } // input_pad if (do_padding) { - CHECK_GE(stages[input_pad]->n_out_dims(), 3U) << "input_pad's out_dims should be more than 3"; + CHECK_GE(stages[input_pad]->n_out_dims(), 3U) + << "input_pad's out_dims should be more than 3"; stages[input_pad]->Fuse({0, 1, 2}); - stages[input_pad]->Vectorize(stages[input_pad]->n_out_dims() - 1, input_pad->shape.back().as_int32()); + stages[input_pad]->Vectorize(stages[input_pad]->n_out_dims() - 1, + input_pad->shape.back().as_int32()); } else { stages[input_pad]->ComputeInline(); } // weights if (weights_dilation.defined()) { - CHECK_GE(stages[weights_dilation]->n_out_dims(), 3U) << "weights_dilation's out_dims should be more than 3"; - // oc_outer, ic_outer, oh, ow, ic_inner, oc_inner -> oc_outer, oh, ic_outer, ow, ic_inner, oc_inner + CHECK_GE(stages[weights_dilation]->n_out_dims(), 3U) + << "weights_dilation's out_dims should be more than 3"; + // oc_outer, ic_outer, oh, ow, ic_inner, oc_inner -> oc_outer, oh, ic_outer, + // ow, ic_inner, oc_inner stages[weights_dilation]->Reorder({2, 1}); stages[weights_dilation]->Fuse({0, 1}); } @@ -893,44 +938,56 @@ void Conv2d_NCHWc_1X1_Schedule_CPU(poly::StageMap stages, // [batch, oc_outer, oh_outer, oh_inner, ow_outer, ow_inner, oc_inner] -> // [batch_oc_outer_oh_outer_fused, oh_inner, ow_outer, ow_inner, oc_inner] stages[packed_out]->Fuse({0, 1, 2}); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); - // CC: [batch, oh, ow, oc, ic, kh, kw] -> [batch_oc_outer_oh_outer_fused, oh_inner, ow, oc_inner, ic, kh, kw] + // CC: [batch, oh, ow, oc, ic, kh, kw] -> [batch_oc_outer_oh_outer_fused, + // oh_inner, ow, oc_inner, ic, kh, kw] stages[CC]->ComputeAt2(stages[packed_out], 0); VLOG(3) << "cache write shape: " << utils::Join(CC->shape, ", "); // tempory solution because reorder may be wrong before ComputeAt - // reorder: [batch_oc_outer_oh_outer_fused, oh_inner, ow_outer, ow_inner, oc_inner] -> - // [batch_oc_outer_oh_outer_fused, ow_outer, oh_inner, ow_inner, oc_inner] + // reorder: [batch_oc_outer_oh_outer_fused, oh_inner, ow_outer, ow_inner, + // oc_inner] -> [batch_oc_outer_oh_outer_fused, ow_outer, oh_inner, ow_inner, + // oc_inner] stages[packed_out]->Reorder({2, 1}); - stages[packed_out]->Vectorize(stages[packed_out]->n_out_dims() - 1, packed_out->shape.back().as_int32()); - VLOG(3) << "stages[packed_out]->transformed_domain()" << stages[packed_out]->transformed_domain(); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + stages[packed_out]->Vectorize(stages[packed_out]->n_out_dims() - 1, + packed_out->shape.back().as_int32()); + VLOG(3) << "stages[packed_out]->transformed_domain()" + << stages[packed_out]->transformed_domain(); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); // CC: [batch_oc_outer_oh_outer_fused, oh_inner, ow, oc_inner, ic, kh, kw] // split ow stages[CC]->Split(2, ow_bn_size); - // reorder: [batch_oc_outer_oh_outer_fused, oh_inner, ow_outer, ow_inner, oc_inner, ic, kh, kw] -> - // [batch_oc_outer_oh_outer_fused, oh_inner, ow_outer, ow_inner, oc_inner, ic, kh, kw] + // reorder: [batch_oc_outer_oh_outer_fused, oh_inner, ow_outer, ow_inner, + // oc_inner, ic, kh, kw] -> [batch_oc_outer_oh_outer_fused, oh_inner, + // ow_outer, ow_inner, oc_inner, ic, kh, kw] stages[CC]->Reorder({2, 1}); // split ic - // CC: [batch_oc_outer_oh_outer_fused, ow_outer, oh_inner, ow_inner, oc_inner, ic, kh, kw] + // CC: [batch_oc_outer_oh_outer_fused, ow_outer, oh_inner, ow_inner, oc_inner, + // ic, kh, kw] stages[CC]->Split(5, ic_bn_size); - // reorder: [batch_oc_outer_oh_outer_fused, ow_outer, oh_inner, ow_inner, oc_inner, ic_outer, ic_inner, kh, kw] -> - // [batch_oc_outer_oh_outer_fused, ow_outer, ic_outer, ic_inner, oh_inner, ow_inner, oc_inner, kh, kw] + // reorder: [batch_oc_outer_oh_outer_fused, ow_outer, oh_inner, ow_inner, + // oc_inner, ic_outer, ic_inner, kh, kw] -> [batch_oc_outer_oh_outer_fused, + // ow_outer, ic_outer, ic_inner, oh_inner, ow_inner, oc_inner, kh, kw] auto oh_inner = stages[CC]->axis(2); auto ow_inner = stages[CC]->axis(3); auto oc_inner = stages[CC]->axis(4); auto ic_outer = stages[CC]->axis(5); auto ic_inner = stages[CC]->axis(6); stages[CC]->Reorder({ic_outer, ic_inner, oh_inner, ow_inner, oc_inner}); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); - stages[CC]->Vectorize(stages[CC]->n_out_dims() - 3, CC->shape.back().as_int32()); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); + stages[CC]->Vectorize(stages[CC]->n_out_dims() - 3, + CC->shape.back().as_int32()); // unroll ow_inner, oh_inner VLOG(3) << stages[CC]->transformed_domain(); // CC_init auto CC_init = CC->GetInitTensor(stages, target); - stages[CC_init]->Vectorize(stages[CC_init]->n_out_dims() - 1, CC_init->shape.back().as_int32()); + stages[CC_init]->Vectorize(stages[CC_init]->n_out_dims() - 1, + CC_init->shape.back().as_int32()); stages[CC]->Unroll(stages[CC]->n_out_dims() - 4); stages[CC]->Unroll(stages[CC]->n_out_dims() - 5); stages[CC_init]->Unroll(stages[CC_init]->n_out_dims() - 2); @@ -941,18 +998,20 @@ void Conv2d_NCHWc_1X1_Schedule_CPU(poly::StageMap stages, stages[res]->Split(1, oc_bn_size); stages[res]->Split(3, oh_bn_size); stages[res]->Split(5, ow_bn_size); - // reorder: [n, oc_outer, oc_inner, oh_outer, oh_inner, ow_outer, ow_inner] -> - // [n, oc_outer, oh_outer, ow_outer, oh_inner, ow_inner, oc_inner] + // reorder: [n, oc_outer, oc_inner, oh_outer, oh_inner, ow_outer, ow_inner] + // -> [n, oc_outer, oh_outer, ow_outer, oh_inner, ow_inner, oc_inner] auto oc_inner1 = stages[res]->axis(2); auto oh_outer1 = stages[res]->axis(3); auto oh_inner1 = stages[res]->axis(4); auto ow_outer1 = stages[res]->axis(5); auto ow_inner1 = stages[res]->axis(6); - stages[res]->Reorder({oh_outer1, ow_outer1, oh_inner1, ow_inner1, oc_inner1}); + stages[res]->Reorder( + {oh_outer1, ow_outer1, oh_inner1, ow_inner1, oc_inner1}); // stages[res]->Fuse({0, 1, 2}); // Todo: computeAt according to forloops' range // stages[packed_out]->ComputeAt2(stages[res], 2); - VLOG(3) << "stages[res]->transformed_domain()" << stages[res]->transformed_domain(); + VLOG(3) << "stages[res]->transformed_domain()" + << stages[res]->transformed_domain(); } } @@ -963,26 +1022,28 @@ void Conv2d_NCHWc_1X1_Schedule_CPU_Nofuse(poly::StageMap stages, const ir::Tensor &weights_dilation, const ir::Tensor &data, const common::Target &target) { - CHECK(target.arch == Target::Arch::X86) << "Conv2d_NCHWc_1X1_Schedule_CPU_Nofuse schedule only used in x86"; + CHECK(target.arch == Target::Arch::X86) + << "Conv2d_NCHWc_1X1_Schedule_CPU_Nofuse schedule only used in x86"; CHECK(packed_out.defined()); CHECK(input_pad.defined()); auto type = packed_out->type(); absl::flat_hash_map conv2d_factors; - CHECK_EQ(packed_out->shape.size(), 5U) << "packed_out's shape size should be 5"; - Expr h_out = common::AutoSimplify(packed_out->shape[2]); - Expr w_out = common::AutoSimplify(packed_out->shape[3]); - int oh = h_out.as_int32(); - int ow = w_out.as_int32(); + CHECK_EQ(packed_out->shape.size(), 5U) + << "packed_out's shape size should be 5"; + Expr h_out = common::AutoSimplify(packed_out->shape[2]); + Expr w_out = common::AutoSimplify(packed_out->shape[3]); + int oh = h_out.as_int32(); + int ow = w_out.as_int32(); int basic_split_factor = GetBasicFactor(type, target); GetConv2d1x1Factors(&conv2d_factors, -1, -1, oh, ow, type, target); int oh_bn_size = conv2d_factors["oh_bn"]; int ow_bn_size = conv2d_factors["ow_bn"]; auto input_shape = input_pad->shape; - int shape_size = input_shape.size(); + int shape_size = input_shape.size(); CHECK_EQ(shape_size, 5U) << "input shape size should be 5"; - Expr oc_bn = common::AutoSimplify(packed_out->shape.back()); - Expr ic_bn = common::AutoSimplify(input_shape.back()); + Expr oc_bn = common::AutoSimplify(packed_out->shape.back()); + Expr ic_bn = common::AutoSimplify(input_shape.back()); int oc_bn_size = oc_bn.as_int32(); int ic_bn_size = ic_bn.as_int32(); VLOG(3) << "ow_bn_size" << ow_bn_size; @@ -995,7 +1056,8 @@ void Conv2d_NCHWc_1X1_Schedule_CPU_Nofuse(poly::StageMap stages, } // weights if (weights_dilation.defined()) { - CHECK_GE(stages[weights_dilation]->n_out_dims(), 3U) << "weights_dilation's out_dims should be more than 3"; + CHECK_GE(stages[weights_dilation]->n_out_dims(), 3U) + << "weights_dilation's out_dims should be more than 3"; // Reorder: [oc_outer, ic_outer, oh, ow, ic_inner, oc_inner] -> // [oc_outer, oh, ic_outer, ow, ic_inner, oc_inner] stages[weights_dilation]->Reorder({2, 1}); @@ -1003,45 +1065,57 @@ void Conv2d_NCHWc_1X1_Schedule_CPU_Nofuse(poly::StageMap stages, // packed_out auto CC = stages[packed_out]->CacheWrite("global", stages, packed_out); - VLOG(3) << "stages[packed_out]->transformed_domain()" << stages[packed_out]->transformed_domain(); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + VLOG(3) << "stages[packed_out]->transformed_domain()" + << stages[packed_out]->transformed_domain(); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); // packed_out: [batch, oc_outer, oh, ow, oc_inner] // split oh, ow stages[packed_out]->Split(2, oh_bn_size); stages[packed_out]->Split(4, ow_bn_size); // CC: [batch, oc_outer, oh, ow, oc_inner] - // packed_out: [batch, oc_outer, oh_outer, oh_inner, ow_outer, ow_inner, oc_inner] + // packed_out: [batch, oc_outer, oh_outer, oh_inner, ow_outer, ow_inner, + // oc_inner] stages[CC]->ComputeAt2(stages[packed_out], 2); - VLOG(3) << "stages[packed_out]->transformed_domain()" << stages[packed_out]->transformed_domain(); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + VLOG(3) << "stages[packed_out]->transformed_domain()" + << stages[packed_out]->transformed_domain(); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); // tempory solution because reordering before computeAt may be wrong - // reorder: [batch, oc_outer, oh_outer, oh_inner, ow_outer, ow_inner, oc_inner] -> - // [batch, oc_outer, oh_outer, ow_outer, oh_inner, ow_inner, oc_inner] + // reorder: [batch, oc_outer, oh_outer, oh_inner, ow_outer, ow_inner, + // oc_inner] -> [batch, oc_outer, oh_outer, ow_outer, oh_inner, ow_inner, + // oc_inner] stages[packed_out]->Reorder({4, 3}); - stages[packed_out]->Vectorize(stages[packed_out]->n_out_dims() - 1, packed_out->shape.back().as_int32()); + stages[packed_out]->Vectorize(stages[packed_out]->n_out_dims() - 1, + packed_out->shape.back().as_int32()); // split oh, ow // CC: [batch, oc_outer, oh_outer, oh_inner, ow, oc_inner, ic, kh, kw] stages[CC]->Split(4, ow_bn_size); - // CC: [batch, oc_outer, oh_outer, oh_inner, ow_outer, ow_inner, oc_inner, ic, kh, kw] - // split ic + // CC: [batch, oc_outer, oh_outer, oh_inner, ow_outer, ow_inner, oc_inner, ic, + // kh, kw] split ic stages[CC]->Split(7, ic_bn_size); - // reorder: [batch, oc_outer, oh_outer, oh_inner, ow_outer, ow_inner, oc_inner, ic_outer, ic_inner, kh, kw] -> - // [batch, oc_outer, oh_outer, ow_outer, ic_outer, ic_inner, oh_inner, ow_inner, oc_inner, kh, kw] + // reorder: [batch, oc_outer, oh_outer, oh_inner, ow_outer, ow_inner, + // oc_inner, ic_outer, ic_inner, kh, kw] -> [batch, oc_outer, oh_outer, + // ow_outer, ic_outer, ic_inner, oh_inner, ow_inner, oc_inner, kh, kw] auto oh_inner = stages[CC]->axis(3); auto ow_outer = stages[CC]->axis(4); auto ow_inner = stages[CC]->axis(5); auto oc_inner = stages[CC]->axis(6); auto ic_outer = stages[CC]->axis(7); auto ic_inner = stages[CC]->axis(8); - stages[CC]->Reorder({ow_outer, ic_outer, ic_inner, oh_inner, ow_inner, oc_inner}); - stages[CC]->Vectorize(stages[CC]->n_out_dims() - 3, CC->shape.back().as_int32()); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + stages[CC]->Reorder( + {ow_outer, ic_outer, ic_inner, oh_inner, ow_inner, oc_inner}); + stages[CC]->Vectorize(stages[CC]->n_out_dims() - 3, + CC->shape.back().as_int32()); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); // CC_init auto CC_init = CC->GetInitTensor(stages, target); - stages[CC_init]->Vectorize(stages[CC_init]->n_out_dims() - 1, CC_init->shape.back().as_int32()); + stages[CC_init]->Vectorize(stages[CC_init]->n_out_dims() - 1, + CC_init->shape.back().as_int32()); // res // n, oc, oh, ow @@ -1049,15 +1123,17 @@ void Conv2d_NCHWc_1X1_Schedule_CPU_Nofuse(poly::StageMap stages, stages[res]->Split(1, oc_bn_size); stages[res]->Split(3, oh_bn_size); stages[res]->Split(5, ow_bn_size); - // reorder: [n, oc_outer, oc_inner, oh_outer, oh_inner, ow_outer, ow_inner] -> - // [n, oc_outer, oh_outer, ow_outer, oh_inner, ow_inner, oc_inner] + // reorder: [n, oc_outer, oc_inner, oh_outer, oh_inner, ow_outer, ow_inner] + // -> [n, oc_outer, oh_outer, ow_outer, oh_inner, ow_inner, oc_inner] auto oc_inner1 = stages[res]->axis(2); auto oh_outer1 = stages[res]->axis(3); auto oh_inner1 = stages[res]->axis(4); auto ow_outer1 = stages[res]->axis(5); auto ow_inner1 = stages[res]->axis(6); - stages[res]->Reorder({oh_outer1, ow_outer1, oh_inner1, ow_inner1, oc_inner1}); - VLOG(3) << "stages[res]->transformed_domain()" << stages[res]->transformed_domain(); + stages[res]->Reorder( + {oh_outer1, ow_outer1, oh_inner1, ow_inner1, oc_inner1}); + VLOG(3) << "stages[res]->transformed_domain()" + << stages[res]->transformed_domain(); } } @@ -1068,23 +1144,25 @@ void Conv2d_NCHWc_Schedule_CPU_Nofuse(poly::StageMap stages, const ir::Tensor &weights_dilation, const ir::Tensor &data, const common::Target &target) { - CHECK(target.arch == Target::Arch::X86) << "Conv2d_NCHWc_Schedule_CPU_Nofuse schedule only used in x86"; + CHECK(target.arch == Target::Arch::X86) + << "Conv2d_NCHWc_Schedule_CPU_Nofuse schedule only used in x86"; CHECK(packed_out.defined()); CHECK(input_pad.defined()); auto type = packed_out->type(); absl::flat_hash_map conv2d_factors; - CHECK_EQ(packed_out->shape.size(), 5U) << "packed_out's shape size should be 5"; - Expr w_out = common::AutoSimplify(packed_out->shape[3]); - int ow = w_out.as_int32(); + CHECK_EQ(packed_out->shape.size(), 5U) + << "packed_out's shape size should be 5"; + Expr w_out = common::AutoSimplify(packed_out->shape[3]); + int ow = w_out.as_int32(); int basic_split_factor = GetBasicFactor(type, target); GetConv2dFactors(&conv2d_factors, -1, -1, -1, -1, ow, type, target); int ow_bn_size = conv2d_factors["ow_bn"]; auto input_shape = input_pad->shape; - int shape_size = input_shape.size(); + int shape_size = input_shape.size(); CHECK_EQ(shape_size, 5U) << "input shape size should be 5"; - Expr oc_bn = common::AutoSimplify(packed_out->shape.back()); - Expr ic_bn = common::AutoSimplify(input_shape.back()); + Expr oc_bn = common::AutoSimplify(packed_out->shape.back()); + Expr ic_bn = common::AutoSimplify(input_shape.back()); int oc_bn_size = oc_bn.as_int32(); int ic_bn_size = ic_bn.as_int32(); VLOG(3) << "ow_bn_size " << ow_bn_size; @@ -1097,45 +1175,55 @@ void Conv2d_NCHWc_Schedule_CPU_Nofuse(poly::StageMap stages, } // weights if (weights_dilation.defined()) { - CHECK_GE(stages[weights_dilation]->n_out_dims(), 3U) << "weights_dilation's out_dims should be more than 3"; + CHECK_GE(stages[weights_dilation]->n_out_dims(), 3U) + << "weights_dilation's out_dims should be more than 3"; // Reorder: [oc_outer, ic_outer, oh, ow, ic_inner, oc_inner] -> // [oc_outer, oh, ic_outer, ow, ic_inner, oc_inner] stages[weights_dilation]->Reorder({2, 1}); } // packed_out auto CC = stages[packed_out]->CacheWrite("global", stages, packed_out); - VLOG(3) << "stages[packed_out]->transformed_domain()" << stages[packed_out]->transformed_domain(); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + VLOG(3) << "stages[packed_out]->transformed_domain()" + << stages[packed_out]->transformed_domain(); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); // packed_out: [batch, oc_outer, oh, ow, oc_inner] // split ow stages[packed_out]->Split(3, ow_bn_size); - stages[packed_out]->Vectorize(stages[packed_out]->n_out_dims() - 1, packed_out->shape.back().as_int32()); + stages[packed_out]->Vectorize(stages[packed_out]->n_out_dims() - 1, + packed_out->shape.back().as_int32()); // CC: [batch, oc_outer, oh, ow, oc_inner] // packed_out: [batch, oc_outer, oh, ow_outer, ow_inner, oc_inner] // not computeAt ow_outer but oh stages[CC]->ComputeAt2(stages[packed_out], 2); - VLOG(3) << "stages[packed_out]->transformed_domain()" << stages[packed_out]->transformed_domain(); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + VLOG(3) << "stages[packed_out]->transformed_domain()" + << stages[packed_out]->transformed_domain(); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); // split ow stages[CC]->Split(3, ow_bn_size); // CC: [batch, oc_outer, oh, ow_outer, ow_inner, oc_inner, ic, kh, kw] // split ic stages[CC]->Split(6, ic_bn_size); - // reorder: [batch, oc_outer, oh, ow_outer, ow_inner, oc_inner, ic_outer, ic_inner, kh, kw] -> - // [batch, oc_outer, oh, ow_outer, ic_outer, kh, kw, ic_inner, ow_inner, oc_inner] + // reorder: [batch, oc_outer, oh, ow_outer, ow_inner, oc_inner, ic_outer, + // ic_inner, kh, kw] -> [batch, oc_outer, oh, ow_outer, ic_outer, kh, kw, + // ic_inner, ow_inner, oc_inner] auto ow_inner = stages[CC]->axis(4); auto oc_inner = stages[CC]->axis(5); auto ic_outer = stages[CC]->axis(6); auto ic_inner = stages[CC]->axis(7); - auto kh = stages[CC]->axis(8); - auto kw = stages[CC]->axis(9); + auto kh = stages[CC]->axis(8); + auto kw = stages[CC]->axis(9); stages[CC]->Reorder({ic_outer, kh, kw, ic_inner, ow_inner, oc_inner}); - stages[CC]->Vectorize(stages[CC]->n_out_dims() - 1, CC->shape.back().as_int32()); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + stages[CC]->Vectorize(stages[CC]->n_out_dims() - 1, + CC->shape.back().as_int32()); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); // CC_init auto CC_init = CC->GetInitTensor(stages, target); - stages[CC_init]->Vectorize(stages[CC_init]->n_out_dims() - 1, CC_init->shape.back().as_int32()); + stages[CC_init]->Vectorize(stages[CC_init]->n_out_dims() - 1, + CC_init->shape.back().as_int32()); // res // n, oc, oh, ow @@ -1145,11 +1233,12 @@ void Conv2d_NCHWc_Schedule_CPU_Nofuse(poly::StageMap stages, // Reorder: [n, oc_outer, oc_inner, oh, ow_outer, ow_inner] -> // [n, oc_outer, oh, ow_outer, ow_inner, oc_inner] auto oc_inner1 = stages[res]->axis(2); - auto oh1 = stages[res]->axis(3); + auto oh1 = stages[res]->axis(3); auto ow_outer1 = stages[res]->axis(4); auto ow_inner1 = stages[res]->axis(5); stages[res]->Reorder({oh1, ow_outer1, ow_inner1, oc_inner1}); - VLOG(3) << "stages[res]->transformed_domain()" << stages[res]->transformed_domain(); + VLOG(3) << "stages[res]->transformed_domain()" + << stages[res]->transformed_domain(); } } @@ -1162,18 +1251,20 @@ void Conv2d_NCHWc_Schedule_CPU(poly::StageMap stages, const common::Target &target, const std::string &key, bool do_padding) { - CHECK(target.arch == Target::Arch::X86) << "Conv2d_NCHWc_Schedule_CPU schedule only used in x86"; + CHECK(target.arch == Target::Arch::X86) + << "Conv2d_NCHWc_Schedule_CPU schedule only used in x86"; CHECK(packed_out.defined()); CHECK(input_pad.defined()); auto type = packed_out->type(); - CHECK_EQ(packed_out->shape.size(), 5U) << "packed_out's shape size should be 5"; - Expr w_out = common::AutoSimplify(packed_out->shape[3]); - int ow = w_out.as_int32(); + CHECK_EQ(packed_out->shape.size(), 5U) + << "packed_out's shape size should be 5"; + Expr w_out = common::AutoSimplify(packed_out->shape[3]); + int ow = w_out.as_int32(); auto input_shape = input_pad->shape; - int shape_size = input_shape.size(); + int shape_size = input_shape.size(); CHECK_EQ(shape_size, 5U) << "input shape size should be 5"; - Expr oc_bn = common::AutoSimplify(packed_out->shape.back()); - Expr ic_bn = common::AutoSimplify(input_shape.back()); + Expr oc_bn = common::AutoSimplify(packed_out->shape.back()); + Expr ic_bn = common::AutoSimplify(input_shape.back()); int oc_bn_size = oc_bn.as_int32(); int ic_bn_size = ic_bn.as_int32(); @@ -1190,63 +1281,77 @@ void Conv2d_NCHWc_Schedule_CPU(poly::StageMap stages, VLOG(3) << "unroll_kw " << unroll_kw; // data if (data.defined()) { - CHECK_GE(stages[data]->n_out_dims(), 3U) << "data's out_dims should be more than 3"; + CHECK_GE(stages[data]->n_out_dims(), 3U) + << "data's out_dims should be more than 3"; stages[data]->Fuse({0, 1, 2}); stages[data]->ComputeInline(); } // input_pad if (do_padding) { - CHECK_GE(stages[input_pad]->n_out_dims(), 3U) << "input_pad's out_dims should be more than 3"; + CHECK_GE(stages[input_pad]->n_out_dims(), 3U) + << "input_pad's out_dims should be more than 3"; stages[input_pad]->Fuse({0, 1, 2}); - stages[input_pad]->Vectorize(stages[input_pad]->n_out_dims() - 1, input_pad->shape.back().as_int32()); + stages[input_pad]->Vectorize(stages[input_pad]->n_out_dims() - 1, + input_pad->shape.back().as_int32()); } else { stages[input_pad]->ComputeInline(); } // weights if (weights_dilation.defined()) { - CHECK_GE(stages[weights_dilation]->n_out_dims(), 3U) << "weights_dilation's out_dims should be more than 3"; - // oc_outer, ic_outer, oh, ow, ic_inner, oc_inner -> oc_outer, oh, ic_outer, ow, ic_inner, oc_inner + CHECK_GE(stages[weights_dilation]->n_out_dims(), 3U) + << "weights_dilation's out_dims should be more than 3"; + // oc_outer, ic_outer, oh, ow, ic_inner, oc_inner -> oc_outer, oh, ic_outer, + // ow, ic_inner, oc_inner stages[weights_dilation]->Reorder({2, 1}); stages[weights_dilation]->Fuse({0, 1}); } // packed_out auto CC = stages[packed_out]->CacheWrite("global", stages, packed_out); - VLOG(3) << "stages[packed_out]->transformed_domain()" << stages[packed_out]->transformed_domain(); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + VLOG(3) << "stages[packed_out]->transformed_domain()" + << stages[packed_out]->transformed_domain(); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); // packed_out: [batch, oc_outer, oh, ow, oc_inner] // split ow stages[packed_out]->Split(3, ow_bn_size); stages[packed_out]->Fuse({0, 1, 2}); - stages[packed_out]->Vectorize(stages[packed_out]->n_out_dims() - 1, packed_out->shape.back().as_int32()); + stages[packed_out]->Vectorize(stages[packed_out]->n_out_dims() - 1, + packed_out->shape.back().as_int32()); // CC stages[CC]->ComputeAt2(stages[packed_out], 1); VLOG(3) << "cache write shape: " << utils::Join(CC->shape, ", "); - VLOG(3) << "stages[packed_out]->transformed_domain()" << stages[packed_out]->transformed_domain(); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + VLOG(3) << "stages[packed_out]->transformed_domain()" + << stages[packed_out]->transformed_domain(); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); // CC: [batch_oc_outer_oh_fused, ow_outer, ow_inner, oc_inner, ic, kh, kw] // for fused_axes' copy transform, not split ow again // split ic stages[CC]->Split(4, ic_bn_size); - // reorder: [batch_oc_outer_oh_fused, ow_outer, ow_inner, oc_inner, ic_outer, ic_inner, kh, kw] -> - // [batch_oc_outer_oh_fused, ow_outer, ic_outer, kh, kw, ic_inner, ow_inner, oc_inner] + // reorder: [batch_oc_outer_oh_fused, ow_outer, ow_inner, oc_inner, ic_outer, + // ic_inner, kh, kw] -> [batch_oc_outer_oh_fused, ow_outer, ic_outer, kh, kw, + // ic_inner, ow_inner, oc_inner] auto ow_inner = stages[CC]->axis(2); auto oc_inner = stages[CC]->axis(3); auto ic_outer = stages[CC]->axis(4); auto ic_inner = stages[CC]->axis(5); - auto kh = stages[CC]->axis(6); - auto kw = stages[CC]->axis(7); + auto kh = stages[CC]->axis(6); + auto kw = stages[CC]->axis(7); if (unroll_kw) { stages[CC]->Reorder({ic_outer, kh, ic_inner, kw, ow_inner, oc_inner}); stages[CC]->Unroll(kw); } else { stages[CC]->Reorder({ic_outer, kh, kw, ic_inner, ow_inner, oc_inner}); } - stages[CC]->Vectorize(stages[CC]->n_out_dims() - 1, CC->shape.back().as_int32()); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + stages[CC]->Vectorize(stages[CC]->n_out_dims() - 1, + CC->shape.back().as_int32()); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); // CC_init auto CC_init = CC->GetInitTensor(stages, target); - stages[CC_init]->Vectorize(stages[CC_init]->n_out_dims() - 1, CC_init->shape.back().as_int32()); + stages[CC_init]->Vectorize(stages[CC_init]->n_out_dims() - 1, + CC_init->shape.back().as_int32()); // unroll ow_inner stages[CC]->Unroll(stages[CC]->n_out_dims() - 2); stages[CC_init]->Unroll(stages[CC_init]->n_out_dims() - 2); @@ -1259,7 +1364,7 @@ void Conv2d_NCHWc_Schedule_CPU(poly::StageMap stages, // Reorder: [n, oc_outer, oc_inner, oh, ow_outer, ow_inner] -> // [n, oc_outer, oh, ow_outer, ow_inner, oc_inner] auto oc_inner1 = stages[res]->axis(2); - auto oh1 = stages[res]->axis(3); + auto oh1 = stages[res]->axis(3); auto ow_outer1 = stages[res]->axis(4); auto ow_inner1 = stages[res]->axis(5); stages[res]->Reorder({oh1, ow_outer1, ow_inner1, oc_inner1}); @@ -1269,31 +1374,34 @@ void Conv2d_NCHWc_Schedule_CPU(poly::StageMap stages, } } -void Depthwise_Conv2d_NCHWc_Schedule_CPU_Nofuse(poly::StageMap stages, - const ir::Tensor &res, - ir::Tensor &packed_out, - const ir::Tensor &input_pad, - const ir::Tensor &weights_dilation, - const ir::Tensor &data, - const common::Target &target, - bool do_padding) { - CHECK(target.arch == Target::Arch::X86) << "Depthwise_Conv2d_NCHWc_Schedule_CPU_Nofuse schedule only used in x86"; +void Depthwise_Conv2d_NCHWc_Schedule_CPU_Nofuse( + poly::StageMap stages, + const ir::Tensor &res, + ir::Tensor &packed_out, + const ir::Tensor &input_pad, + const ir::Tensor &weights_dilation, + const ir::Tensor &data, + const common::Target &target, + bool do_padding) { + CHECK(target.arch == Target::Arch::X86) + << "Depthwise_Conv2d_NCHWc_Schedule_CPU_Nofuse schedule only used in x86"; CHECK(packed_out.defined()); CHECK(input_pad.defined()); auto type = packed_out->type(); absl::flat_hash_map conv2d_factors; - CHECK_EQ(packed_out->shape.size(), 5U) << "packed_out's shape size should be 5"; - Expr w_out = common::AutoSimplify(packed_out->shape[3]); - int ow = w_out.as_int32(); + CHECK_EQ(packed_out->shape.size(), 5U) + << "packed_out's shape size should be 5"; + Expr w_out = common::AutoSimplify(packed_out->shape[3]); + int ow = w_out.as_int32(); int basic_split_factor = GetBasicFactor(type, target); GetConv2dFactors(&conv2d_factors, -1, -1, -1, -1, ow, type, target); int ow_bn_size = conv2d_factors["ow_bn"]; auto input_shape = input_pad->shape; - int shape_size = input_shape.size(); + int shape_size = input_shape.size(); CHECK_EQ(shape_size, 5U) << "input shape size should be 5"; - Expr oc_bn = common::AutoSimplify(packed_out->shape.back()); - Expr ic_bn = common::AutoSimplify(input_shape.back()); + Expr oc_bn = common::AutoSimplify(packed_out->shape.back()); + Expr ic_bn = common::AutoSimplify(input_shape.back()); int oc_bn_size = oc_bn.as_int32(); int ic_bn_size = ic_bn.as_int32(); VLOG(3) << "ow_bn_size " << ow_bn_size; @@ -1310,7 +1418,8 @@ void Depthwise_Conv2d_NCHWc_Schedule_CPU_Nofuse(poly::StageMap stages, } // weights if (weights_dilation.defined()) { - CHECK_GE(stages[weights_dilation]->n_out_dims(), 3U) << "weights_dilation's out_dims should be more than 3"; + CHECK_GE(stages[weights_dilation]->n_out_dims(), 3U) + << "weights_dilation's out_dims should be more than 3"; // Reorder: [oc_outer, ic_outer, oh, ow, ic_inner, oc_inner] -> // [oc_outer, oh, ic_outer, ow, ic_inner, oc_inner] stages[weights_dilation]->Reorder({2, 1}); @@ -1318,32 +1427,40 @@ void Depthwise_Conv2d_NCHWc_Schedule_CPU_Nofuse(poly::StageMap stages, // packed_out auto CC = stages[packed_out]->CacheWrite("global", stages, packed_out); - VLOG(3) << "stages[packed_out]->transformed_domain()" << stages[packed_out]->transformed_domain(); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + VLOG(3) << "stages[packed_out]->transformed_domain()" + << stages[packed_out]->transformed_domain(); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); // packed_out: [batch, oc_outer, oh, ow, oc_inner] // split ow stages[packed_out]->Split(3, ow_bn_size); - stages[packed_out]->Vectorize(stages[packed_out]->n_out_dims() - 1, packed_out->shape.back().as_int32()); + stages[packed_out]->Vectorize(stages[packed_out]->n_out_dims() - 1, + packed_out->shape.back().as_int32()); // CC: [batch, oc_outer, oh, ow, oc_inner] // packed_out: [batch, oc_outer, oh, ow_outer, ow_inner, oc_inner] stages[CC]->ComputeAt2(stages[packed_out], 3); - VLOG(3) << "stages[packed_out]->transformed_domain()" << stages[packed_out]->transformed_domain(); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + VLOG(3) << "stages[packed_out]->transformed_domain()" + << stages[packed_out]->transformed_domain(); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); // CC: [batch, oc_outer, oh, ow_outer, ow_inner, oc_inner, fc, kh, kw] // batch, oc_outer, oh, ow_outer, kh, kw, ow_inner, oc_inner auto CC_ow_inner = stages[CC]->axis(4); auto CC_oc_inner = stages[CC]->axis(5); - auto CC_fc = stages[CC]->axis(6); - auto CC_kh = stages[CC]->axis(7); - auto CC_kw = stages[CC]->axis(8); + auto CC_fc = stages[CC]->axis(6); + auto CC_kh = stages[CC]->axis(7); + auto CC_kw = stages[CC]->axis(8); stages[CC]->Reorder({CC_fc, CC_kh, CC_kw, CC_ow_inner, CC_oc_inner}); - stages[CC]->Vectorize(stages[CC]->n_out_dims() - 1, CC->shape.back().as_int32()); - VLOG(3) << "stages[CC]->transformed_domain()" << stages[CC]->transformed_domain(); + stages[CC]->Vectorize(stages[CC]->n_out_dims() - 1, + CC->shape.back().as_int32()); + VLOG(3) << "stages[CC]->transformed_domain()" + << stages[CC]->transformed_domain(); // CC_init auto CC_init = CC->GetInitTensor(stages, target); - stages[CC_init]->Vectorize(stages[CC_init]->n_out_dims() - 1, CC_init->shape.back().as_int32()); + stages[CC_init]->Vectorize(stages[CC_init]->n_out_dims() - 1, + CC_init->shape.back().as_int32()); // res // n, oc, oh, ow @@ -1353,11 +1470,12 @@ void Depthwise_Conv2d_NCHWc_Schedule_CPU_Nofuse(poly::StageMap stages, // Reorder: [n, oc_outer, oc_inner, oh, ow_outer, ow_inner] -> // [n, oc_outer, oh, ow_outer, ow_inner, oc_inner] auto oc_inner1 = stages[res]->axis(2); - auto oh1 = stages[res]->axis(3); + auto oh1 = stages[res]->axis(3); auto ow_outer1 = stages[res]->axis(4); auto ow_inner1 = stages[res]->axis(5); stages[res]->Reorder({oh1, ow_outer1, ow_inner1, oc_inner1}); - VLOG(3) << "stages[res]->transformed_domain()" << stages[res]->transformed_domain(); + VLOG(3) << "stages[res]->transformed_domain()" + << stages[res]->transformed_domain(); } } @@ -1371,7 +1489,9 @@ void CudaScheduleMul(poly::StageMap stages, } inline void InputDirectConvCudaParam( - absl::flat_hash_map>> &model_data, + absl::flat_hash_map>> + &model_data, const std::string &key, const std::vector> &int_data) { CHECK_EQ(int_data.size(), 6UL); @@ -1379,331 +1499,737 @@ inline void InputDirectConvCudaParam( schedule_data["rc"] = int_data[0]; schedule_data["ry"] = int_data[1]; schedule_data["rx"] = int_data[2]; - schedule_data["f"] = int_data[3]; - schedule_data["y"] = int_data[4]; - schedule_data["x"] = int_data[5]; - CHECK(model_data.count(key) == 0) << "Key " << key << "in conv cuda param already exists."; + schedule_data["f"] = int_data[3]; + schedule_data["y"] = int_data[4]; + schedule_data["x"] = int_data[5]; + CHECK(model_data.count(key) == 0) + << "Key " << key << "in conv cuda param already exists."; model_data[key] = schedule_data; } inline void InputWinogradConvCudaParam( - absl::flat_hash_map>> &model_data, + absl::flat_hash_map>> + &model_data, const std::string &key, const std::vector> &int_data) { CHECK_EQ(int_data.size(), 4UL); absl::flat_hash_map> schedule_data; schedule_data["rc"] = int_data[0]; - schedule_data["x"] = int_data[1]; - schedule_data["y"] = int_data[2]; - schedule_data["b"] = int_data[3]; - model_data[key] = schedule_data; + schedule_data["x"] = int_data[1]; + schedule_data["y"] = int_data[2]; + schedule_data["b"] = int_data[3]; + model_data[key] = schedule_data; } -absl::flat_hash_map>> CreateCudaParams() { - absl::flat_hash_map>> model_data; +absl::flat_hash_map>> +CreateCudaParams() { + absl::flat_hash_map>> + model_data; // The format of serial data is: - // hash_key: string = name of schedule + shape of input_pad + shape of weights + shape of output - // value: vector of params - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 3 230 230 64 3 7 7 1 64 112 112", - {{3, 1}, {7, 1}, {1, 7}, {1, 4, 8, 2}, {112, 1, 1, 1}, {1, 7, 16, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 64 56 56 64 64 1 1 1 64 56 56", - {{4, 16}, {1, 1}, {1, 1}, {1, 8, 8, 1}, {56, 1, 1, 1}, {1, 2, 28, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 64 58 58 128 64 3 3 1 128 28 28", - {{32, 2}, {1, 3}, {1, 3}, {4, 2, 16, 1}, {28, 1, 1, 1}, {1, 2, 14, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 64 56 56 128 64 1 1 1 128 28 28", - {{4, 16}, {1, 1}, {1, 1}, {2, 2, 32, 1}, {28, 1, 1, 1}, {1, 2, 14, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 128 30 30 256 128 3 3 1 256 14 14", - {{32, 4}, {1, 3}, {1, 3}, {8, 1, 16, 2}, {7, 1, 2, 1}, {1, 1, 7, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 128 28 28 256 128 1 1 1 256 14 14", - {{16, 8}, {1, 1}, {1, 1}, {8, 1, 16, 2}, {14, 1, 1, 1}, {1, 1, 14, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 256 16 16 512 256 3 3 1 512 7 7", - {{64, 4}, {1, 3}, {1, 3}, {32, 1, 16, 1}, {7, 1, 1, 1}, {1, 1, 7, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 256 14 14 512 256 1 1 1 512 7 7", - {{16, 16}, {1, 1}, {1, 1}, {16, 1, 32, 1}, {7, 1, 1, 1}, {1, 1, 7, 1}}); + // hash_key: string = name of schedule + shape of input_pad + shape of weights + // + shape of output value: vector of params + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 3 230 230 64 3 7 7 1 64 112 112", + {{3, 1}, {7, 1}, {1, 7}, {1, 4, 8, 2}, {112, 1, 1, 1}, {1, 7, 16, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 64 56 56 64 64 1 1 1 64 56 56", + {{4, 16}, {1, 1}, {1, 1}, {1, 8, 8, 1}, {56, 1, 1, 1}, {1, 2, 28, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 64 58 58 128 64 3 3 1 128 28 28", + {{32, 2}, {1, 3}, {1, 3}, {4, 2, 16, 1}, {28, 1, 1, 1}, {1, 2, 14, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 64 56 56 128 64 1 1 1 128 28 28", + {{4, 16}, {1, 1}, {1, 1}, {2, 2, 32, 1}, {28, 1, 1, 1}, {1, 2, 14, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 128 30 30 256 128 3 3 1 256 14 14", + {{32, 4}, {1, 3}, {1, 3}, {8, 1, 16, 2}, {7, 1, 2, 1}, {1, 1, 7, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 128 28 28 256 128 1 1 1 256 14 14", + {{16, 8}, {1, 1}, {1, 1}, {8, 1, 16, 2}, {14, 1, 1, 1}, {1, 1, 14, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 256 16 16 512 256 3 3 1 512 7 7", + {{64, 4}, {1, 3}, {1, 3}, {32, 1, 16, 1}, {7, 1, 1, 1}, {1, 1, 7, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 256 14 14 512 256 1 1 1 512 7 7", + {{16, 16}, {1, 1}, {1, 1}, {16, 1, 32, 1}, {7, 1, 1, 1}, {1, 1, 7, 1}}); // winograd - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 64 58 58 64 64 3 3 1 64 56 56", - {{32, 2}, {1, 3}, {1, 3}, {4, 1, 8, 2}, {28, 1, 2, 1}, {1, 2, 7, 4}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 64 58 58 64 64 3 3 1 64 56 56", + {{32, 2}, {1, 3}, {1, 3}, {4, 1, 8, 2}, {28, 1, 2, 1}, {1, 2, 7, 4}}); // winograd - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 512 9 9 512 512 3 3 1 512 7 7", - {{64, 8}, {1, 3}, {1, 3}, {32, 1, 16, 1}, {7, 1, 1, 1}, {1, 1, 7, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 512 9 9 512 512 3 3 1 512 7 7", + {{64, 8}, {1, 3}, {1, 3}, {32, 1, 16, 1}, {7, 1, 1, 1}, {1, 1, 7, 1}}); // winograd - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 256 16 16 256 256 3 3 1 256 14 14", - {{64, 4}, {1, 3}, {1, 3}, {16, 1, 16, 1}, {14, 1, 1, 1}, {1, 1, 14, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 256 16 16 256 256 3 3 1 256 14 14", + {{64, 4}, {1, 3}, {1, 3}, {16, 1, 16, 1}, {14, 1, 1, 1}, {1, 1, 14, 1}}); // winograd - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 128 30 30 128 128 3 3 1 128 28 28", - {{32, 4}, {1, 3}, {1, 3}, {8, 1, 16, 1}, {14, 1, 2, 1}, {1, 1, 7, 4}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 128 30 30 128 128 3 3 1 128 28 28", + {{32, 4}, {1, 3}, {1, 3}, {8, 1, 16, 1}, {14, 1, 2, 1}, {1, 1, 7, 4}}); // MobileNetV2 schedule params /* InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 3 226 226 32 3 3 3 1 32 112 112", - {{3, 1}, {1, 3}, {1, 3}, {-1, 2, 8, 2}, {-1, 1, 1, 7}, {-1, 1, 16, 1}}); */ - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 32 112 112 16 32 1 1 1 16 112 112", - {{-1, 4}, {-1, 1}, {-1, 1}, {-1, 2, 2, 4}, {-1, 1, 2, 1}, {-1, 1, 56, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 32 112 112 32 32 1 1 1 32 112 112", - {{-1, 4}, {-1, 1}, {-1, 1}, {-1, 1, 4, 8}, {-1, 1, 2, 1}, {-1, 7, 16, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 16 112 112 96 16 1 1 1 96 112 112", - {{-1, 4}, {-1, 1}, {-1, 1}, {-1, 4, 4, 2}, {-1, 2, 2, 1}, {-1, 1, 16, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 96 56 56 24 96 1 1 1 24 56 56", - {{-1, 4}, {-1, 1}, {-1, 1}, {-1, 3, 4, 2}, {-1, 1, 1, 1}, {-1, 1, 28, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 24 56 56 144 24 1 1 1 144 56 56", - {{-1, 6}, {-1, 1}, {-1, 1}, {-1, 9, 4, 2}, {-1, 2, 1, 1}, {-1, 1, 56, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 144 56 56 24 144 1 1 1 24 56 56", - {{-1, 12}, {-1, 1}, {-1, 1}, {-1, 1, 8, 3}, {-1, 1, 1, 1}, {-1, 2, 14, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 144 28 28 32 144 1 1 1 32 28 28", - {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 4, 8, 1}, {-1, 1, 1, 1}, {-1, 1, 14, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 32 28 28 192 32 1 1 1 192 28 28", - {{-1, 4}, {-1, 1}, {-1, 1}, {-1, 6, 4, 1}, {-1, 2, 1, 2}, {-1, 1, 28, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 192 28 28 32 192 1 1 1 32 28 28", - {{-1, 48}, {-1, 1}, {-1, 1}, {-1, 4, 8, 1}, {-1, 1, 1, 1}, {-1, 1, 28, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 192 14 14 64 192 1 1 1 64 14 14", - {{-1, 12}, {-1, 1}, {-1, 1}, {-1, 1, 8, 2}, {-1, 2, 1, 1}, {-1, 1, 14, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 64 14 14 384 64 1 1 1 384 14 14", - {{-1, 4}, {-1, 1}, {-1, 1}, {-1, 2, 4, 3}, {-1, 1, 7, 1}, {-1, 1, 7, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 384 14 14 64 384 1 1 1 64 14 14", - {{-1, 48}, {-1, 1}, {-1, 1}, {-1, 2, 16, 1}, {-1, 1, 2, 1}, {-1, 1, 7, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 384 14 14 96 384 1 1 1 96 14 14", - {{-1, 12}, {-1, 1}, {-1, 1}, {-1, 2, 6, 1}, {-1, 1, 2, 1}, {-1, 1, 7, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 96 14 14 576 96 1 1 1 576 14 14", - {{-1, 6}, {-1, 1}, {-1, 1}, {-1, 1, 6, 6}, {-1, 1, 7, 1}, {-1, 1, 7, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 576 14 14 96 576 1 1 1 96 14 14", - {{-1, 24}, {-1, 1}, {-1, 1}, {-1, 1, 8, 3}, {-1, 1, 2, 1}, {-1, 1, 7, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 576 7 7 160 576 1 1 1 160 7 7", - {{-1, 36}, {-1, 1}, {-1, 1}, {-1, 2, 2, 2}, {-1, 1, 7, 1}, {-1, 1, 7, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 160 7 7 960 160 1 1 1 960 7 7", - {{-1, 16}, {-1, 1}, {-1, 1}, {-1, 6, 4, 1}, {-1, 1, 7, 1}, {-1, 1, 7, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 960 7 7 160 960 1 1 1 160 7 7", - {{-1, 60}, {-1, 1}, {-1, 1}, {-1, 2, 4, 1}, {-1, 1, 7, 1}, {-1, 1, 7, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 960 7 7 320 960 1 1 1 320 7 7", - {{-1, 20}, {-1, 1}, {-1, 1}, {-1, 2, 2, 2}, {-1, 1, 7, 1}, {-1, 1, 7, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 320 7 7 1280 320 1 1 1 1280 7 7", - {{-1, 16}, {-1, 1}, {-1, 1}, {-1, 2, 16, 1}, {-1, 7, 1, 1}, {-1, 1, 7, 1}}); + "CudaDirectConvSchedule 1 3 226 226 32 3 3 3 1 32 + 112 112", + {{3, 1}, {1, 3}, {1, 3}, {-1, 2, 8, 2}, {-1, 1, 1, + 7}, {-1, 1, 16, 1}}); */ + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 32 112 112 16 32 1 1 1 16 112 112", + {{-1, 4}, + {-1, 1}, + {-1, 1}, + {-1, 2, 2, 4}, + {-1, 1, 2, 1}, + {-1, 1, 56, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 32 112 112 32 32 1 1 1 32 112 112", + {{-1, 4}, + {-1, 1}, + {-1, 1}, + {-1, 1, 4, 8}, + {-1, 1, 2, 1}, + {-1, 7, 16, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 16 112 112 96 16 1 1 1 96 112 112", + {{-1, 4}, + {-1, 1}, + {-1, 1}, + {-1, 4, 4, 2}, + {-1, 2, 2, 1}, + {-1, 1, 16, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 96 56 56 24 96 1 1 1 24 56 56", + {{-1, 4}, + {-1, 1}, + {-1, 1}, + {-1, 3, 4, 2}, + {-1, 1, 1, 1}, + {-1, 1, 28, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 24 56 56 144 24 1 1 1 144 56 56", + {{-1, 6}, + {-1, 1}, + {-1, 1}, + {-1, 9, 4, 2}, + {-1, 2, 1, 1}, + {-1, 1, 56, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 144 56 56 24 144 1 1 1 24 56 56", + {{-1, 12}, + {-1, 1}, + {-1, 1}, + {-1, 1, 8, 3}, + {-1, 1, 1, 1}, + {-1, 2, 14, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 144 28 28 32 144 1 1 1 32 28 28", + {{-1, 8}, + {-1, 1}, + {-1, 1}, + {-1, 4, 8, 1}, + {-1, 1, 1, 1}, + {-1, 1, 14, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 32 28 28 192 32 1 1 1 192 28 28", + {{-1, 4}, + {-1, 1}, + {-1, 1}, + {-1, 6, 4, 1}, + {-1, 2, 1, 2}, + {-1, 1, 28, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 192 28 28 32 192 1 1 1 32 28 28", + {{-1, 48}, + {-1, 1}, + {-1, 1}, + {-1, 4, 8, 1}, + {-1, 1, 1, 1}, + {-1, 1, 28, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 192 14 14 64 192 1 1 1 64 14 14", + {{-1, 12}, + {-1, 1}, + {-1, 1}, + {-1, 1, 8, 2}, + {-1, 2, 1, 1}, + {-1, 1, 14, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 64 14 14 384 64 1 1 1 384 14 14", + {{-1, 4}, {-1, 1}, {-1, 1}, {-1, 2, 4, 3}, {-1, 1, 7, 1}, {-1, 1, 7, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 384 14 14 64 384 1 1 1 64 14 14", + {{-1, 48}, + {-1, 1}, + {-1, 1}, + {-1, 2, 16, 1}, + {-1, 1, 2, 1}, + {-1, 1, 7, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 384 14 14 96 384 1 1 1 96 14 14", + {{-1, 12}, + {-1, 1}, + {-1, 1}, + {-1, 2, 6, 1}, + {-1, 1, 2, 1}, + {-1, 1, 7, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 96 14 14 576 96 1 1 1 576 14 14", + {{-1, 6}, {-1, 1}, {-1, 1}, {-1, 1, 6, 6}, {-1, 1, 7, 1}, {-1, 1, 7, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 576 14 14 96 576 1 1 1 96 14 14", + {{-1, 24}, + {-1, 1}, + {-1, 1}, + {-1, 1, 8, 3}, + {-1, 1, 2, 1}, + {-1, 1, 7, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 576 7 7 160 576 1 1 1 160 7 7", + {{-1, 36}, + {-1, 1}, + {-1, 1}, + {-1, 2, 2, 2}, + {-1, 1, 7, 1}, + {-1, 1, 7, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 160 7 7 960 160 1 1 1 960 7 7", + {{-1, 16}, + {-1, 1}, + {-1, 1}, + {-1, 6, 4, 1}, + {-1, 1, 7, 1}, + {-1, 1, 7, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 960 7 7 160 960 1 1 1 160 7 7", + {{-1, 60}, + {-1, 1}, + {-1, 1}, + {-1, 2, 4, 1}, + {-1, 1, 7, 1}, + {-1, 1, 7, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 960 7 7 320 960 1 1 1 320 7 7", + {{-1, 20}, + {-1, 1}, + {-1, 1}, + {-1, 2, 2, 2}, + {-1, 1, 7, 1}, + {-1, 1, 7, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 320 7 7 1280 320 1 1 1 1280 7 7", + {{-1, 16}, + {-1, 1}, + {-1, 1}, + {-1, 2, 16, 1}, + {-1, 7, 1, 1}, + {-1, 1, 7, 1}}); // EfficientNet schedule params - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 3 228 228 32 3 3 3 1 32 113 113", - {{-1, 1}, {-1, 1}, {-1, 3}, {-1, 32, 1, 1}, {-1, 1, 1, 1}, {-1, 1, 113, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 3 228 228 32 3 3 3 1 32 113 113", + {{-1, 1}, + {-1, 1}, + {-1, 3}, + {-1, 32, 1, 1}, + {-1, 1, 1, 1}, + {-1, 1, 113, 1}}); InputDirectConvCudaParam(model_data, "CudaDirectConvSchedule 1 32 1 1 8 32 1 1 1 8 1 1", - {{-1, 16}, {-1, 1}, {-1, 1}, {-1, 1, 4, 1}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 8 1 1 32 8 1 1 1 32 1 1", - {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 1, 8, 4}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); + {{-1, 16}, + {-1, 1}, + {-1, 1}, + {-1, 1, 4, 1}, + {-1, 1, 1, 1}, + {-1, 1, 1, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 8 1 1 32 8 1 1 1 32 1 1", + {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 1, 8, 4}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); InputDirectConvCudaParam(model_data, "CudaDirectConvSchedule 1 96 1 1 4 96 1 1 1 4 1 1", - {{-1, 48}, {-1, 1}, {-1, 1}, {-1, 1, 4, 1}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); + {{-1, 48}, + {-1, 1}, + {-1, 1}, + {-1, 1, 4, 1}, + {-1, 1, 1, 1}, + {-1, 1, 1, 1}}); InputDirectConvCudaParam(model_data, "CudaDirectConvSchedule 1 4 1 1 96 4 1 1 1 96 1 1", - {{-1, 2}, {-1, 1}, {-1, 1}, {-1, 12, 1, 1}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); + {{-1, 2}, + {-1, 1}, + {-1, 1}, + {-1, 12, 1, 1}, + {-1, 1, 1, 1}, + {-1, 1, 1, 1}}); InputDirectConvCudaParam(model_data, "CudaDirectConvSchedule 1 144 1 1 6 144 1 1 1 6 1 1", - {{-1, 48}, {-1, 1}, {-1, 1}, {-1, 1, 6, 1}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 6 1 1 144 6 1 1 1 144 1 1", - {{-1, 2}, {-1, 1}, {-1, 1}, {-1, 2, 8, 1}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 144 28 28 40 144 1 1 1 40 28 28", - {{-1, 16}, {-1, 1}, {-1, 1}, {-1, 5, 8, 1}, {-1, 1, 1, 1}, {-1, 1, 28, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 40 28 28 240 40 1 1 1 240 28 28", - {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 2, 8, 3}, {-1, 4, 1, 1}, {-1, 1, 28, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 240 1 1 10 240 1 1 1 10 1 1", - {{-1, 60}, {-1, 1}, {-1, 1}, {-1, 1, 5, 1}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 10 1 1 240 10 1 1 1 240 1 1", - {{-1, 10}, {-1, 1}, {-1, 1}, {-1, 1, 40, 3}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 240 28 28 40 240 1 1 1 40 28 28", - {{-1, 16}, {-1, 1}, {-1, 1}, {-1, 1, 8, 5}, {-1, 1, 1, 1}, {-1, 1, 28, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 240 14 14 80 240 1 1 1 80 14 14", - {{-1, 20}, {-1, 1}, {-1, 1}, {-1, 2, 8, 1}, {-1, 1, 2, 1}, {-1, 1, 7, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 80 14 14 480 80 1 1 1 480 14 14", - {{-1, 16}, {-1, 1}, {-1, 1}, {-1, 2, 8, 3}, {-1, 1, 7, 1}, {-1, 1, 7, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 480 1 1 20 480 1 1 1 20 1 1", - {{-1, 60}, {-1, 1}, {-1, 1}, {-1, 1, 4, 1}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 20 1 1 480 20 1 1 1 480 1 1", - {{-1, 5}, {-1, 1}, {-1, 1}, {-1, 1, 32, 1}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 480 14 14 80 480 1 1 1 80 14 14", - {{-1, 40}, {-1, 1}, {-1, 1}, {-1, 2, 8, 1}, {-1, 1, 2, 1}, {-1, 1, 14, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 480 14 14 112 480 1 1 1 112 14 14", - {{-1, 20}, {-1, 1}, {-1, 1}, {-1, 1, 8, 2}, {-1, 1, 2, 1}, {-1, 1, 7, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 112 14 14 672 112 1 1 1 672 14 14", - {{-1, 14}, {-1, 1}, {-1, 1}, {-1, 1, 7, 6}, {-1, 1, 7, 1}, {-1, 1, 7, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 672 1 1 28 672 1 1 1 28 1 1", - {{-1, 28}, {-1, 1}, {-1, 1}, {-1, 1, 7, 1}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 28 1 1 672 28 1 1 1 672 1 1", - {{-1, 28}, {-1, 1}, {-1, 1}, {-1, 1, 16, 1}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 672 14 14 112 672 1 1 1 112 14 14", - {{-1, 14}, {-1, 1}, {-1, 1}, {-1, 2, 4, 2}, {-1, 1, 2, 1}, {-1, 1, 7, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 672 7 7 192 672 1 1 1 192 7 7", - {{-1, 28}, {-1, 1}, {-1, 1}, {-1, 1, 2, 3}, {-1, 1, 7, 1}, {-1, 1, 7, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 192 7 7 1152 192 1 1 1 1152 7 7", - {{-1, 24}, {-1, 1}, {-1, 1}, {-1, 1, 12, 3}, {-1, 7, 1, 1}, {-1, 1, 7, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 1152 1 1 48 1152 1 1 1 48 1 1", - {{-1, 576}, {-1, 1}, {-1, 1}, {-1, 1, 3, 1}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 48 1 1 1152 48 1 1 1 1152 1 1", - {{-1, 12}, {-1, 1}, {-1, 1}, {-1, 1, 32, 1}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 1152 7 7 192 1152 1 1 1 192 7 7", - {{-1, 36}, {-1, 1}, {-1, 1}, {-1, 1, 2, 6}, {-1, 1, 7, 1}, {-1, 1, 7, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 1152 7 7 320 1152 1 1 1 320 7 7", - {{-1, 12}, {-1, 1}, {-1, 1}, {-1, 1, 2, 4}, {-1, 1, 7, 1}, {-1, 1, 7, 1}}); + {{-1, 48}, + {-1, 1}, + {-1, 1}, + {-1, 1, 6, 1}, + {-1, 1, 1, 1}, + {-1, 1, 1, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 6 1 1 144 6 1 1 1 144 1 1", + {{-1, 2}, {-1, 1}, {-1, 1}, {-1, 2, 8, 1}, {-1, 1, 1, 1}, {-1, 1, 1, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 144 28 28 40 144 1 1 1 40 28 28", + {{-1, 16}, + {-1, 1}, + {-1, 1}, + {-1, 5, 8, 1}, + {-1, 1, 1, 1}, + {-1, 1, 28, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 40 28 28 240 40 1 1 1 240 28 28", + {{-1, 8}, + {-1, 1}, + {-1, 1}, + {-1, 2, 8, 3}, + {-1, 4, 1, 1}, + {-1, 1, 28, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 240 1 1 10 240 1 1 1 10 1 1", + {{-1, 60}, + {-1, 1}, + {-1, 1}, + {-1, 1, 5, 1}, + {-1, 1, 1, 1}, + {-1, 1, 1, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 10 1 1 240 10 1 1 1 240 1 1", + {{-1, 10}, + {-1, 1}, + {-1, 1}, + {-1, 1, 40, 3}, + {-1, 1, 1, 1}, + {-1, 1, 1, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 240 28 28 40 240 1 1 1 40 28 28", + {{-1, 16}, + {-1, 1}, + {-1, 1}, + {-1, 1, 8, 5}, + {-1, 1, 1, 1}, + {-1, 1, 28, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 240 14 14 80 240 1 1 1 80 14 14", + {{-1, 20}, + {-1, 1}, + {-1, 1}, + {-1, 2, 8, 1}, + {-1, 1, 2, 1}, + {-1, 1, 7, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 80 14 14 480 80 1 1 1 480 14 14", + {{-1, 16}, + {-1, 1}, + {-1, 1}, + {-1, 2, 8, 3}, + {-1, 1, 7, 1}, + {-1, 1, 7, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 480 1 1 20 480 1 1 1 20 1 1", + {{-1, 60}, + {-1, 1}, + {-1, 1}, + {-1, 1, 4, 1}, + {-1, 1, 1, 1}, + {-1, 1, 1, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 20 1 1 480 20 1 1 1 480 1 1", + {{-1, 5}, + {-1, 1}, + {-1, 1}, + {-1, 1, 32, 1}, + {-1, 1, 1, 1}, + {-1, 1, 1, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 480 14 14 80 480 1 1 1 80 14 14", + {{-1, 40}, + {-1, 1}, + {-1, 1}, + {-1, 2, 8, 1}, + {-1, 1, 2, 1}, + {-1, 1, 14, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 480 14 14 112 480 1 1 1 112 14 14", + {{-1, 20}, + {-1, 1}, + {-1, 1}, + {-1, 1, 8, 2}, + {-1, 1, 2, 1}, + {-1, 1, 7, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 112 14 14 672 112 1 1 1 672 14 14", + {{-1, 14}, + {-1, 1}, + {-1, 1}, + {-1, 1, 7, 6}, + {-1, 1, 7, 1}, + {-1, 1, 7, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 672 1 1 28 672 1 1 1 28 1 1", + {{-1, 28}, + {-1, 1}, + {-1, 1}, + {-1, 1, 7, 1}, + {-1, 1, 1, 1}, + {-1, 1, 1, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 28 1 1 672 28 1 1 1 672 1 1", + {{-1, 28}, + {-1, 1}, + {-1, 1}, + {-1, 1, 16, 1}, + {-1, 1, 1, 1}, + {-1, 1, 1, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 672 14 14 112 672 1 1 1 112 14 14", + {{-1, 14}, + {-1, 1}, + {-1, 1}, + {-1, 2, 4, 2}, + {-1, 1, 2, 1}, + {-1, 1, 7, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 672 7 7 192 672 1 1 1 192 7 7", + {{-1, 28}, + {-1, 1}, + {-1, 1}, + {-1, 1, 2, 3}, + {-1, 1, 7, 1}, + {-1, 1, 7, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 192 7 7 1152 192 1 1 1 1152 7 7", + {{-1, 24}, + {-1, 1}, + {-1, 1}, + {-1, 1, 12, 3}, + {-1, 7, 1, 1}, + {-1, 1, 7, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 1152 1 1 48 1152 1 1 1 48 1 1", + {{-1, 576}, + {-1, 1}, + {-1, 1}, + {-1, 1, 3, 1}, + {-1, 1, 1, 1}, + {-1, 1, 1, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 48 1 1 1152 48 1 1 1 1152 1 1", + {{-1, 12}, + {-1, 1}, + {-1, 1}, + {-1, 1, 32, 1}, + {-1, 1, 1, 1}, + {-1, 1, 1, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 1152 7 7 192 1152 1 1 1 192 7 7", + {{-1, 36}, + {-1, 1}, + {-1, 1}, + {-1, 1, 2, 6}, + {-1, 1, 7, 1}, + {-1, 1, 7, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 1152 7 7 320 1152 1 1 1 320 7 7", + {{-1, 12}, + {-1, 1}, + {-1, 1}, + {-1, 1, 2, 4}, + {-1, 1, 7, 1}, + {-1, 1, 7, 1}}); // FaceDet schedule params /* InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 3 242 322 16 3 3 3 1 16 120 160", - {{-1, 1}, {-1, 3}, {-1, 3}, {-1, 2, 4, 2}, {-1, 1, 1, 5}, {-1, 1, 32, 1}}); */ - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 16 120 160 32 16 1 1 1 32 120 160", - {{-1, 4}, {-1, 1}, {-1, 1}, {-1, 8, 4, 1}, {-1, 1, 1, 1}, {-1, 5, 32, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 32 60 80 32 32 1 1 1 32 60 80", - {{-1, 4}, {-1, 1}, {-1, 1}, {-1, 8, 4, 1}, {-1, 3, 1, 1}, {-1, 1, 40, 1}}); + "CudaDirectConvSchedule 1 3 242 322 16 3 3 3 1 + 16 120 160", + {{-1, 1}, {-1, 3}, {-1, 3}, {-1, 2, 4, 2}, {-1, + 1, 1, 5}, {-1, 1, 32, 1}}); */ + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 16 120 160 32 16 1 1 1 32 120 160", + {{-1, 4}, + {-1, 1}, + {-1, 1}, + {-1, 8, 4, 1}, + {-1, 1, 1, 1}, + {-1, 5, 32, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 32 60 80 32 32 1 1 1 32 60 80", + {{-1, 4}, + {-1, 1}, + {-1, 1}, + {-1, 8, 4, 1}, + {-1, 3, 1, 1}, + {-1, 1, 40, 1}}); /* InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 32 30 40 64 32 1 1 1 64 30 40", - {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 2, 8, 2}, {-1, 1, 1, 3}, {-1, 1, 20, 1}}); */ - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 64 30 40 64 64 1 1 1 64 30 40", - {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 2, 8, 2}, {-1, 1, 2, 1}, {-1, 5, 8, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 64 30 40 8 64 1 1 1 8 30 40", - {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 2, 4, 1}, {-1, 1, 2, 1}, {-1, 1, 8, 1}}); + "CudaDirectConvSchedule 1 32 30 40 64 32 1 1 1 + 64 30 40", + {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 2, 8, 2}, {-1, + 1, 1, 3}, {-1, 1, 20, 1}}); */ + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 64 30 40 64 64 1 1 1 64 30 40", + {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 2, 8, 2}, {-1, 1, 2, 1}, {-1, 5, 8, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 64 30 40 8 64 1 1 1 8 30 40", + {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 2, 4, 1}, {-1, 1, 2, 1}, {-1, 1, 8, 1}}); /* InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 8 32 42 12 8 3 3 1 12 30 40", - {{-1, 4}, {-1, 3}, {-1, 3}, {-1, 1, 12, 1}, {-1, 1, 1, 3}, {-1, 1, 10, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 8 32 42 16 8 3 3 1 16 30 40", - {{-1, 8}, {-1, 3}, {-1, 3}, {-1, 1, 16, 1}, {-1, 3, 1, 2}, {-1, 1, 4, 2}}); */ - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 16 36 46 16 16 3 3 1 16 30 40", - {{-1, 16}, {-1, 1}, {-1, 1}, {-1, 2, 8, 1}, {-1, 1, 2, 1}, {-1, 1, 8, 1}}); + "CudaDirectConvSchedule 1 8 32 42 12 8 3 3 1 12 + 30 40", + {{-1, 4}, {-1, 3}, {-1, 3}, {-1, 1, 12, 1}, + {-1, 1, 1, 3}, {-1, 1, 10, 1}}); InputDirectConvCudaParam(model_data, + "CudaDirectConvSchedule 1 8 32 42 16 8 3 3 1 16 + 30 40", + {{-1, 8}, {-1, 3}, {-1, 3}, {-1, 1, 16, 1}, + {-1, 3, 1, 2}, {-1, 1, 4, 2}}); */ + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 16 36 46 16 16 3 3 1 16 30 40", + {{-1, 16}, + {-1, 1}, + {-1, 1}, + {-1, 2, 8, 1}, + {-1, 1, 2, 1}, + {-1, 1, 8, 1}}); /* InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 16 34 44 16 16 3 3 1 16 30 40", - {{-1, 4}, {-1, 3}, {-1, 3}, {-1, 1, 4, 2}, {-1, 3, 2, 1}, {-1, 1, 20, 1}}); */ + "CudaDirectConvSchedule 1 16 34 44 16 16 3 3 1 16 + 30 40", + {{-1, 4}, {-1, 3}, {-1, 3}, {-1, 1, 4, 2}, {-1, 3, + 2, 1}, {-1, 1, 20, 1}}); */ /* InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 12 32 42 16 12 3 3 1 16 30 40", - {{-1, 4}, {-1, 3}, {-1, 3}, {-1, 1, 16, 1}, {-1, 1, 2, 3}, {-1, 1, 2, 2}}); */ + "CudaDirectConvSchedule 1 12 32 42 16 12 3 3 1 + 16 30 40", + {{-1, 4}, {-1, 3}, {-1, 3}, {-1, 1, 16, 1}, + {-1, 1, 2, 3}, {-1, 1, 2, 2}}); */ /* InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 16 40 50 16 16 3 3 1 16 30 40", - {{-1, 4}, {-1, 1}, {-1, 3}, {-1, 1, 1, 8}, {-1, 1, 3, 1}, {-1, 1, 40, 1}}); */ + "CudaDirectConvSchedule 1 16 40 50 16 16 3 3 1 16 + 30 40", + {{-1, 4}, {-1, 1}, {-1, 3}, {-1, 1, 1, 8}, {-1, 1, + 3, 1}, {-1, 1, 40, 1}}); */ /* InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 48 30 40 64 48 1 1 1 64 30 40", - {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 2, 8, 2}, {-1, 1, 1, 3}, {-1, 1, 20, 1}}); */ - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 64 30 40 12 64 1 1 1 12 30 40", - {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 1, 4, 3}, {-1, 1, 3, 1}, {-1, 1, 10, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 64 30 40 6 64 1 1 1 6 30 40", - {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 3, 2, 1}, {-1, 1, 3, 1}, {-1, 1, 10, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 64 15 20 128 64 1 1 1 128 15 20", - {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 2, 8, 2}, {-1, 1, 3, 1}, {-1, 1, 10, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 128 15 20 128 128 1 1 1 128 15 20", - {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 4, 8, 1}, {-1, 1, 3, 1}, {-1, 1, 10, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 128 15 20 8 128 1 1 1 8 15 20", - {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 1, 8, 1}, {-1, 1, 1, 1}, {-1, 1, 10, 2}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 128 15 20 4 128 1 1 1 4 15 20", - {{-1, 16}, {-1, 1}, {-1, 1}, {-1, 1, 4, 1}, {-1, 1, 1, 1}, {-1, 1, 20, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 128 8 10 256 128 1 1 1 256 8 10", - {{-1, 16}, {-1, 1}, {-1, 1}, {-1, 1, 16, 2}, {-1, 1, 8, 1}, {-1, 1, 2, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 256 8 10 256 256 1 1 1 256 8 10", - {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 4, 8, 1}, {-1, 1, 8, 1}, {-1, 1, 2, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 256 8 10 64 256 1 1 1 64 8 10", - {{-1, 16}, {-1, 1}, {-1, 1}, {-1, 1, 16, 1}, {-1, 1, 8, 1}, {-1, 2, 1, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 256 8 10 8 256 1 1 1 8 8 10", - {{-1, 32}, {-1, 1}, {-1, 1}, {-1, 1, 8, 1}, {-1, 1, 2, 1}, {-1, 1, 2, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 256 8 10 4 256 1 1 1 4 8 10", - {{-1, 32}, {-1, 1}, {-1, 1}, {-1, 1, 4, 1}, {-1, 1, 4, 1}, {-1, 1, 2, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 64 4 5 256 64 1 1 1 256 4 5", - {{-1, 16}, {-1, 1}, {-1, 1}, {-1, 1, 8, 1}, {-1, 1, 4, 1}, {-1, 1, 5, 1}}); - InputDirectConvCudaParam(model_data, - "CudaDirectConvSchedule 1 256 6 7 12 256 3 3 1 12 4 5", - {{-1, 32}, {-1, 3}, {-1, 3}, {-1, 1, 4, 1}, {-1, 1, 4, 1}, {-1, 1, 1, 1}}); + "CudaDirectConvSchedule 1 48 30 40 64 48 1 1 1 64 + 30 40", + {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 2, 8, 2}, {-1, + 1, 1, 3}, {-1, 1, 20, 1}}); */ + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 64 30 40 12 64 1 1 1 12 30 40", + {{-1, 8}, + {-1, 1}, + {-1, 1}, + {-1, 1, 4, 3}, + {-1, 1, 3, 1}, + {-1, 1, 10, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 64 30 40 6 64 1 1 1 6 30 40", + {{-1, 8}, + {-1, 1}, + {-1, 1}, + {-1, 3, 2, 1}, + {-1, 1, 3, 1}, + {-1, 1, 10, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 64 15 20 128 64 1 1 1 128 15 20", + {{-1, 8}, + {-1, 1}, + {-1, 1}, + {-1, 2, 8, 2}, + {-1, 1, 3, 1}, + {-1, 1, 10, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 128 15 20 128 128 1 1 1 128 15 20", + {{-1, 8}, + {-1, 1}, + {-1, 1}, + {-1, 4, 8, 1}, + {-1, 1, 3, 1}, + {-1, 1, 10, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 128 15 20 8 128 1 1 1 8 15 20", + {{-1, 8}, + {-1, 1}, + {-1, 1}, + {-1, 1, 8, 1}, + {-1, 1, 1, 1}, + {-1, 1, 10, 2}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 128 15 20 4 128 1 1 1 4 15 20", + {{-1, 16}, + {-1, 1}, + {-1, 1}, + {-1, 1, 4, 1}, + {-1, 1, 1, 1}, + {-1, 1, 20, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 128 8 10 256 128 1 1 1 256 8 10", + {{-1, 16}, + {-1, 1}, + {-1, 1}, + {-1, 1, 16, 2}, + {-1, 1, 8, 1}, + {-1, 1, 2, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 256 8 10 256 256 1 1 1 256 8 10", + {{-1, 8}, {-1, 1}, {-1, 1}, {-1, 4, 8, 1}, {-1, 1, 8, 1}, {-1, 1, 2, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 256 8 10 64 256 1 1 1 64 8 10", + {{-1, 16}, + {-1, 1}, + {-1, 1}, + {-1, 1, 16, 1}, + {-1, 1, 8, 1}, + {-1, 2, 1, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 256 8 10 8 256 1 1 1 8 8 10", + {{-1, 32}, + {-1, 1}, + {-1, 1}, + {-1, 1, 8, 1}, + {-1, 1, 2, 1}, + {-1, 1, 2, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 256 8 10 4 256 1 1 1 4 8 10", + {{-1, 32}, + {-1, 1}, + {-1, 1}, + {-1, 1, 4, 1}, + {-1, 1, 4, 1}, + {-1, 1, 2, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 64 4 5 256 64 1 1 1 256 4 5", + {{-1, 16}, + {-1, 1}, + {-1, 1}, + {-1, 1, 8, 1}, + {-1, 1, 4, 1}, + {-1, 1, 5, 1}}); + InputDirectConvCudaParam( + model_data, + "CudaDirectConvSchedule 1 256 6 7 12 256 3 3 1 12 4 5", + {{-1, 32}, + {-1, 3}, + {-1, 3}, + {-1, 1, 4, 1}, + {-1, 1, 4, 1}, + {-1, 1, 1, 1}}); InputDirectConvCudaParam(model_data, "CudaDirectConvSchedule 1 256 6 7 6 256 3 3 1 6 4 5", - {{-1, 32}, {-1, 3}, {-1, 3}, {-1, 1, 2, 1}, {-1, 1, 4, 1}, {-1, 1, 1, 1}}); + {{-1, 32}, + {-1, 3}, + {-1, 3}, + {-1, 1, 2, 1}, + {-1, 1, 4, 1}, + {-1, 1, 1, 1}}); #ifndef CINN_WITH_CUDNN - InputWinogradConvCudaParam(model_data, - "CudaWinogradConvSchedule 1 512 9 9 512 512 3 3 1 512 7 7", - {{32, 16}, {1, 1, 8, 2}, {8, 1, 16, 4}, {16, 1, 1, 1}}); - InputWinogradConvCudaParam(model_data, - "CudaWinogradConvSchedule 1 256 6 7 12 256 3 3 1 12 4 5", - {{-1, 256}, {-1, 1, 6, 1}, {-1, 1, 6, 1}, {-1, 1, 1, 1}}); - InputWinogradConvCudaParam(model_data, - "CudaWinogradConvSchedule 1 256 6 7 6 256 3 3 1 12 4 5", - {{-1, 256}, {-1, 1, 6, 1}, {-1, 1, 6, 1}, {-1, 1, 1, 1}}); - InputWinogradConvCudaParam(model_data, - "CudaWinogradConvSchedule 1 12 32 42 16 12 3 3 1 16 30 40", - {{-1, 12}, {-1, 2, 30, 1}, {-1, 4, 2, 2}, {-1, 1, 1, 1}}); - InputWinogradConvCudaParam(model_data, - "CudaWinogradConvSchedule 1 8 32 42 12 8 3 3 1 12 30 40", - {{-1, 8}, {-1, 2, 30, 1}, {-1, 1, 2, 6}, {-1, 1, 1, 1}}); - InputWinogradConvCudaParam(model_data, - "CudaWinogradConvSchedule 1 8 32 42 16 8 3 3 1 16 30 40", - {{-1, 4}, {-1, 2, 30, 1}, {-1, 1, 4, 4}, {-1, 1, 1, 1}}); + InputWinogradConvCudaParam( + model_data, + "CudaWinogradConvSchedule 1 512 9 9 512 512 3 3 1 512 7 7", + {{32, 16}, {1, 1, 8, 2}, {8, 1, 16, 4}, {16, 1, 1, 1}}); + InputWinogradConvCudaParam( + model_data, + "CudaWinogradConvSchedule 1 256 6 7 12 256 3 3 1 12 4 5", + {{-1, 256}, {-1, 1, 6, 1}, {-1, 1, 6, 1}, {-1, 1, 1, 1}}); + InputWinogradConvCudaParam( + model_data, + "CudaWinogradConvSchedule 1 256 6 7 6 256 3 3 1 12 4 5", + {{-1, 256}, {-1, 1, 6, 1}, {-1, 1, 6, 1}, {-1, 1, 1, 1}}); + InputWinogradConvCudaParam( + model_data, + "CudaWinogradConvSchedule 1 12 32 42 16 12 3 3 1 16 30 40", + {{-1, 12}, {-1, 2, 30, 1}, {-1, 4, 2, 2}, {-1, 1, 1, 1}}); + InputWinogradConvCudaParam( + model_data, + "CudaWinogradConvSchedule 1 8 32 42 12 8 3 3 1 12 30 40", + {{-1, 8}, {-1, 2, 30, 1}, {-1, 1, 2, 6}, {-1, 1, 1, 1}}); + InputWinogradConvCudaParam( + model_data, + "CudaWinogradConvSchedule 1 8 32 42 16 8 3 3 1 16 30 40", + {{-1, 4}, {-1, 2, 30, 1}, {-1, 1, 4, 4}, {-1, 1, 1, 1}}); #endif return model_data; } -void CreateCudaSerialData(const std::string &file_name) { SaveSerialData(CreateCudaParams(), file_name); } +void CreateCudaSerialData(const std::string &file_name) { + SaveSerialData(CreateCudaParams(), file_name); +} int GetMaxSplitter(int a, int b) { while (a % b > 0) { @@ -1712,8 +2238,11 @@ int GetMaxSplitter(int a, int b) { return b; } -void LoadSerialData(absl::flat_hash_map>> *params, - const std::string &file_name) { +void LoadSerialData( + absl::flat_hash_map>> + *params, + const std::string &file_name) { proto::ModelData read_model_data; std::fstream input(file_name, std::ios::in | std::ios::binary); if (!read_model_data.ParseFromIstream(&input)) { @@ -1739,7 +2268,9 @@ void LoadSerialData(absl::flat_hash_map>> &model_data, + const absl::flat_hash_map< + std::string, + absl::flat_hash_map>> &model_data, const std::string &file_name) { proto::ModelData write_model_data; for (auto &i : model_data) { @@ -1749,15 +2280,16 @@ void SaveSerialData( for (auto &k : j.second) { write_vector_data.add_data(std::to_string(k)); } - auto data_map = write_schedule_data.mutable_data(); + auto data_map = write_schedule_data.mutable_data(); (*data_map)[j.first] = write_vector_data; } - auto model_map = write_model_data.mutable_data(); + auto model_map = write_model_data.mutable_data(); (*model_map)[i.first] = write_schedule_data; std::string test_write1; write_schedule_data.SerializeToString(&test_write1); } - std::fstream output(file_name, std::ios::out | std::ios::trunc | std::ios::binary); + std::fstream output(file_name, + std::ios::out | std::ios::trunc | std::ios::binary); std::string test_write; write_model_data.SerializeToString(&test_write); if (!write_model_data.SerializeToOstream(&output)) { @@ -1767,7 +2299,9 @@ void SaveSerialData( output.close(); } -void CudaScheduleDepthwiseConv(poly::StageMap stages, ir::Tensor &output, const common::Target &target) { +void CudaScheduleDepthwiseConv(poly::StageMap stages, + ir::Tensor &output, + const common::Target &target) { auto OL = stages[output]->CacheWrite("local", stages, output); stages[output]->Bind(0, "blockIdx.x"); stages[output]->Bind(1, "blockIdx.y"); @@ -1784,22 +2318,27 @@ void CudaScheduleConv(poly::StageMap stages, ir::Tensor &output, const common::Target &target) { auto &res = ScheduleParam::get_cuda_instance().GetParam(); - int n = output->shape[0].as_int32(); - int c = output->shape[1].as_int32(); + int n = output->shape[0].as_int32(); + int c = output->shape[1].as_int32(); optim::Simplify(&(output->shape[2])); int h = output->shape[2].as_int32(); optim::Simplify(&(output->shape[3])); - int w = output->shape[3].as_int32(); + int w = output->shape[3].as_int32(); int rc = input_pad->shape[1].as_int32(); - std::string key = - "CudaDirectConvSchedule " + std::to_string(input_pad->shape[0].as_int32()) + " " + - std::to_string(input_pad->shape[1].as_int32()) + " " + std::to_string(input_pad->shape[2].as_int32()) + " " + - std::to_string(input_pad->shape[3].as_int32()) + " " + std::to_string(weights->shape[0].as_int32()) + " " + - std::to_string(weights->shape[1].as_int32()) + " " + std::to_string(weights->shape[2].as_int32()) + " " + - std::to_string(weights->shape[3].as_int32()) + " " + std::to_string(output->shape[0].as_int32()) + " " + - std::to_string(output->shape[1].as_int32()) + " " + std::to_string(output->shape[2].as_int32()) + " " + - std::to_string(output->shape[3].as_int32()); + std::string key = "CudaDirectConvSchedule " + + std::to_string(input_pad->shape[0].as_int32()) + " " + + std::to_string(input_pad->shape[1].as_int32()) + " " + + std::to_string(input_pad->shape[2].as_int32()) + " " + + std::to_string(input_pad->shape[3].as_int32()) + " " + + std::to_string(weights->shape[0].as_int32()) + " " + + std::to_string(weights->shape[1].as_int32()) + " " + + std::to_string(weights->shape[2].as_int32()) + " " + + std::to_string(weights->shape[3].as_int32()) + " " + + std::to_string(output->shape[0].as_int32()) + " " + + std::to_string(output->shape[1].as_int32()) + " " + + std::to_string(output->shape[2].as_int32()) + " " + + std::to_string(output->shape[3].as_int32()); if (res.count(key) == 0) { VLOG(3) << "Didn't find saved param, key is: " << key; } else { @@ -1811,27 +2350,27 @@ void CudaScheduleConv(poly::StageMap stages, if (stages[weights]->has_expression()) { stages[weights]->ComputeInline(); } - int f_inner = GetInnerSplitter(c, h); - int block_z = SplitEven(c / f_inner); + int f_inner = GetInnerSplitter(c, h); + int block_z = SplitEven(c / f_inner); int thread_z = c / f_inner / block_z; int rc_factor = SplitEven(rc); while (w * thread_z > 1024 && thread_z % 2 == 0) { thread_z = thread_z / 2; - f_inner = f_inner * 2; + f_inner = f_inner * 2; } CHECK_LE(w * thread_z, 1024) << "Wrong Param of Conv2d!"; auto OL = stages[output]->CacheWrite("local", stages, output); - auto tx = stages[output]->axis(3); - auto by = stages[output]->axis(2); + auto tx = stages[output]->axis(3); + auto by = stages[output]->axis(2); auto tem_fi = stages[output]->Split(1, f_inner); - auto &tem = std::get<0>(tem_fi); - auto &fi = std::get<1>(tem_fi); + auto &tem = std::get<0>(tem_fi); + auto &fi = std::get<1>(tem_fi); auto bz_tz = stages[output]->Split(1, thread_z); - auto &bz = std::get<0>(bz_tz); - auto &tz = std::get<1>(bz_tz); + auto &bz = std::get<0>(bz_tz); + auto &tz = std::get<1>(bz_tz); stages[output]->Reorder({bz, by, tz, tx, fi}); stages[output]->Bind(1, "blockIdx.z"); @@ -1858,9 +2397,9 @@ void CudaScheduleConv2(poly::StageMap stages, auto KR = stages[weights]->CacheRead("shared", readers, stages); auto OL = stages[output]->CacheWrite("local", stages, output); - auto &x_param = res[key]["x"]; - auto &y_param = res[key]["y"]; - auto &f_param = res[key]["f"]; + auto &x_param = res[key]["x"]; + auto &y_param = res[key]["y"]; + auto &f_param = res[key]["f"]; auto &rx_param = res[key]["rx"]; auto &ry_param = res[key]["ry"]; auto &rc_param = res[key]["rc"]; @@ -1915,7 +2454,8 @@ void CudaScheduleConv2(poly::StageMap stages, } else if (stages[PR]->n_out_dims() == 19) { stages[PR]->Fuse({13, 14, 15, 16, 17, 18}); } else { - LOG(FATAL) << "PR number of output dims is wrong: " << stages[PR]->n_out_dims(); + LOG(FATAL) << "PR number of output dims is wrong: " + << stages[PR]->n_out_dims(); } if (stages[KR]->n_out_dims() == 18) { @@ -1923,21 +2463,24 @@ void CudaScheduleConv2(poly::StageMap stages, } else if (stages[KR]->n_out_dims() == 19) { stages[KR]->Fuse({13, 14, 15, 16, 17, 18}); } else { - LOG(FATAL) << "KR number of output dims is wrong: " << stages[KR]->n_out_dims(); + LOG(FATAL) << "KR number of output dims is wrong: " + << stages[KR]->n_out_dims(); } int thread_z = f_param[2]; int thread_x = x_param[2]; if (stages[PR]->GetDimRange(13) <= thread_z) { stages[PR]->Bind(13, "threadIdx.z"); } else { - stages[PR]->Split(13, GetMaxSplitter(stages[PR]->GetDimRange(13), thread_z)); + stages[PR]->Split(13, + GetMaxSplitter(stages[PR]->GetDimRange(13), thread_z)); stages[PR]->Bind(14, "threadIdx.z"); stages[PR]->Unroll(13); } if (stages[KR]->GetDimRange(13) <= thread_x) { stages[KR]->Bind(13, "threadIdx.x"); } else { - stages[KR]->Split(13, GetMaxSplitter(stages[KR]->GetDimRange(13), thread_x)); + stages[KR]->Split(13, + GetMaxSplitter(stages[KR]->GetDimRange(13), thread_x)); stages[KR]->Bind(14, "threadIdx.x"); stages[KR]->Unroll(13); } @@ -1975,38 +2518,43 @@ void CudaScheduleConv2(poly::StageMap stages, void CudaScheduleWinogradConv(poly::StageMap wino_stages, std::vector &all_tensors, const common::Target &target) { - auto &res = ScheduleParam::get_cuda_instance().GetParam(); + auto &res = ScheduleParam::get_cuda_instance().GetParam(); auto &wino_weights_dilation = all_tensors[0]; - auto &wino_input_pad = all_tensors[1]; - auto &wino_A = all_tensors[2]; - auto &wino_B = all_tensors[3]; - auto &wino_G = all_tensors[4]; - auto &kernel_pack = all_tensors[5]; - auto &input_tile = all_tensors[6]; - auto &data_pack = all_tensors[7]; - auto &bgemm = all_tensors[8]; - auto &inverse = all_tensors[9]; - auto &wino_conv = all_tensors[10]; + auto &wino_input_pad = all_tensors[1]; + auto &wino_A = all_tensors[2]; + auto &wino_B = all_tensors[3]; + auto &wino_G = all_tensors[4]; + auto &kernel_pack = all_tensors[5]; + auto &input_tile = all_tensors[6]; + auto &data_pack = all_tensors[7]; + auto &bgemm = all_tensors[8]; + auto &inverse = all_tensors[9]; + auto &wino_conv = all_tensors[10]; std::string key = - "CudaWinogradConvSchedule " + std::to_string(wino_input_pad->shape[0].as_int32()) + " " + - std::to_string(wino_input_pad->shape[1].as_int32()) + " " + std::to_string(wino_input_pad->shape[2].as_int32()) + - " " + std::to_string(wino_input_pad->shape[3].as_int32()) + " " + + "CudaWinogradConvSchedule " + + std::to_string(wino_input_pad->shape[0].as_int32()) + " " + + std::to_string(wino_input_pad->shape[1].as_int32()) + " " + + std::to_string(wino_input_pad->shape[2].as_int32()) + " " + + std::to_string(wino_input_pad->shape[3].as_int32()) + " " + std::to_string(wino_weights_dilation->shape[0].as_int32()) + " " + std::to_string(wino_weights_dilation->shape[1].as_int32()) + " " + std::to_string(wino_weights_dilation->shape[2].as_int32()) + " " + std::to_string(wino_weights_dilation->shape[3].as_int32()) + " " + - std::to_string(wino_conv->shape[0].as_int32()) + " " + std::to_string(wino_conv->shape[1].as_int32()) + " " + - std::to_string(wino_conv->shape[2].as_int32()) + " " + std::to_string(wino_conv->shape[3].as_int32()); + std::to_string(wino_conv->shape[0].as_int32()) + " " + + std::to_string(wino_conv->shape[1].as_int32()) + " " + + std::to_string(wino_conv->shape[2].as_int32()) + " " + + std::to_string(wino_conv->shape[3].as_int32()); VLOG(1) << "Key in CudaScheduleWinogradConv is : " << key; CHECK_GT(res.count(key), 0); auto &rc_param = res[key]["rc"]; - auto &x_param = res[key]["x"]; - auto &y_param = res[key]["y"]; - auto &b_param = res[key]["b"]; + auto &x_param = res[key]["x"]; + auto &y_param = res[key]["y"]; + auto &b_param = res[key]["b"]; wino_stages[wino_B]->ComputeInline(); - auto data_l = wino_stages[data_pack]->CacheWrite("local", wino_stages, data_pack); + auto data_l = + wino_stages[data_pack]->CacheWrite("local", wino_stages, data_pack); wino_stages[data_pack]->Split(3, 1); wino_stages[data_pack]->Fuse({2, 3}); @@ -2126,8 +2674,9 @@ int MaxFactorLessThan(int a, int b) { void CudaScheduleInjectiveWithVectorize(poly::Stage *stage, const std::vector &output_shape, const common::Target &target) { - int dims = stage->n_out_dims(); - int prod_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + int dims = stage->n_out_dims(); + int prod_size = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); int num_thread = target.max_num_threads(); int last_shape = stage->GetDimRange(stage->n_out_dims() - 1); // determine the factor of vectorize @@ -2162,7 +2711,9 @@ void CudaScheduleInjectiveWithVectorize(poly::Stage *stage, stage->Split(bind_idx, gcd(stage->GetDimRange(bind_idx), num_thread)); ++bind_idx; } - while (bind_idx > 0 && stage->GetDimRange(bind_idx - 1) * stage->GetDimRange(bind_idx) < num_thread) { + while (bind_idx > 0 && + stage->GetDimRange(bind_idx - 1) * stage->GetDimRange(bind_idx) < + num_thread) { stage->Fuse(bind_idx - 1, bind_idx); --bind_idx; } @@ -2185,13 +2736,18 @@ void CudaScheduleInjectiveWithVectorize(poly::Stage *stage, stage->Bind(bind_idx, block_idx); --bind_idx; } - VLOG(5) << "CudaScheduleInjectiveWithVectorize tensor:" << stage->tensor()->name << ", vector_width:" << vector_width - << ", prod_size:" << prod_size << ", shape:[" << utils::Join(output_shape, ",") << "]" + VLOG(5) << "CudaScheduleInjectiveWithVectorize tensor:" + << stage->tensor()->name << ", vector_width:" << vector_width + << ", prod_size:" << prod_size << ", shape:[" + << utils::Join(output_shape, ",") << "]" << ", range:" << range_str_fn(); } -void CudaScheduleInjective(poly::Stage *stage, const std::vector &output_shape, const common::Target &target) { - CHECK_EQ(stage->n_out_dims(), stage->n_in_dims()) << "The dims of op are not equal"; +void CudaScheduleInjective(poly::Stage *stage, + const std::vector &output_shape, + const common::Target &target) { + CHECK_EQ(stage->n_out_dims(), stage->n_in_dims()) + << "The dims of op are not equal"; if (FLAGS_cinn_use_cuda_vectorize) { CudaScheduleInjectiveWithVectorize(stage, output_shape, target); return; @@ -2202,7 +2758,8 @@ void CudaScheduleInjective(poly::Stage *stage, const std::vector &output_sh } int num_thread = target.max_num_threads(); - int prod_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + int prod_size = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); if (prod_size <= num_thread) { stage->Bind(0, "threadIdx.x"); return; @@ -2211,7 +2768,8 @@ void CudaScheduleInjective(poly::Stage *stage, const std::vector &output_sh if (new_num_thread % 32 != 0) { new_num_thread = MaxFactorLessThan(prod_size, num_thread); } - if (new_num_thread == 1) LOG(FATAL) << "prod_size out of range: " << prod_size; + if (new_num_thread == 1) + LOG(FATAL) << "prod_size out of range: " << prod_size; CHECK_GT(prod_size, new_num_thread); stage->Split(0, new_num_thread); diff --git a/paddle/cinn/hlir/pe/schedule.h b/paddle/cinn/hlir/pe/schedule.h index 9190146d679c5..3c05084335c8d 100644 --- a/paddle/cinn/hlir/pe/schedule.h +++ b/paddle/cinn/hlir/pe/schedule.h @@ -42,15 +42,22 @@ class ScheduleParam { static ScheduleParam instance{common::Target::Arch::X86}; return instance; } - absl::flat_hash_map>> &GetParam() { + absl::flat_hash_map>> + &GetParam() { return param_data; } - absl::flat_hash_map> &operator[](const std::string &key) { return param_data[key]; } + absl::flat_hash_map> &operator[]( + const std::string &key) { + return param_data[key]; + } int Count(const std::string &key) { return param_data.count(key); } private: ScheduleParam(common::Target::Arch arch); - absl::flat_hash_map>> param_data; + absl::flat_hash_map>> + param_data; }; int GetInnerSplitter(int origin, int other_axis); @@ -63,7 +70,9 @@ int GetBasicFactor(const Type &type, const common::Target &target); int GetBetterSplitFactor(int shape, int split_factor); -int GetArrayPackingFactor(int shape, const Type &type, const common::Target &target); +int GetArrayPackingFactor(int shape, + const Type &type, + const common::Target &target); void ScheduleInjectiveCPU(poly::Stage *stage, const std::vector &output_shape, @@ -75,7 +84,9 @@ void ScheduleInjectiveCPU1(poly::Stage *stage, const common::Target &target, bool vectorizable = true); -void MatmulScheduleCUDA(poly::StageMap stages, const ir::Tensor &output, const common::Target &target); +void MatmulScheduleCUDA(poly::StageMap stages, + const ir::Tensor &output, + const common::Target &target); void MatmulScheduleCPU(poly::StageMap stage, const ir::Tensor &output, @@ -87,7 +98,10 @@ void MulScheduleCPU(poly::StageMap stage, const ir::Tensor &input_tensor, const common::Target &target); -void SoftmaxScheduleCPU(poly::StageMap stage, const ir::Tensor &output, const ir::Tensor &temp, int axis = -1); +void SoftmaxScheduleCPU(poly::StageMap stage, + const ir::Tensor &output, + const ir::Tensor &temp, + int axis = -1); void GetConv2dFactors(absl::flat_hash_map *factors, int oc, @@ -98,7 +112,7 @@ void GetConv2dFactors(absl::flat_hash_map *factors, const Type &type, const common::Target &target, const std::string &key = "", - bool import_params = true); + bool import_params = true); void GetConv2d1x1Factors(absl::flat_hash_map *factors, int oc, @@ -117,9 +131,15 @@ void Conv2d_NCHWc_Schedule_CPU(poly::StageMap stages, const common::Target &target, const std::string &key, bool do_padding); -void GlobalPoolScheduleGPU(poly::StageMap stages, const std::vector &output, const common::Target &target); -void PoolScheduleCPU(poly::StageMap stages, const ir::Tensor &output, const common::Target &target); -void PoolScheduleGPU(poly::StageMap stages, ir::Tensor &output, const common::Target &target); +void GlobalPoolScheduleGPU(poly::StageMap stages, + const std::vector &output, + const common::Target &target); +void PoolScheduleCPU(poly::StageMap stages, + const ir::Tensor &output, + const common::Target &target); +void PoolScheduleGPU(poly::StageMap stages, + ir::Tensor &output, + const common::Target &target); void Conv2d_NCHWc_Schedule_CPU_Nofuse(poly::StageMap stages, const ir::Tensor &res, @@ -147,14 +167,15 @@ void Conv2d_NCHWc_1X1_Schedule_CPU_Nofuse(poly::StageMap stages, const ir::Tensor &data, const common::Target &target); -void Depthwise_Conv2d_NCHWc_Schedule_CPU_Nofuse(poly::StageMap stages, - const ir::Tensor &res, - ir::Tensor &packed_out, - const ir::Tensor &input_pad, - const ir::Tensor &weights_dilation, - const ir::Tensor &data, - const common::Target &target, - bool do_padding); +void Depthwise_Conv2d_NCHWc_Schedule_CPU_Nofuse( + poly::StageMap stages, + const ir::Tensor &res, + ir::Tensor &packed_out, + const ir::Tensor &input_pad, + const ir::Tensor &weights_dilation, + const ir::Tensor &data, + const common::Target &target, + bool do_padding); void CudaScheduleMul(poly::StageMap stages, ir::Tensor output, @@ -162,17 +183,26 @@ void CudaScheduleMul(poly::StageMap stages, const common::Target &target); // reduce shedules. -void CudaReduceSchedule(poly::StageMap stages, ir::Tensor output, int last_dimension_num, const common::Target &target); +void CudaReduceSchedule(poly::StageMap stages, + ir::Tensor output, + int last_dimension_num, + const common::Target &target); -void CudaWarpReduceSchedule(poly::StageMap stages, ir::Tensor tmp_out, ir::Tensor out, const common::Target &target); +void CudaWarpReduceSchedule(poly::StageMap stages, + ir::Tensor tmp_out, + ir::Tensor out, + const common::Target &target); void CudaBlockReduceInternalSchedule(poly::StageMap stages, ir::Tensor tmp_out, ir::Tensor out, const common::Target &target); -void CudaBlockReduceSchedule( - poly::StageMap stages, ir::Tensor reduce_tmp_out, ir::Tensor tmp_out, ir::Tensor out, const common::Target &target); +void CudaBlockReduceSchedule(poly::StageMap stages, + ir::Tensor reduce_tmp_out, + ir::Tensor tmp_out, + ir::Tensor out, + const common::Target &target); void CudaBlockShuffleReduceSchedule(poly::StageMap stages, ir::Tensor reduce_reshape, @@ -187,7 +217,9 @@ void CudaTwoStepReduceSchedule(poly::StageMap stages, ir::Tensor out, const common::Target &target); -void CudaScheduleDepthwiseConv(poly::StageMap stages, ir::Tensor &output, const common::Target &target); +void CudaScheduleDepthwiseConv(poly::StageMap stages, + ir::Tensor &output, + const common::Target &target); void CudaScheduleConv(poly::StageMap stages, ir::Tensor &input_pad, @@ -206,7 +238,9 @@ void CudaScheduleConv2(poly::StageMap stages, const common::Target &target, const std::string &key); -void CudaScheduleInjective(poly::Stage *stage, const std::vector &output_shape, const common::Target &target); +void CudaScheduleInjective(poly::Stage *stage, + const std::vector &output_shape, + const common::Target &target); void CudaSplitSchedule(common::CINNValuePack *arg_pack, const std::vector> &output_shapes, @@ -220,7 +254,7 @@ std::string GenerateX86ConvKey(const std::vector &input_shape, const std::vector &strides, const std::vector &paddings, const std::vector &dilations, - const int &index = 0, + const int &index = 0, const std::string &model_name = ""); std::string GenerateX86ConvKey(const std::vector &input_shape, @@ -228,20 +262,27 @@ std::string GenerateX86ConvKey(const std::vector &input_shape, const std::vector &strides, const std::vector &paddings, const std::vector &dilations, - const int &index = 0, + const int &index = 0, const std::string &model_name = ""); void CreateX86SerialData(const std::string &file_name = "default_serial.log"); -void LoadSerialData(absl::flat_hash_map>> *params, - const std::string &file_name = "default_serial.log"); +void LoadSerialData( + absl::flat_hash_map>> + *params, + const std::string &file_name = "default_serial.log"); void SaveSerialData( - const absl::flat_hash_map>> &model_data, + const absl::flat_hash_map< + std::string, + absl::flat_hash_map>> &model_data, const std::string &file_name = "default_serial.log"); int GetMaxSplitter(int a, int b); -absl::flat_hash_map>> CreateCudaParams(); +absl::flat_hash_map>> +CreateCudaParams(); } // namespace pe } // namespace hlir diff --git a/paddle/cinn/hlir/pe/transform.cc b/paddle/cinn/hlir/pe/transform.cc index 2e39bb0cc936e..38e3ba3a39541 100644 --- a/paddle/cinn/hlir/pe/transform.cc +++ b/paddle/cinn/hlir/pe/transform.cc @@ -36,10 +36,12 @@ using cinn::lang::Compute; using ir::Tensor; namespace utils { -std::vector> GetMatmulNewShapes(const std::vector>& inputs_shape, - bool trans_x, - bool trans_y) { - CHECK_EQ(inputs_shape.size(), 2UL) << "The matmul should only have two inputs."; +std::vector> GetMatmulNewShapes( + const std::vector>& inputs_shape, + bool trans_x, + bool trans_y) { + CHECK_EQ(inputs_shape.size(), 2UL) + << "The matmul should only have two inputs."; const auto &x_shape = inputs_shape[0], &y_shape = inputs_shape[1]; CHECK(!x_shape.empty()) << "The shape of matmul input 'x' should not empty."; CHECK(!y_shape.empty()) << "The shape of matmul input 'y' should not empty."; @@ -57,14 +59,15 @@ std::vector> GetMatmulNewShapes(const std::vector> new_shape(3); auto& new_x_shape = new_shape[0]; auto& new_y_shape = new_shape[1]; - auto& out_shape = new_shape[2]; + auto& out_shape = new_shape[2]; int x_dim = x_shape.size(), y_dim = y_shape.size(); int max_dim = std::max(x_shape.size(), y_shape.size()); int out_dim = max_dim >= 3 ? 3 : (max_dim <= 2 ? 2 : max_dim); auto get_input_shape = [out_dim](const std::vector& old_shape) { - CHECK_GE(old_shape.size(), 2UL) << "The shape of matmul input should greater equal 2"; + CHECK_GE(old_shape.size(), 2UL) + << "The shape of matmul input should greater equal 2"; std::vector res; res.resize(out_dim, 1); // [a, b, m, d] -> [a*b, m, d] @@ -79,15 +82,20 @@ std::vector> GetMatmulNewShapes(const std::vector{x_shape[0], 1} : std::vector{1, x_shape[0]}; - new_y_shape = trans_y ? std::vector{1, y_shape[0]} : std::vector{y_shape[0], 1}; - out_shape = {1}; + << "The matmul input X's numbers must be equal to Y's numbers,when " + "X/Y's dims =1. But here " + << matmul_info(); + + new_x_shape = trans_x ? std::vector{x_shape[0], 1} + : std::vector{1, x_shape[0]}; + new_y_shape = trans_y ? std::vector{1, y_shape[0]} + : std::vector{y_shape[0], 1}; + out_shape = {1}; } else if (x_dim == 1) { // vector * matrix int y_K = trans_y ? y_shape[max_dim - 1] : y_shape[max_dim - 2]; - CHECK_EQ(y_K, x_shape[0]) << "The K dimension of Y:" << y_K << " should equal to X.shape[0]:" << x_shape[0] + CHECK_EQ(y_K, x_shape[0]) << "The K dimension of Y:" << y_K + << " should equal to X.shape[0]:" << x_shape[0] << ". But here " << matmul_info(); // set x shape for broadcast @@ -115,7 +123,8 @@ std::vector> GetMatmulNewShapes(const std::vector> GetMatmulNewShapes(const std::vector [1, c, m] * [a*b, m, d] new_x_shape = get_input_shape(x_shape); @@ -162,10 +172,13 @@ std::vector> GetMatmulNewShapes(const std::vector= 0 && y_pos >= 0) { - CHECK(x_shape[x_pos] == y_shape[y_pos] || x_shape[x_pos] == 1 || y_shape[y_pos] == 1) - << "Input X and Y's batch dimension should be same or 1. But here " << matmul_info(); + CHECK(x_shape[x_pos] == y_shape[y_pos] || x_shape[x_pos] == 1 || + y_shape[y_pos] == 1) + << "Input X and Y's batch dimension should be same or 1. But here " + << matmul_info(); - out_shape[out_pos] = (x_shape[x_pos] == 1) ? y_shape[y_pos] : x_shape[x_pos]; + out_shape[out_pos] = + (x_shape[x_pos] == 1) ? y_shape[y_pos] : x_shape[x_pos]; out_pos--; x_pos--; @@ -183,10 +196,11 @@ std::vector> GetMatmulNewShapes(const std::vector> GetMulNewShapes(const std::vector>& inputs_shape, - int x_num_col_dims, - int y_num_col_dims, - bool is_infer) { +std::vector> GetMulNewShapes( + const std::vector>& inputs_shape, + int x_num_col_dims, + int y_num_col_dims, + bool is_infer) { CHECK_EQ(inputs_shape.size(), 2UL) << "The mul should only have two inputs."; const auto &x_shape = inputs_shape[0], &y_shape = inputs_shape[1]; CHECK(!x_shape.empty()) << "The shape of mul input 'x' should not empty."; @@ -197,7 +211,8 @@ std::vector> GetMulNewShapes(const std::vector ss << std::boolalpha << "mul(X:" << "[" << cinn::utils::Join(x_shape, ", ") << "], Y:" << "[" << cinn::utils::Join(y_shape, ", ") << "]" - << ", x_num_col_dims=" << x_num_col_dims << ", y_num_col_dims=" << y_num_col_dims << ")"; + << ", x_num_col_dims=" << x_num_col_dims + << ", y_num_col_dims=" << y_num_col_dims << ")"; return ss.str(); }; VLOG(4) << "Try infer " << mul_info() << "'s correct shape"; @@ -205,7 +220,7 @@ std::vector> GetMulNewShapes(const std::vector std::vector> new_shape(3); auto& new_x_shape = new_shape[0]; auto& new_y_shape = new_shape[1]; - auto& out_shape = new_shape[2]; + auto& out_shape = new_shape[2]; auto flatten_shape = [&](const std::vector& shape, int num_col_dims) { if (shape.size() <= 2) { @@ -216,8 +231,11 @@ std::vector> GetMulNewShapes(const std::vector num_col_dims += shape.size(); } - CHECK_GT(num_col_dims, 0) << "The [num_col_dims] should not be 0 in " << mul_info() << "! Please check."; - CHECK_LT(num_col_dims, shape.size()) << "The [num_col_dims] > rank(input) in " << mul_info() << "! Please check."; + CHECK_GT(num_col_dims, 0) << "The [num_col_dims] should not be 0 in " + << mul_info() << "! Please check."; + CHECK_LT(num_col_dims, shape.size()) + << "The [num_col_dims] > rank(input) in " << mul_info() + << "! Please check."; std::vector res(2, 1); for (int i = 0; i < num_col_dims; ++i) { @@ -249,26 +267,33 @@ std::vector> GetMulNewShapes(const std::vector } } // namespace utils -std::vector Matmul( - const Tensor& A, const Tensor& B, bool trans_a, bool trans_b, float alpha, const std::string& name) { +std::vector Matmul(const Tensor& A, + const Tensor& B, + bool trans_a, + bool trans_b, + float alpha, + const std::string& name) { std::vector shape_A = A->shape; std::vector shape_B = B->shape; - int a_dim = shape_A.size(); - int b_dim = shape_B.size(); - CHECK(a_dim == 3U || a_dim == 2U) << "tensor_A's dim should be 2 or 3 while current dim is " << a_dim; - CHECK(b_dim == 3U || b_dim == 2U) << "tensor_B's dim should be 2 or 3 while current dim is " << b_dim; + int a_dim = shape_A.size(); + int b_dim = shape_B.size(); + CHECK(a_dim == 3U || a_dim == 2U) + << "tensor_A's dim should be 2 or 3 while current dim is " << a_dim; + CHECK(b_dim == 3U || b_dim == 2U) + << "tensor_B's dim should be 2 or 3 while current dim is " << b_dim; CHECK_EQ(a_dim, b_dim) << "tensor_A's dim should be same with tensor_B"; - Expr x_width = trans_a ? shape_A[a_dim - 2] : shape_A.back(); + Expr x_width = trans_a ? shape_A[a_dim - 2] : shape_A.back(); Expr y_height = trans_b ? shape_B.back() : shape_B[b_dim - 2]; - Expr M = trans_a ? shape_A.back() : shape_A[a_dim - 2]; - Expr N = trans_b ? shape_B[b_dim - 2] : shape_B.back(); - CHECK(is_zero(x_width - y_height)) << "matrix multiplication requires x_width to be same with y_height"; + Expr M = trans_a ? shape_A.back() : shape_A[a_dim - 2]; + Expr N = trans_b ? shape_B[b_dim - 2] : shape_B.back(); + CHECK(is_zero(x_width - y_height)) + << "matrix multiplication requires x_width to be same with y_height"; std::vector output_shape; std::vector out; if (a_dim == 3) { int max_batch = std::max(shape_A[0].as_int32(), shape_B[0].as_int32()); - output_shape = {Expr(max_batch), M, N}; + output_shape = {Expr(max_batch), M, N}; } else { output_shape = {M, N}; } @@ -279,7 +304,8 @@ std::vector Matmul( int out_dim = indice.size(); std::vector A_indice; std::vector B_indice; - CHECK(out_dim == 3U || out_dim == 2U) << "indice size should be 2 or 3 while current dim is " << out_dim; + CHECK(out_dim == 3U || out_dim == 2U) + << "indice size should be 2 or 3 while current dim is " << out_dim; if (out_dim == 3U) { // batch A_indice.push_back(indice[0]); @@ -301,7 +327,9 @@ std::vector Matmul( if (alpha != 1) { auto res = Compute( output_shape, - [=](const std::vector& indice) { return temp(indice) * ir::Cast::Make(temp->type(), Expr(alpha)); }, + [=](const std::vector& indice) { + return temp(indice) * ir::Cast::Make(temp->type(), Expr(alpha)); + }, name); return {res, temp}; } else { @@ -315,8 +343,8 @@ ir::Tensor Reshape(const ir::Tensor& A, const std::string& name) { std::vector new_expr_shape; std::vector A_expr_shape = A->shape; - int input_total_size = 1; - int output_total_size = 1; + int input_total_size = 1; + int output_total_size = 1; for (auto& i : A_expr_shape) { CHECK(i.is_constant()) << "Input tensor's shape should be constant value."; input_total_size *= static_cast(i.get_constant()); @@ -326,15 +354,17 @@ ir::Tensor Reshape(const ir::Tensor& A, new_expr_shape.push_back(Expr(i)); } CHECK_EQ(input_total_size, output_total_size) - << "In op reshape, the input tensor and output tensor's total size should be equal, please check!"; + << "In op reshape, the input tensor and output tensor's total size " + "should be equal, please check!"; auto out = Identity(A->Reshape(new_expr_shape, stages), name).front(); return out; } -std::vector Split(const ir::Tensor& A, - int axis, - const std::vector>& output_shapes, - const std::vector& names) { +std::vector Split( + const ir::Tensor& A, + int axis, + const std::vector>& output_shapes, + const std::vector& names) { if (axis < 0) axis += A->shape.size(); auto output_size = output_shapes.size(); @@ -359,7 +389,7 @@ std::vector Split(const ir::Tensor& A, res[i] = Compute( out_shape[i], [=](const std::vector& indice) { - auto temp = indice; + auto temp = indice; temp[axis] = common::AutoSimplify(temp[axis] + Expr(start[i])); return A(temp); }, @@ -368,16 +398,22 @@ std::vector Split(const ir::Tensor& A, return res; } -ir::Tensor Concat(const ir::Tensor& A, const ir::Tensor& B, int axis, const std::string& name) { +ir::Tensor Concat(const ir::Tensor& A, + const ir::Tensor& B, + int axis, + const std::string& name) { if (axis < 0) axis += A->shape.size(); - CHECK_EQ(A->shape.size(), B->shape.size()) << "Dimensions of inputs A and B in Concat should be equal! Please check."; + CHECK_EQ(A->shape.size(), B->shape.size()) + << "Dimensions of inputs A and B in Concat should be equal! Please " + "check."; std::vector output_shape = A->shape; - Expr pivot = A->shape[axis]; - output_shape[axis] = common::AutoSimplify(output_shape[axis] + B->shape[axis]); - auto res = Compute( + Expr pivot = A->shape[axis]; + output_shape[axis] = + common::AutoSimplify(output_shape[axis] + B->shape[axis]); + auto res = Compute( output_shape, [=](const std::vector& indice) { - auto indice_B = indice; + auto indice_B = indice; indice_B[axis] = indice_B[axis] - pivot; return ir::Select::Make(indice[axis] < pivot, A(indice), B(indice_B)); }, @@ -385,31 +421,39 @@ ir::Tensor Concat(const ir::Tensor& A, const ir::Tensor& B, int axis, const std: return res; } -ir::Tensor Concat(const std::vector& input_tensors, int axis, const std::string& name) { +ir::Tensor Concat(const std::vector& input_tensors, + int axis, + const std::string& name) { int input_size = input_tensors.size(); CHECK_GE(input_size, 2U) << "Concat should have at least 2 input tensors"; std::vector output_shape = input_tensors[0]->shape; - int input_dim = output_shape.size(); - CHECK(axis >= -input_dim && axis < input_dim) << "Concat's axis should be in [-R, R)" - << ", but get axis: " << axis << ", R: " << input_dim; + int input_dim = output_shape.size(); + CHECK(axis >= -input_dim && axis < input_dim) + << "Concat's axis should be in [-R, R)" + << ", but get axis: " << axis << ", R: " << input_dim; if (axis < 0) axis += output_shape.size(); for (int i = 1; i < input_size; i++) { CHECK_EQ(input_tensors[i]->shape.size(), input_dim) - << "Dimensions of inputs tensors in Concat should be equal! Please check."; - output_shape[axis] = common::AutoSimplify(output_shape[axis] + input_tensors[i]->shape[axis]); + << "Dimensions of inputs tensors in Concat should be equal! Please " + "check."; + output_shape[axis] = common::AutoSimplify(output_shape[axis] + + input_tensors[i]->shape[axis]); } auto res = Compute( output_shape, [=](const std::vector& indice) { - auto ret = input_tensors[0](indice); + auto ret = input_tensors[0](indice); Expr accumulate_shape = Expr(0); for (int i = 0; i < input_size - 1; i++) { - accumulate_shape = common::AutoSimplify(accumulate_shape + input_tensors[i]->shape[axis]); + accumulate_shape = common::AutoSimplify( + accumulate_shape + input_tensors[i]->shape[axis]); std::vector new_indice = indice; - new_indice[axis] = indice[axis] - accumulate_shape; - ret = ir::Select::Make(indice[axis] < accumulate_shape, ret, input_tensors[i + 1](new_indice)); + new_indice[axis] = indice[axis] - accumulate_shape; + ret = ir::Select::Make(indice[axis] < accumulate_shape, + ret, + input_tensors[i + 1](new_indice)); } return ret; }, @@ -426,30 +470,33 @@ std::vector MatmulV2(const Tensor& A, const common::Target& target) { std::vector shape_A = A->shape; std::vector shape_B = B->shape; - int a_dim = shape_A.size(); - int b_dim = shape_B.size(); - CHECK(a_dim == 3U || a_dim == 2U) << "tensor_A's dim should be 2 or 3 while current dim is " << a_dim; - CHECK(b_dim == 3U || b_dim == 2U) << "tensor_B's dim should be 2 or 3 while current dim is " << b_dim; + int a_dim = shape_A.size(); + int b_dim = shape_B.size(); + CHECK(a_dim == 3U || a_dim == 2U) + << "tensor_A's dim should be 2 or 3 while current dim is " << a_dim; + CHECK(b_dim == 3U || b_dim == 2U) + << "tensor_B's dim should be 2 or 3 while current dim is " << b_dim; CHECK_EQ(a_dim, b_dim) << "tensor_A's dim should be same with tensor_B"; - Expr x_width = trans_a ? shape_A[a_dim - 2] : shape_A.back(); + Expr x_width = trans_a ? shape_A[a_dim - 2] : shape_A.back(); Expr y_height = trans_b ? shape_B.back() : shape_B[b_dim - 2]; - Expr M = trans_a ? shape_A.back() : shape_A[a_dim - 2]; - Expr N = trans_b ? shape_B[b_dim - 2] : shape_B.back(); - CHECK(is_zero(x_width - y_height)) << "matrix multiplication requires x_width to be same with y_height"; + Expr M = trans_a ? shape_A.back() : shape_A[a_dim - 2]; + Expr N = trans_b ? shape_B[b_dim - 2] : shape_B.back(); + CHECK(is_zero(x_width - y_height)) + << "matrix multiplication requires x_width to be same with y_height"; Var reduce_k(x_width, UniqName("reduce_k")); std::vector output_shape; std::vector out; if (a_dim == 3) { int max_batch = std::max(shape_A[0].as_int32(), shape_B[0].as_int32()); - output_shape = {Expr(max_batch), M, N}; + output_shape = {Expr(max_batch), M, N}; } else { output_shape = {M, N}; } // array packing int shape_B_N = N.as_int32(); - int bn = GetArrayPackingFactor(shape_B_N, B->type(), target); + int bn = GetArrayPackingFactor(shape_B_N, B->type(), target); // {N / bn, K, bn} std::vector packedB_shape = {Expr(shape_B_N / bn), y_height, Expr(bn)}; if (b_dim == 3) { @@ -460,7 +507,9 @@ std::vector MatmulV2(const Tensor& A, [=](const std::vector& indice) { std::vector indice_b; int indice_dim = indice.size(); - CHECK_GE(indice_dim, 3) << "packedB's dim should be at least 3 while current dim is " << indice_dim; + CHECK_GE(indice_dim, 3) + << "packedB's dim should be at least 3 while current dim is " + << indice_dim; if (indice_dim == 4) { // batch indice_b.push_back(indice[0]); @@ -481,7 +530,8 @@ std::vector MatmulV2(const Tensor& A, std::vector indice_a; std::vector indice_b; int out_dim = indice.size(); - CHECK(out_dim == 3U || out_dim == 2U) << "indice size should be 2 or 3 while current dim is " << out_dim; + CHECK(out_dim == 3U || out_dim == 2U) + << "indice size should be 2 or 3 while current dim is " << out_dim; if (out_dim == 3) { // batch indice_a.push_back(indice[0]); @@ -498,7 +548,9 @@ std::vector MatmulV2(const Tensor& A, if (alpha == 1) { return lang::ReduceSum(A(indice_a) * packedB(indice_b), {reduce_k}); } else { - return lang::ReduceSum(A(indice_a) * packedB(indice_b) * ir::Cast::Make(A->type(), Expr(alpha)), {reduce_k}); + return lang::ReduceSum(A(indice_a) * packedB(indice_b) * + ir::Cast::Make(A->type(), Expr(alpha)), + {reduce_k}); } }, UniqName("matmulV2_out")); @@ -512,25 +564,30 @@ std::vector MatmulMKL(const Tensor& A, float alpha, const std::string& name, const common::Target& target) { - CHECK(target.arch == Target::Arch::X86) << "mkl should be used in the cpu environment"; + CHECK(target.arch == Target::Arch::X86) + << "mkl should be used in the cpu environment"; std::vector shape_A = A->shape; std::vector shape_B = B->shape; - int a_dim = shape_A.size(); - int b_dim = shape_B.size(); - CHECK(a_dim == 3U || a_dim == 2U) << "tensor_A's dim should be 2 or 3 while current dim is " << a_dim; - CHECK(b_dim == 3U || b_dim == 2U) << "tensor_B's dim should be 2 or 3 while current dim is " << b_dim; + int a_dim = shape_A.size(); + int b_dim = shape_B.size(); + CHECK(a_dim == 3U || a_dim == 2U) + << "tensor_A's dim should be 2 or 3 while current dim is " << a_dim; + CHECK(b_dim == 3U || b_dim == 2U) + << "tensor_B's dim should be 2 or 3 while current dim is " << b_dim; CHECK_EQ(a_dim, b_dim) << "tensor_A's dim should be same with tensor_B"; if (a_dim == 3U) { CHECK_EQ(shape_A.front(), shape_B.front()) - << "tensor A and B's batch size should be same but current batch sizes are " << shape_A.front() << " and " - << shape_B.front(); + << "tensor A and B's batch size should be same but current batch sizes " + "are " + << shape_A.front() << " and " << shape_B.front(); } - Expr x_width = trans_a ? shape_A[a_dim - 2] : shape_A.back(); + Expr x_width = trans_a ? shape_A[a_dim - 2] : shape_A.back(); Expr y_height = trans_b ? shape_B.back() : shape_B[b_dim - 2]; - Expr M = trans_a ? shape_A.back() : shape_A[a_dim - 2]; - Expr N = trans_b ? shape_B[b_dim - 2] : shape_B.back(); - CHECK(is_zero(x_width - y_height)) << "matrix multiplication requires x_width to be same with y_height"; + Expr M = trans_a ? shape_A.back() : shape_A[a_dim - 2]; + Expr N = trans_b ? shape_B[b_dim - 2] : shape_B.back(); + CHECK(is_zero(x_width - y_height)) + << "matrix multiplication requires x_width to be same with y_height"; ir::Tensor call; if (a_dim == 2U) { @@ -587,7 +644,7 @@ std::vector MatmulMKL(const Tensor& A, } int GetMulFactor(int shape, const Type& type, const common::Target& target) { - int split_base = GetBasicFactor(type, target); + int split_base = GetBasicFactor(type, target); int split_factor = 1; for (size_t i = split_base; i >= 1; --i) { if (shape % i == 0) { @@ -598,37 +655,50 @@ int GetMulFactor(int shape, const Type& type, const common::Target& target) { return split_factor; } -std::vector MulBase(const Tensor& A, const Tensor& B, const std::string& name, const common::Target& target) { +std::vector MulBase(const Tensor& A, + const Tensor& B, + const std::string& name, + const common::Target& target) { std::vector output_shape; - CHECK_EQ(A->shape.size(), 2U) << "tensor_A's shape size should be two while current shape size is " - << A->shape.size(); - CHECK_EQ(B->shape.size(), 2U) << "tensor_B's shape size should be two while current shape size is " - << B->shape.size(); - CHECK_EQ(A->shape[1], B->shape[1]) << "tensor_A's last shape should be same with tensor_B"; + CHECK_EQ(A->shape.size(), 2U) + << "tensor_A's shape size should be two while current shape size is " + << A->shape.size(); + CHECK_EQ(B->shape.size(), 2U) + << "tensor_B's shape size should be two while current shape size is " + << B->shape.size(); + CHECK_EQ(A->shape[1], B->shape[1]) + << "tensor_A's last shape should be same with tensor_B"; output_shape.push_back(A->shape[0]); output_shape.push_back(B->shape[0]); if (target.arch == Target::Arch::X86) { - int reduce_dim = A->shape[1].as_int32(); + int reduce_dim = A->shape[1].as_int32(); int split_factor = GetMulFactor(reduce_dim, A->type(), target); - Var reduce_k_first(ir::Cast::Make(A->shape[1]->type(), Expr(reduce_dim / split_factor)), - UniqName("reduce_k_first")); + Var reduce_k_first( + ir::Cast::Make(A->shape[1]->type(), Expr(reduce_dim / split_factor)), + UniqName("reduce_k_first")); auto mul_reduce_first = Compute( {A->shape[0], B->shape[0], Expr(split_factor)}, [=](const std::vector& indice) { - CHECK_EQ(indice.size(), 3U) << "indice size should be three while current size is " << indice.size(); - return lang::ReduceSum(A({indice[0], reduce_k_first * Expr(split_factor) + indice[2]}) * - B({indice[1], reduce_k_first * Expr(split_factor) + indice[2]}), - {reduce_k_first}); + CHECK_EQ(indice.size(), 3U) + << "indice size should be three while current size is " + << indice.size(); + return lang::ReduceSum( + A({indice[0], reduce_k_first * Expr(split_factor) + indice[2]}) * + B({indice[1], + reduce_k_first * Expr(split_factor) + indice[2]}), + {reduce_k_first}); }, UniqName("mul_reduce_k_first")); - Var reduce_k_second(ir::Cast::Make(A->shape[1]->type(), Expr(split_factor)), UniqName("reduce_k_second")); + Var reduce_k_second(ir::Cast::Make(A->shape[1]->type(), Expr(split_factor)), + UniqName("reduce_k_second")); return {Compute( output_shape, [=](const std::vector& indice) { std::vector new_indice = indice; new_indice.push_back(reduce_k_second); - return lang::ReduceSum(mul_reduce_first(new_indice), {reduce_k_second}); + return lang::ReduceSum(mul_reduce_first(new_indice), + {reduce_k_second}); }, name), mul_reduce_first}; @@ -639,7 +709,9 @@ std::vector MulBase(const Tensor& A, const Tensor& B, const std::string& [=](const std::vector& indice) { std::vector A_indice; std::vector B_indice; - CHECK_EQ(indice.size(), 2U) << "indice size should be two while current size is " << indice.size(); + CHECK_EQ(indice.size(), 2U) + << "indice size should be two while current size is " + << indice.size(); A_indice.push_back(indice[0]); B_indice.push_back(indice[1]); A_indice.push_back(reduce_k); @@ -661,8 +733,10 @@ std::vector Mul(const Tensor& A, [=](const std::vector& indice) { std::vector A_indice; std::vector B_indice; - A_indice.insert(A_indice.begin(), indice.begin(), indice.begin() + x_num_col_dims); - B_indice.insert(B_indice.begin(), indice.begin() + x_num_col_dims, indice.end()); + A_indice.insert( + A_indice.begin(), indice.begin(), indice.begin() + x_num_col_dims); + B_indice.insert( + B_indice.begin(), indice.begin() + x_num_col_dims, indice.end()); A_indice.push_back(axis_k); B_indice.push_back(axis_k); return lang::ReduceSum(A(A_indice) * B(B_indice), {axis_k}); @@ -670,21 +744,31 @@ std::vector Mul(const Tensor& A, name)}; } -std::vector MulMKL(const Tensor& A, const Tensor& B, const std::string& name, const common::Target& target) { - CHECK(target.arch == Target::Arch::X86) << "mkl should be used in the cpu environment"; +std::vector MulMKL(const Tensor& A, + const Tensor& B, + const std::string& name, + const common::Target& target) { + CHECK(target.arch == Target::Arch::X86) + << "mkl should be used in the cpu environment"; std::vector shape_A = A->shape; std::vector shape_B = B->shape; - int a_dim = shape_A.size(); - int b_dim = shape_B.size(); - CHECK_EQ(a_dim, 2U) << "tensor_A's shape size should be two while current shape size is " << A->shape.size(); - CHECK_EQ(b_dim, 2U) << "tensor_B's shape size should be two while current shape size is " << B->shape.size(); + int a_dim = shape_A.size(); + int b_dim = shape_B.size(); + CHECK_EQ(a_dim, 2U) + << "tensor_A's shape size should be two while current shape size is " + << A->shape.size(); + CHECK_EQ(b_dim, 2U) + << "tensor_B's shape size should be two while current shape size is " + << B->shape.size(); // A: [M, K], B: [N, K] - Expr x_width = shape_A[1]; + Expr x_width = shape_A[1]; Expr y_height = shape_B[1]; - Expr M = shape_A[0]; - Expr N = shape_B[0]; - CHECK(is_zero(x_width - y_height)) << "matrix multiplication requires x_width to be same with y_height"; - CHECK_EQ(A->shape[1], B->shape[1]) << "tensor_A's last shape should be same with tensor_B"; + Expr M = shape_A[0]; + Expr N = shape_B[0]; + CHECK(is_zero(x_width - y_height)) + << "matrix multiplication requires x_width to be same with y_height"; + CHECK_EQ(A->shape[1], B->shape[1]) + << "tensor_A's last shape should be same with tensor_B"; auto call = Compute( {Expr(1)}, @@ -711,14 +795,15 @@ std::vector MulMKL(const Tensor& A, const Tensor& B, const std::string& return {out, call}; } -void GetLayoutTransformInfo(const ir::Layout& src_layout, - const ir::Layout& dst_layout, - absl::flat_hash_map>* split_index_map) { +void GetLayoutTransformInfo( + const ir::Layout& src_layout, + const ir::Layout& dst_layout, + absl::flat_hash_map>* split_index_map) { CHECK_GT(dst_layout.ndims(), src_layout.ndims()); int offset = 'A' - 'a'; CHECK_EQ(dst_layout.axis_names().size(), dst_layout.ndims()); for (int i = dst_layout.ndims() - 1; i >= 0; i--) { - char axis_name = dst_layout.axis_names(i); + char axis_name = dst_layout.axis_names(i); char prim_axis_name = axis_name; if (axis_name >= 'a' && axis_name <= 'z') { prim_axis_name += offset; @@ -732,17 +817,19 @@ void GetLayoutTransformInfo(const ir::Layout& src_layout, (*split_index_map)[src_primal_index] = {dst_primal_index, i, factor}; } else { int src_primal_index = src_layout.axis_names().find(prim_axis_name); - if (split_index_map->find(src_primal_index) != split_index_map->end()) continue; + if (split_index_map->find(src_primal_index) != split_index_map->end()) + continue; CHECK(src_primal_index != src_layout.axis_names().npos); (*split_index_map)[src_primal_index] = {i}; } } } -std::vector InferShapeLayoutTransform(const std::vector& input_shapes, - const ir::Layout& old_layout, - const ir::Layout& new_layout, - absl::flat_hash_map>* split_index_map) { +std::vector InferShapeLayoutTransform( + const std::vector& input_shapes, + const ir::Layout& old_layout, + const ir::Layout& new_layout, + absl::flat_hash_map>* split_index_map) { int src_dim = old_layout.ndims(); int dst_dim = new_layout.ndims(); std::vector output_shape(dst_dim); @@ -756,15 +843,15 @@ std::vector InferShapeLayoutTransform(const std::vector& input_shape for (int i = 0; i < src_dim; i++) { CHECK(split_index_map->find(i) != split_index_map->end()); if ((*split_index_map)[i].size() == 3) { - int dst_prim_index = (*split_index_map)[i][0]; - int dst_sub_index = (*split_index_map)[i][1]; - int factor = (*split_index_map)[i][2]; - Expr chunk_shape = common::AutoSimplify(input_shapes[i] / factor); - Expr block_shape = Expr(factor); + int dst_prim_index = (*split_index_map)[i][0]; + int dst_sub_index = (*split_index_map)[i][1]; + int factor = (*split_index_map)[i][2]; + Expr chunk_shape = common::AutoSimplify(input_shapes[i] / factor); + Expr block_shape = Expr(factor); output_shape[dst_prim_index] = chunk_shape; - output_shape[dst_sub_index] = block_shape; + output_shape[dst_sub_index] = block_shape; } else if ((*split_index_map)[i].size() == 1) { - int dst_prim_index = (*split_index_map)[i][0]; + int dst_prim_index = (*split_index_map)[i][0]; output_shape[dst_prim_index] = input_shapes[i]; } } @@ -774,14 +861,15 @@ std::vector InferShapeLayoutTransform(const std::vector& input_shape CHECK(split_index_map->find(i) != split_index_map->end()); if ((*split_index_map)[i].size() == 3) { int src_prim_index = (*split_index_map)[i][0]; - int src_sub_index = (*split_index_map)[i][1]; - int factor = (*split_index_map)[i][2]; + int src_sub_index = (*split_index_map)[i][1]; + int factor = (*split_index_map)[i][2]; CHECK_GE(input_shapes.size(), src_sub_index); CHECK_EQ(input_shapes[src_sub_index].as_int32(), factor); - output_shape[i] = common::AutoSimplify(input_shapes[src_prim_index] * factor); + output_shape[i] = + common::AutoSimplify(input_shapes[src_prim_index] * factor); } else if ((*split_index_map)[i].size() == 1) { int src_prim_index = (*split_index_map)[i][0]; - output_shape[i] = input_shapes[src_prim_index]; + output_shape[i] = input_shapes[src_prim_index]; } } } @@ -793,7 +881,8 @@ ir::Tensor LayoutTransform(const Tensor& input, const std::string& src_layout, const std::string& dst_layout, const std::string& name) { - CHECK(src_layout != dst_layout) << "dst_layout is same with src_layout, should not do layout transform"; + CHECK(src_layout != dst_layout) + << "dst_layout is same with src_layout, should not do layout transform"; // NCHW -> NCHWxc // NCHWxc -> NCHW // OIHW -> OIHWxixo @@ -805,9 +894,10 @@ ir::Tensor LayoutTransform(const Tensor& input, int offset = 'A' - 'a'; ir::Layout old_layout(src_layout); ir::Layout new_layout(dst_layout); - int src_dim = old_layout.ndims(); - int dst_dim = new_layout.ndims(); - std::vector output_shape = InferShapeLayoutTransform(input->shape, old_layout, new_layout, &split_index_map); + int src_dim = old_layout.ndims(); + int dst_dim = new_layout.ndims(); + std::vector output_shape = InferShapeLayoutTransform( + input->shape, old_layout, new_layout, &split_index_map); CHECK_EQ(output_shape.size(), dst_dim); auto res = Compute( @@ -821,13 +911,14 @@ ir::Tensor LayoutTransform(const Tensor& input, std::vector split_infos = split_index_map.at(i); if (split_infos.size() == 3) { int prim_index = split_infos[0]; - int sub_index = split_infos[1]; - int factor = split_infos[2]; + int sub_index = split_infos[1]; + int factor = split_infos[2]; if (dst_dim > src_dim) { - new_indice[i] = common::AutoSimplify(indice[prim_index] * factor + indice[sub_index]); + new_indice[i] = common::AutoSimplify(indice[prim_index] * factor + + indice[sub_index]); } else { new_indice[prim_index] = common::AutoSimplify(indice[i] / factor); - new_indice[sub_index] = common::AutoSimplify(indice[i] % factor); + new_indice[sub_index] = common::AutoSimplify(indice[i] % factor); } } else if (split_infos.size() == 1) { @@ -847,9 +938,12 @@ ir::Tensor LayoutTransform(const Tensor& input, return {res}; } -ir::Tensor Reverse(const ir::Tensor& input, const std::vector& axis, const std::string& output_name) { +ir::Tensor Reverse(const ir::Tensor& input, + const std::vector& axis, + const std::string& output_name) { for (auto& val : axis) { - CHECK(val >= 0 && val < static_cast(input->shape.size())) << "axis should be [0,n_dim)"; + CHECK(val >= 0 && val < static_cast(input->shape.size())) + << "axis should be [0,n_dim)"; } std::vector shape = input->shape; return lang::Compute( @@ -864,10 +958,14 @@ ir::Tensor Reverse(const ir::Tensor& input, const std::vector& axis, const output_name); } -ir::Tensor Transpose(const ir::Tensor& input, const std::vector& axis, const std::string& output_name) { - CHECK_EQ(input->shape.size(), axis.size()) << "input shape size and axis size is not equal!"; +ir::Tensor Transpose(const ir::Tensor& input, + const std::vector& axis, + const std::string& output_name) { + CHECK_EQ(input->shape.size(), axis.size()) + << "input shape size and axis size is not equal!"; for (int idx = 0; idx < axis.size(); ++idx) { - CHECK(axis[idx] >= 0 && axis[idx] < axis.size()) << "axis value should be among [0,axis.size())"; + CHECK(axis[idx] >= 0 && axis[idx] < axis.size()) + << "axis value should be among [0,axis.size())"; for (int idy = idx + 1; idy < axis.size(); ++idy) { CHECK_NE(axis[idx], axis[idy]) << "axis value can't repeat!"; } @@ -938,7 +1036,8 @@ ir::Tensor Slice(const ir::Tensor& A, std::vector temp; int indice_i = 0; for (int i = 0; i < input_shape.size(); ++i) { - if (std::find(decrease_axis.cbegin(), decrease_axis.cend(), i) != decrease_axis.cend()) { + if (std::find(decrease_axis.cbegin(), decrease_axis.cend(), i) != + decrease_axis.cend()) { temp.emplace_back(0); } else { temp.emplace_back(indice[indice_i]); @@ -946,7 +1045,8 @@ ir::Tensor Slice(const ir::Tensor& A, } } for (int i = 0; i < axes.size(); i++) { - temp[axes[i]] = temp[axes[i]] * Expr(strides[i]) + Expr(new_starts[i]); + temp[axes[i]] = + temp[axes[i]] * Expr(strides[i]) + Expr(new_starts[i]); } return A(temp); }, @@ -960,9 +1060,12 @@ ir::Tensor SliceAssign(const ir::Tensor& input, const std::vector& ends, const std::vector& strides, const std::string& output_name) { - CHECK_EQ(axes.size(), starts.size()) << "axes's size is not equal to starts's size!"; - CHECK_EQ(axes.size(), ends.size()) << "axes's size is not equal to starts's size!"; - CHECK_EQ(axes.size(), strides.size()) << "axes's size is not equal to strides's size!"; + CHECK_EQ(axes.size(), starts.size()) + << "axes's size is not equal to starts's size!"; + CHECK_EQ(axes.size(), ends.size()) + << "axes's size is not equal to starts's size!"; + CHECK_EQ(axes.size(), strides.size()) + << "axes's size is not equal to strides's size!"; std::vector input_shape; for (const auto& shape : input->shape) { @@ -972,18 +1075,22 @@ ir::Tensor SliceAssign(const ir::Tensor& input, std::vector new_ends(ends); std::vector new_strides(strides); for (int i = 0; i < axes.size(); i++) { - CHECK_LT(axes[i], input->shape.size()) << "axes should less than input's shape size"; + CHECK_LT(axes[i], input->shape.size()) + << "axes should less than input's shape size"; if (new_starts[i] < 0) { new_starts[i] = input_shape[axes[i]] + new_starts[i]; - CHECK_GE(new_starts[i], 0) << "The value of [starts] should not less than " << -input_shape[axes[i]]; + CHECK_GE(new_starts[i], 0) + << "The value of [starts] should not less than " + << -input_shape[axes[i]]; } if (new_starts[i] > input_shape[axes[i]]) { new_starts[i] = input_shape[axes[i]]; } if (new_ends[i] < 0) { new_ends[i] = input_shape[axes[i]] + new_ends[i]; - CHECK_GE(new_ends[i], 0) << "The value of [ends] should not less than " << -input_shape[axes[i]]; + CHECK_GE(new_ends[i], 0) << "The value of [ends] should not less than " + << -input_shape[axes[i]]; } if (new_ends[i] > input_shape[axes[i]]) { new_ends[i] = input_shape[axes[i]]; @@ -992,16 +1099,19 @@ ir::Tensor SliceAssign(const ir::Tensor& input, // if strides < 0, starts > ends, we need swap them CHECK_NE(strides[i], 0) << "[strides] should not be 0 ! Please Check."; if (strides[i] < 0) { - CHECK_GT(new_starts[i], new_ends[i]) << "[starts] should greater than [ends] when [strides] < 0"; + CHECK_GT(new_starts[i], new_ends[i]) + << "[starts] should greater than [ends] when [strides] < 0"; // if strides > 0, the range is [starts, ends) // but if strides < 0, the range is (ends, starts] - auto tmp = new_starts[i]; - new_starts[i] = new_ends[i] + 1; // the new starts should not contain ends[i] - new_ends[i] = tmp + 1; // the new ends should contain starts[i] + auto tmp = new_starts[i]; + new_starts[i] = + new_ends[i] + 1; // the new starts should not contain ends[i] + new_ends[i] = tmp + 1; // the new ends should contain starts[i] new_strides[i] = -new_strides[i]; } else { - CHECK_LT(new_starts[i], new_ends[i]) << "[ends] shoould greater than [starts] when [strides] > 0"; + CHECK_LT(new_starts[i], new_ends[i]) + << "[ends] shoould greater than [starts] when [strides] > 0"; } } @@ -1009,7 +1119,7 @@ ir::Tensor SliceAssign(const ir::Tensor& input, auto output_tensor = Compute( input->shape, [=](const std::vector& indice) { - ir::Expr is_assigned = ir::Expr(true); + ir::Expr is_assigned = ir::Expr(true); std::vector tmp_indice = indice; for (int idx = 0; idx < axes.size(); ++idx) { // get input axis to be assigned @@ -1030,7 +1140,8 @@ ir::Tensor SliceAssign(const ir::Tensor& input, // check start <= axis < ends auto inside = ir::And::Make(ge, lt); // check (axis - starts) % strides == 0 - auto mod = ir::EQ::Make(ir::Mod::Make(out_axis, Expr(new_strides[idx])), Expr(0)); + auto mod = ir::EQ::Make( + ir::Mod::Make(out_axis, Expr(new_strides[idx])), Expr(0)); // check start <= axis < ends and (axis - starts) % strides == 0 is_assigned = ir::And::Make(is_assigned, ir::And::Make(inside, mod)); // update axis for assign tensor @@ -1047,9 +1158,11 @@ ir::Tensor Gather(const ir::Tensor& x, const std::vector& output_shape, int axis, const std::string& name) { - CHECK_EQ(x->shape.size(), index->shape.size()) << "The rank of x and index must be same."; + CHECK_EQ(x->shape.size(), index->shape.size()) + << "The rank of x and index must be same."; // The implementation details are explained below. - // If output_shape = [2, 4, 3] and axis = 0, `Compute` can be translated as the following code: + // If output_shape = [2, 4, 3] and axis = 0, `Compute` can be translated as + // the following code: // { // for (i, 0, 2) // { @@ -1068,10 +1181,11 @@ ir::Tensor Gather(const ir::Tensor& x, // 1) indice is got from `output_shape` // 2) transformed_indice is used in the input `x` std::vector transformed_indice = indice; - // The element type of index maybe int64, but the index type is limited to int32 in CINN. - // See the below link for more details: + // The element type of index maybe int64, but the index type is limited + // to int32 in CINN. See the below link for more details: // https://github.com/PaddlePaddle/CINN/blob/85ab4981a38926dc5c1dbf672762cec335d2b857/cinn/ir/ir.cc#L477 - transformed_indice[axis] = ir::Cast::Make(common::Int(32), index(indice)); + transformed_indice[axis] = + ir::Cast::Make(common::Int(32), index(indice)); return x(transformed_indice); }, name); @@ -1084,7 +1198,8 @@ ir::Tensor ScatterAssign(const ir::Tensor& input, const common::Target& target, const int axis, const std::string& output_name) { - CHECK_EQ(index->type(), common::Int(32)) << "Param [Index] of ScatterAssign only support int32 ! Please Check.\n"; + CHECK_EQ(index->type(), common::Int(32)) + << "Param [Index] of ScatterAssign only support int32 ! Please Check.\n"; std::string extern_fun_name; if (target.arch == common::Target::Arch::NVGPU) { extern_fun_name.assign("cinn_cuda_find_int"); @@ -1103,13 +1218,15 @@ ir::Tensor ScatterAssign(const ir::Tensor& input, // find whether indice[axis] in Index, // then return id if found Index[id] == indice[axis] // else return -1 - auto id = lang::CallExtern(extern_fun_name, {index, index->shape[0], indice[pos_axis]}); + auto id = lang::CallExtern(extern_fun_name, + {index, index->shape[0], indice[pos_axis]}); std::vector indice_updates = indice; - indice_updates[pos_axis] = id; + indice_updates[pos_axis] = id; // check wheter Index[id] == cur_index and return by check result - return ir::Select::Make(ir::EQ::Make(id, Expr(-1)), input(indice), updates(indice_updates)); + return ir::Select::Make( + ir::EQ::Make(id, Expr(-1)), input(indice), updates(indice_updates)); }, UniqName(output_name)); return res; @@ -1121,17 +1238,22 @@ ir::Tensor ScatterAdd(const ir::Tensor& input, const common::Target& target, const int axis, const std::string& output_name) { - CHECK_EQ(target.arch, common::Target::Arch::NVGPU) << "Op IndexAdd only support NVGPU now ! Please Check.\n"; + CHECK_EQ(target.arch, common::Target::Arch::NVGPU) + << "Op IndexAdd only support NVGPU now ! Please Check.\n"; - CHECK_EQ(index->type(), common::Int(32)) << "Param [index] of IndexAdd only support int32 ! Please Check.\n"; - CHECK_EQ(index->shape.size(), 1) << "The dimension of param [index] of IndexAdd should be 1 ! Please Check.\n"; + CHECK_EQ(index->type(), common::Int(32)) + << "Param [index] of IndexAdd only support int32 ! Please Check.\n"; + CHECK_EQ(index->shape.size(), 1) << "The dimension of param [index] of " + "IndexAdd should be 1 ! Please Check.\n"; CHECK_EQ(input->type(), updates->type()) - << "Please ensure that the data types for input and updates are identical.\n"; + << "Please ensure that the data types for input and updates are " + "identical.\n"; auto pos_axis = axis; if (pos_axis < 0) pos_axis += input->shape.size(); CHECK(pos_axis >= 0 && pos_axis < input->shape.size()) - << "Param [axis] of IndexAdd should satisfy 0 <= axis < input.shape ! Please Check.\n"; + << "Param [axis] of IndexAdd should satisfy 0 <= axis < input.shape ! " + "Please Check.\n"; // compute each dimension's stride, it is used for indice2offset. // for shape=[1,2,3,4], strides=[2*3*4,3*4,4*1,1]=[24, 12, 4, 1] @@ -1153,11 +1275,12 @@ ir::Tensor ScatterAdd(const ir::Tensor& input, return offset; }; - const std::string& extern_func_name = GetExternFuncName(target, input->type(), "index_add"); + const std::string& extern_func_name = + GetExternFuncName(target, input->type(), "index_add"); - // assume shape=[1,2,3], axis=1, `cinn_cuda_index_add` extern function do following compute: - // out[i][j][k] = input[i][j][k] - // for l in range(index.size()): + // assume shape=[1,2,3], axis=1, `cinn_cuda_index_add` extern function do + // following compute: out[i][j][k] = input[i][j][k] for l in + // range(index.size()): // if index[l] == j: // out[i][j][k] += update[i][l][k] auto output = Compute( diff --git a/paddle/cinn/hlir/pe/transform.h b/paddle/cinn/hlir/pe/transform.h index 55534c1e029c2..e6dffa42e803b 100644 --- a/paddle/cinn/hlir/pe/transform.h +++ b/paddle/cinn/hlir/pe/transform.h @@ -28,14 +28,16 @@ namespace hlir { namespace pe { namespace utils { -std::vector> GetMatmulNewShapes(const std::vector>& inputs_shape, - bool trans_x, - bool trans_y); - -std::vector> GetMulNewShapes(const std::vector>& inputs_shape, - int x_num_col_dims, - int y_num_col_dims, - bool is_infer = false); +std::vector> GetMatmulNewShapes( + const std::vector>& inputs_shape, + bool trans_x, + bool trans_y); + +std::vector> GetMulNewShapes( + const std::vector>& inputs_shape, + int x_num_col_dims, + int y_num_col_dims, + bool is_infer = false); } // namespace utils /** @@ -51,12 +53,13 @@ std::vector> GetMulNewShapes(const std::vector * * @return the output tensors */ -std::vector Matmul(const ir::Tensor& A, - const ir::Tensor& B, - bool trans_a = false, - bool trans_b = false, - float alpha = 1, - const std::string& name = UniqName("T_Transform_Matmul_out")); +std::vector Matmul( + const ir::Tensor& A, + const ir::Tensor& B, + bool trans_a = false, + bool trans_b = false, + float alpha = 1, + const std::string& name = UniqName("T_Transform_Matmul_out")); // realized by sharing buffer ir::Tensor Reshape(const ir::Tensor& A, @@ -66,28 +69,30 @@ ir::Tensor Reshape(const ir::Tensor& A, ir::Tensor Concat(const ir::Tensor& A, const ir::Tensor& B, - int axis = 0, + int axis = 0, const std::string& name = UniqName("T_Transform_Concat_out")); ir::Tensor Concat(const std::vector& input_tensors, - int axis = 0, + int axis = 0, const std::string& name = UniqName("T_Transform_Concat_out")); -std::vector MatmulV2(const ir::Tensor& A, - const ir::Tensor& B, - bool trans_a = false, - bool trans_b = false, - float alpha = 1, - const std::string& name = UniqName("T_Transform_MatmulV2_out"), - const common::Target& target = common::DefaultHostTarget()); - -std::vector MatmulMKL(const ir::Tensor& A, - const ir::Tensor& B, - bool trans_a = false, - bool trans_b = false, - float alpha = 1, - const std::string& name = UniqName("T_Transform_MatmulMKL_out"), - const common::Target& target = common::DefaultHostTarget()); +std::vector MatmulV2( + const ir::Tensor& A, + const ir::Tensor& B, + bool trans_a = false, + bool trans_b = false, + float alpha = 1, + const std::string& name = UniqName("T_Transform_MatmulV2_out"), + const common::Target& target = common::DefaultHostTarget()); + +std::vector MatmulMKL( + const ir::Tensor& A, + const ir::Tensor& B, + bool trans_a = false, + bool trans_b = false, + float alpha = 1, + const std::string& name = UniqName("T_Transform_MatmulMKL_out"), + const common::Target& target = common::DefaultHostTarget()); int GetMulFactor(int shape, const Type& type, const common::Target& target); @@ -100,12 +105,14 @@ int GetMulFactor(int shape, const Type& type, const common::Target& target); * @param target if target is x86, we will split the reduce axis * * @return the output tensors -Notes: this mul only support two-dims-tensor after flattening [M, K] * [N, K], K is the reduce axis +Notes: this mul only support two-dims-tensor after flattening [M, K] * [N, K], K +is the reduce axis */ -std::vector MulBase(const ir::Tensor& A, - const ir::Tensor& B, - const std::string& name = UniqName("T_Transform_MulBase_out"), - const common::Target& target = common::DefaultHostTarget()); +std::vector MulBase( + const ir::Tensor& A, + const ir::Tensor& B, + const std::string& name = UniqName("T_Transform_MulBase_out"), + const common::Target& target = common::DefaultHostTarget()); std::vector Mul(const ir::Tensor& A, const ir::Tensor& B, @@ -114,20 +121,23 @@ std::vector Mul(const ir::Tensor& A, const ir::Var& axis_k, const std::string& name); -std::vector MulMKL(const ir::Tensor& A, - const ir::Tensor& B, - const std::string& name = UniqName("T_Transform_MulMKL_out"), - const common::Target& target = common::DefaultHostTarget()); +std::vector MulMKL( + const ir::Tensor& A, + const ir::Tensor& B, + const std::string& name = UniqName("T_Transform_MulMKL_out"), + const common::Target& target = common::DefaultHostTarget()); -ir::Tensor LayoutTransform(const ir::Tensor& input, - const std::string& src_layout, - const std::string& dst_layout, - const std::string& name = UniqName("T_LayoutTransform_out")); +ir::Tensor LayoutTransform( + const ir::Tensor& input, + const std::string& src_layout, + const std::string& dst_layout, + const std::string& name = UniqName("T_LayoutTransform_out")); -std::vector InferShapeLayoutTransform(const std::vector& input_shapes, - const ir::Layout& old_layout, - const ir::Layout& new_layout, - absl::flat_hash_map>* split_index_map); +std::vector InferShapeLayoutTransform( + const std::vector& input_shapes, + const ir::Layout& old_layout, + const ir::Layout& new_layout, + absl::flat_hash_map>* split_index_map); /** * @brief Perform meta op Reverse @@ -145,9 +155,10 @@ ir::Tensor Reverse(const ir::Tensor& input, * @param axis tranpsoe axis * @param output_name the name of the output tensor */ -ir::Tensor Transpose(const ir::Tensor& input, - const std::vector& axis, - const std::string& output_name = UniqName("T_Transpose_out")); +ir::Tensor Transpose( + const ir::Tensor& input, + const std::vector& axis, + const std::string& output_name = UniqName("T_Transpose_out")); /** * @brief Perform meta op Split @@ -157,10 +168,11 @@ ir::Tensor Transpose(const ir::Tensor& input, * @param axis select axis * @param output_name the name of the output tensor */ -std::vector Split(const ir::Tensor& A, - int axis, - const std::vector>& output_shapes, - const std::vector& names); +std::vector Split( + const ir::Tensor& A, + int axis, + const std::vector>& output_shapes, + const std::vector& names); ir::Tensor Slice(const ir::Tensor& A, const std::vector& starts, @@ -179,13 +191,14 @@ ir::Tensor Slice(const ir::Tensor& A, * @param strides select reigon strides * @param output_name the name of the output tensor */ -ir::Tensor SliceAssign(const ir::Tensor& input, - const ir::Tensor& assign, - const std::vector& axes, - const std::vector& starts, - const std::vector& ends, - const std::vector& strides, - const std::string& output_name = UniqName("T_Transform_SliceAssign_out")); +ir::Tensor SliceAssign( + const ir::Tensor& input, + const ir::Tensor& assign, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + const std::vector& strides, + const std::string& output_name = UniqName("T_Transform_SliceAssign_out")); /** * @brief Perform meta op Split * @param A The input tensor @@ -196,7 +209,7 @@ ir::Tensor SliceAssign(const ir::Tensor& input, ir::Tensor Gather(const ir::Tensor& x, const ir::Tensor& index, const std::vector& output_shape, - int axis = 0, + int axis = 0, const std::string& name = UniqName("T_Transform_Gather_out")); /** @@ -206,12 +219,13 @@ ir::Tensor Gather(const ir::Tensor& x, * @param indexs The indexs tensor * @param output_name the name of the output tensor */ -ir::Tensor ScatterAssign(const ir::Tensor& input, - const ir::Tensor& updates, - const ir::Tensor& index, - const common::Target& target, - const int axis = 0, - const std::string& output_name = UniqName("T_Transform_ScatterAssign_out")); +ir::Tensor ScatterAssign( + const ir::Tensor& input, + const ir::Tensor& updates, + const ir::Tensor& index, + const common::Target& target, + const int axis = 0, + const std::string& output_name = UniqName("T_Transform_ScatterAssign_out")); /** * @brief Perform meta op ScatterAdd diff --git a/paddle/cinn/ir/buffer.cc b/paddle/cinn/ir/buffer.cc index c72849e268446..ef9227a2d128c 100755 --- a/paddle/cinn/ir/buffer.cc +++ b/paddle/cinn/ir/buffer.cc @@ -27,12 +27,15 @@ namespace ir { std::string TensorGetBufferName(const _Tensor_ *tensor) { CHECK(!tensor->name.empty()); CHECK(!utils::Startswith(tensor->name, "_")) - << "the name with prefix _ is not allowed for tensor. Current tensor's name is: " << tensor->name; + << "the name with prefix _ is not allowed for tensor. Current tensor's " + "name is: " + << tensor->name; return "_" + tensor->name; } std::string BufferGetTensorName(const _Buffer_ *buffer) { CHECK(!buffer->name.empty()); - CHECK(utils::Startswith(buffer->name, "_")) << "buffer's name should start with _"; + CHECK(utils::Startswith(buffer->name, "_")) + << "buffer's name should start with _"; return buffer->name.substr(1); } @@ -52,29 +55,29 @@ Buffer _Buffer_::Make(Var data, CHECK(dtype.valid()); CHECK(!dtype.is_unk()); CHECK(!dtype.is_void()); - auto *node = common::make_shared<_Buffer_>(); - node->shape = shape; - node->strides = strides; - node->elem_offset = elem_offset; - node->name = name; - node->scope = scope; + auto *node = common::make_shared<_Buffer_>(); + node->shape = shape; + node->strides = strides; + node->elem_offset = elem_offset; + node->name = name; + node->scope = scope; node->data_alignment = data_alignment; - node->offset_factor = offset_factor; - node->target = target; - node->dtype = dtype; + node->offset_factor = offset_factor; + node->target = target; + node->dtype = dtype; return Buffer(node); } Buffer _Buffer_::Make(const std::string &name, const std::vector &shape) { - auto *node = common::make_shared<_Buffer_>(); - node->name = name; + auto *node = common::make_shared<_Buffer_>(); + node->name = name; node->shape = shape; node->dtype = Void(); return Buffer(node); } Buffer _Buffer_::Make() { - auto *node = common::make_shared<_Buffer_>(); + auto *node = common::make_shared<_Buffer_>(); node->dtype = Void(); return Buffer(node); } @@ -85,11 +88,14 @@ void _Buffer_::BindTo(const Tensor &tensor) { BindTo(tensor.As<_Tensor_>()); } void _Buffer_::BindTo(const _Tensor_ *tensor) { if (name.empty()) name = TensorGetBufferName(tensor); if (type().is_unk()) set_type(tensor->type()); - CHECK(!tensor->shape.empty()) << "Tensor should have shape to bind to a Buffer"; + CHECK(!tensor->shape.empty()) + << "Tensor should have shape to bind to a Buffer"; shape = tensor->shape; binded_tensors_names_.insert(tensor->name); } -void _Buffer_::Unbind(const _Tensor_ *tensor) { binded_tensors_names_.erase(tensor->name); } +void _Buffer_::Unbind(const _Tensor_ *tensor) { + binded_tensors_names_.erase(tensor->name); +} Var _Buffer_::buffer_addr() const { auto thetype = type().ElementOf(); @@ -114,12 +120,13 @@ void _Buffer_::Verify() const { Expr Buffer::DestroyExpr() const { auto *node = operator->(); - return runtime::IntrinsicCall( - Void(), runtime::intrinsic::buffer_destroy, {ir::_Var_::Make(node->name, node->type())}); + return runtime::IntrinsicCall(Void(), + runtime::intrinsic::buffer_destroy, + {ir::_Var_::Make(node->name, node->type())}); } Expr _BufferRange_::Make(const Expr &buffer, const std::vector &ranges) { - auto node = make_shared<_BufferRange_>(); + auto node = make_shared<_BufferRange_>(); node->buffer = buffer; node->ranges = ranges; return Expr(node); @@ -129,7 +136,7 @@ void _BufferRange_::Verify() const { CHECK(buffer_ptr); } Expr _BufferRange_::Copy() const { - auto node = make_shared<_BufferRange_>(); + auto node = make_shared<_BufferRange_>(); node->buffer = buffer; node->ranges = ranges; node->set_type(type()); @@ -137,27 +144,31 @@ Expr _BufferRange_::Copy() const { } bool BufferRange::operator==(const BufferRange &x) const { - auto this_buffer = operator->()->buffer.As<_Buffer_>(); + auto this_buffer = operator->()->buffer.As<_Buffer_>(); auto other_buffer = x->buffer.As<_Buffer_>(); CHECK(this_buffer); CHECK(other_buffer); if (this_buffer != other_buffer) return false; if (x->ranges.size() != operator->()->ranges.size()) return false; for (int i = 0; i < x->ranges.size(); i++) { - Var this_range = operator->()->ranges[i]; + Var this_range = operator->()->ranges[i]; Var other_range = x->ranges[i]; - if (!is_zero(this_range->lower_bound - other_range->lower_bound)) return false; - if (!is_zero(this_range->upper_bound - other_range->upper_bound)) return false; + if (!is_zero(this_range->lower_bound - other_range->lower_bound)) + return false; + if (!is_zero(this_range->upper_bound - other_range->upper_bound)) + return false; } return true; } -bool BufferRange::operator!=(const BufferRange &x) const { return !(*this == x); } +bool BufferRange::operator!=(const BufferRange &x) const { + return !(*this == x); +} BufferRange &BufferRange::operator=(_BufferRange_ *x) { *this = BufferRange(x); return *this; } BufferRange &BufferRange::operator=(const _BufferRange_ *x) { - auto node = make_shared<_BufferRange_>(); + auto node = make_shared<_BufferRange_>(); node->buffer = x->buffer; node->ranges = x->ranges; node->set_type(x->type()); diff --git a/paddle/cinn/ir/buffer.h b/paddle/cinn/ir/buffer.h index 308af03286b29..7e80b6de9297f 100755 --- a/paddle/cinn/ir/buffer.h +++ b/paddle/cinn/ir/buffer.h @@ -42,9 +42,10 @@ std::string BufferGetTensorName(const _Buffer_* buffer); /** * Buffer is a symbolic multi-dimensional data structure, it is a node in IR. - * It is a composition of primitive symbolic types, used to specify the memory layout of the Tensor used in the program - * input. User can create a buffer and bind to multiple Tensors to specify that the tensors are not inlined and persist - * data to this buffer. + * It is a composition of primitive symbolic types, used to specify the memory + * layout of the Tensor used in the program input. User can create a buffer and + * bind to multiple Tensors to specify that the tensors are not inlined and + * persist data to this buffer. */ class Buffer : public IrNodeRef { public: @@ -88,7 +89,8 @@ class _Buffer_ : public ExprNode<_Buffer_> { MemoryType memory_type{MemoryType::Heap}; //! The data type of the elements. - //! This is different from `type`, a buffer's type should always be `cinn_buffer_t*`. + //! This is different from `type`, a buffer's type should always be + //! `cinn_buffer_t*`. Type dtype; _Buffer_() : elem_offset(Expr(0)) { set_type(type_of()); } @@ -104,13 +106,14 @@ class _Buffer_ : public ExprNode<_Buffer_> { int offset_factor, Target target = UnkTarget()); - static Buffer Make(const std::string& name, const std::vector& shape = {}); + static Buffer Make(const std::string& name, + const std::vector& shape = {}); static Buffer Make(const std::string& name, Type type) { CHECK(!type.is_void()); CHECK(!type.is_unk()); - auto n = make_shared<_Buffer_>(); - n->name = name; + auto n = make_shared<_Buffer_>(); + n->name = name; n->dtype = type; return Buffer(n); } @@ -118,14 +121,19 @@ class _Buffer_ : public ExprNode<_Buffer_> { //! Make an empty buffer. static Buffer Make(); - bool is_on_gpu() const { return memory_type == MemoryType::GPULocal || memory_type == MemoryType::GPUShared; } + bool is_on_gpu() const { + return memory_type == MemoryType::GPULocal || + memory_type == MemoryType::GPUShared; + } bool is_on_host() const { return !is_on_gpu(); } void BindTo(const Tensor& tensor); void BindTo(const _Tensor_* tensor); void Unbind(const _Tensor_* tensor); - const std::set& binded_tensor_names() const { return binded_tensors_names_; } + const std::set& binded_tensor_names() const { + return binded_tensors_names_; + } Var buffer_addr() const; @@ -138,18 +146,23 @@ class _Buffer_ : public ExprNode<_Buffer_> { static const IrNodeTy _node_type_ = IrNodeTy::_Buffer_; // Copy the meta infos to other. - void CopyMeta(_Buffer_* other) const { other->binded_tensors_names_ = binded_tensors_names_; } + void CopyMeta(_Buffer_* other) const { + other->binded_tensors_names_ = binded_tensors_names_; + } private: std::set binded_tensors_names_; }; -static bool operator<(const ir::Buffer& a, const ir::Buffer& b) { return a->name < b->name; } +static bool operator<(const ir::Buffer& a, const ir::Buffer& b) { + return a->name < b->name; +} // represents the multi-dimension ranges of the buffer struct _BufferRange_ : public ExprNode<_BufferRange_> { Expr buffer; - // For every range, it starts from var's lower_bound and ends at var's upper_bound. + // For every range, it starts from var's lower_bound and ends at var's + // upper_bound. std::vector ranges; _BufferRange_() = default; @@ -184,7 +197,9 @@ struct BufferRange : public IrNodeRef { const _BufferRange_* operator->() const { return get(); } _BufferRange_* operator->() { return get(); } - const _BufferRange_* get() const { return static_cast(ptr()); } + const _BufferRange_* get() const { + return static_cast(ptr()); + } _BufferRange_* get() { return static_cast<_BufferRange_*>(ptr()); } }; diff --git a/paddle/cinn/ir/buffer_test.cc b/paddle/cinn/ir/buffer_test.cc index b04db4a891134..9dd4c489c999d 100644 --- a/paddle/cinn/ir/buffer_test.cc +++ b/paddle/cinn/ir/buffer_test.cc @@ -36,7 +36,8 @@ TEST(Buffer, basic) { std::vector shape({Expr(100), Expr(20)}); Var i("i"), j("j"); std::vector strides({Expr(0), Expr(0)}); - auto buffer = _Buffer_::Make(ptr, ptr->type(), shape, strides, Expr(0), "buf", "", 0, 0); + auto buffer = _Buffer_::Make( + ptr, ptr->type(), shape, strides, Expr(0), "buf", "", 0, 0); // Check shared ASSERT_EQ(ref_count(buffer.get()).val(), 1); @@ -70,7 +71,7 @@ TEST(Buffer, bind_to_multiple_tensors) { Target target; target.arch = Target::Arch ::X86; target.bits = Target::Bit ::k32; - target.os = Target::OS ::Linux; + target.os = Target::OS ::Linux; ir::Module::Builder builder("module1", target); builder.AddFunction(funcs); @@ -78,7 +79,8 @@ TEST(Buffer, bind_to_multiple_tensors) { backends::CodeGenC codegen(target); codegen.SetInlineBuiltinCodes(false); - auto out = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + auto out = + codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); std::cout << "codegen C:" << std::endl << out << std::endl; } diff --git a/paddle/cinn/ir/collect_ir_nodes.cc b/paddle/cinn/ir/collect_ir_nodes.cc index 34ecda9aa39b3..afbf99e59d9c1 100644 --- a/paddle/cinn/ir/collect_ir_nodes.cc +++ b/paddle/cinn/ir/collect_ir_nodes.cc @@ -25,7 +25,7 @@ namespace ir { namespace { struct IrNodesCollector : public IRVisitor { - using teller_t = std::function; + using teller_t = std::function; using handler_t = std::function; teller_t teller; @@ -78,9 +78,11 @@ struct IrNodesCollector : public IRVisitor { }; struct IrNodesWithoutTensorCollector : public IrNodesCollector { - using teller_t = std::function; + using teller_t = std::function; using handler_t = std::function; - IrNodesWithoutTensorCollector(teller_t teller, handler_t handler, bool uniq_target) + IrNodesWithoutTensorCollector(teller_t teller, + handler_t handler, + bool uniq_target) : IrNodesCollector(std::move(teller), std::move(handler), uniq_target) {} void Visit(const _Tensor_* expr) override { @@ -93,49 +95,68 @@ struct IrNodesWithoutTensorCollector : public IrNodesCollector { } // namespace -std::set CollectIRNodes(Expr expr, std::function&& teller, bool uniq_target) { +std::set CollectIRNodes(Expr expr, + std::function&& teller, + bool uniq_target) { std::set exprs; - IrNodesCollector::handler_t handler = [&](const Expr* x) { exprs.insert(*x); }; - IrNodesCollector collector(std::move(teller), std::move(handler), uniq_target); + IrNodesCollector::handler_t handler = [&](const Expr* x) { + exprs.insert(*x); + }; + IrNodesCollector collector( + std::move(teller), std::move(handler), uniq_target); collector.Visit(&expr); return exprs; } -std::vector CollectIRNodesInOrder(Expr expr, std::function&& teller) { +std::vector CollectIRNodesInOrder( + Expr expr, std::function&& teller) { std::vector exprs; - IrNodesWithoutTensorCollector::handler_t handler = [&](const Expr* x) { exprs.push_back(*x); }; - IrNodesWithoutTensorCollector collector(std::move(teller), std::move(handler), false); + IrNodesWithoutTensorCollector::handler_t handler = [&](const Expr* x) { + exprs.push_back(*x); + }; + IrNodesWithoutTensorCollector collector( + std::move(teller), std::move(handler), false); collector.Visit(&expr); return exprs; } -std::set CollectIRNodesWithoutTensor(Expr expr, std::function&& teller, bool uniq_target) { +std::set CollectIRNodesWithoutTensor( + Expr expr, std::function&& teller, bool uniq_target) { std::set exprs; - IrNodesWithoutTensorCollector::handler_t handler = [&](const Expr* x) { exprs.insert(*x); }; - IrNodesWithoutTensorCollector collector(std::move(teller), std::move(handler), uniq_target); + IrNodesWithoutTensorCollector::handler_t handler = [&](const Expr* x) { + exprs.insert(*x); + }; + IrNodesWithoutTensorCollector collector( + std::move(teller), std::move(handler), uniq_target); collector.Visit(&expr); return exprs; } -std::map CollectTensorMap(Expr x, std::function&& extra_teller) { +std::map CollectTensorMap( + Expr x, std::function&& extra_teller) { std::map tensor_map; - auto tensors = CollectIRNodes(x, [&](const Expr* x) { return x->as_tensor() && extra_teller(x); }); + auto tensors = CollectIRNodes( + x, [&](const Expr* x) { return x->as_tensor() && extra_teller(x); }); for (auto& e : tensors) { - auto* t = e.as_tensor(); + auto* t = e.as_tensor(); tensor_map[t->name] = e; } return tensor_map; } -std::set CollectLoadTensors(Expr x, std::function&& teller) { +std::set CollectLoadTensors(Expr x, + std::function&& teller) { if (!x.defined()) return std::set(); struct Mutator : public ir::IRMutator { std::function teller; std::set exprs; - Mutator(std::function&& teller) : teller(std::move(teller)) {} + Mutator(std::function&& teller) + : teller(std::move(teller)) {} - void operator()(const Expr* expr) { ir::IRMutator::Visit(expr, expr); } + void operator()(const Expr* expr) { + ir::IRMutator::Visit(expr, expr); + } void Visit(const Load* op, const Expr* expr) override { if (teller(&op->tensor)) { @@ -149,13 +170,17 @@ std::set CollectLoadTensors(Expr x, std::function&& tel return mutator.exprs; } -std::set CollectStoreTensors(Expr x, std::function&& teller) { +std::set CollectStoreTensors(Expr x, + std::function&& teller) { struct Mutator : public ir::IRMutator { std::function teller; std::set exprs; - Mutator(std::function&& teller) : teller(std::move(teller)) {} + Mutator(std::function&& teller) + : teller(std::move(teller)) {} - void operator()(const Expr* expr) { ir::IRMutator::Visit(expr, expr); } + void operator()(const Expr* expr) { + ir::IRMutator::Visit(expr, expr); + } void Visit(const Store* op, const Expr* expr) override { if (teller(&op->tensor)) { @@ -169,7 +194,8 @@ std::set CollectStoreTensors(Expr x, std::function&& te return mutator.exprs; } -std::set CollectReferencedTensors(Expr x, const std::function& teller) { +std::set CollectReferencedTensors( + Expr x, const std::function& teller) { auto handle0 = teller; auto handle1 = teller; diff --git a/paddle/cinn/ir/collect_ir_nodes.h b/paddle/cinn/ir/collect_ir_nodes.h index 0f888aaed7468..75ed3fa9e64f4 100755 --- a/paddle/cinn/ir/collect_ir_nodes.h +++ b/paddle/cinn/ir/collect_ir_nodes.h @@ -22,35 +22,48 @@ namespace ir { /** * Collect the IR Nodes(without duplication) in the expression. */ -std::set CollectIRNodes(Expr x, std::function&& teller, bool uniq_target = false); +std::set CollectIRNodes(Expr x, + std::function&& teller, + bool uniq_target = false); /** - * Collect the IR Nodes(without duplication and tensor's compute body) in the expression. + * Collect the IR Nodes(without duplication and tensor's compute body) in the + * expression. */ -std::set CollectIRNodesWithoutTensor(Expr x, std::function&& teller, bool uniq_target = false); +std::set CollectIRNodesWithoutTensor( + Expr x, + std::function&& teller, + bool uniq_target = false); /** * Collect the IR Nodes from Block. */ -std::vector CollectIRNodesInOrder(Expr block, std::function&& teller); +std::vector CollectIRNodesInOrder( + Expr block, std::function&& teller); /** * Collect the tensors in Load nodes. */ -std::set CollectLoadTensors(Expr x, std::function&& teller); +std::set CollectLoadTensors(Expr x, + std::function&& teller); /** * Collect the tensors in Store nodes. */ -std::set CollectStoreTensors(Expr x, std::function&& teller); +std::set CollectStoreTensors(Expr x, + std::function&& teller); /** * Collect both the Store and Load nodes. */ -std::set CollectReferencedTensors(Expr x, const std::function& teller); +std::set CollectReferencedTensors( + Expr x, const std::function& teller); std::map CollectTensorMap( - Expr x, std::function&& extra_teller = [](const Expr* x) { return true; }); + Expr x, + std::function&& extra_teller = [](const Expr* x) { + return true; + }); } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/collect_ir_nodes_test.cc b/paddle/cinn/ir/collect_ir_nodes_test.cc index ed0e818801afa..82441b4a005c7 100644 --- a/paddle/cinn/ir/collect_ir_nodes_test.cc +++ b/paddle/cinn/ir/collect_ir_nodes_test.cc @@ -23,10 +23,12 @@ namespace ir { TEST(CollectIRNodes, basic0) { Expr C = Expr(1) + 2; - auto exprs = CollectIRNodes(C, [](const Expr* x) { return x->As(); }); + auto exprs = + CollectIRNodes(C, [](const Expr* x) { return x->As(); }); ASSERT_EQ(exprs.size(), 1UL); - auto ints = CollectIRNodes(C, [](const Expr* x) { return x->As(); }); + auto ints = + CollectIRNodes(C, [](const Expr* x) { return x->As(); }); ASSERT_EQ(ints.size(), 2UL); } @@ -45,13 +47,15 @@ TEST(CollectIRNodes, basic) { LOG(INFO) << "fn:\n" << fn; - auto tensors = CollectIRNodes(fn, [](const Expr* x) { return x->as_tensor(); }); + auto tensors = + CollectIRNodes(fn, [](const Expr* x) { return x->as_tensor(); }); ASSERT_EQ(tensors.size(), 5UL); auto fn_body = fn.As()->body; LOG(INFO) << "fn.body:\n" << fn_body; - auto tensors2 = CollectIRNodes(fn_body, [](const Expr* x) { return x->as_tensor(); }); - auto exprs = CollectIRNodes(fn_body, [](const Expr* x) { return x; }); + auto tensors2 = + CollectIRNodes(fn_body, [](const Expr* x) { return x->as_tensor(); }); + auto exprs = CollectIRNodes(fn_body, [](const Expr* x) { return x; }); } } // namespace ir diff --git a/paddle/cinn/ir/function_definition.h b/paddle/cinn/ir/function_definition.h index b77d1cc303d20..db09a98746cde 100644 --- a/paddle/cinn/ir/function_definition.h +++ b/paddle/cinn/ir/function_definition.h @@ -29,11 +29,13 @@ struct DefinitionContents; struct FunctionContents; /** - * A Function definition which can either represent a init or an update definition. + * A Function definition which can either represent a init or an update + * definition. */ class Definition { public: - explicit Definition(const std::shared_ptr& contents) : contents_(contents) {} + explicit Definition(const std::shared_ptr& contents) + : contents_(contents) {} private: std::shared_ptr contents_; diff --git a/paddle/cinn/ir/intrinsic_ops.cc b/paddle/cinn/ir/intrinsic_ops.cc index bf3bd4302a965..19e9c19f4a190 100644 --- a/paddle/cinn/ir/intrinsic_ops.cc +++ b/paddle/cinn/ir/intrinsic_ops.cc @@ -27,7 +27,8 @@ const Type& IntrinsicOp::GetOutputType(int offset) const { return output_types_[offset]; } -void IntrinsicOp::Verify(llvm::ArrayRef input_types, llvm::ArrayRef output_types) const { +void IntrinsicOp::Verify(llvm::ArrayRef input_types, + llvm::ArrayRef output_types) const { CHECK_EQ(input_types.size(), input_types_.size()); CHECK_EQ(output_types.size(), output_types_.size()); @@ -47,7 +48,8 @@ void IntrinsicOp::Verify(llvm::ArrayRef inputs) const { } } -void IntrinsicOp::Verify(llvm::ArrayRef inputs, llvm::ArrayRef outputs) const { +void IntrinsicOp::Verify(llvm::ArrayRef inputs, + llvm::ArrayRef outputs) const { llvm::SmallVector input_types, output_types; for (auto& e : inputs) input_types.push_back(e.type()); for (auto& e : outputs) output_types.push_back(e.type()); @@ -90,8 +92,8 @@ Expr intrinsics::BufferCreate::Make(Expr buffer) { Expr intrinsics::GetAddr::Make(Expr data) { auto* n = new GetAddr; n->set_type(data.type().PointerOf()); - n->data = data; - n->input_types_ = {data.type()}; + n->data = data; + n->input_types_ = {data.type()}; n->output_types_ = {data.type().PointerOf()}; return Expr(n); } @@ -111,12 +113,15 @@ Expr intrinsics::ArgsConstruct::Make(Var var, llvm::ArrayRef args) { return Expr(n); } -Expr intrinsics::BuiltinIntrin::Make( - const std::string& name, llvm::ArrayRef args, llvm::Intrinsic::ID id, int64_t arg_nums, const Type& type) { +Expr intrinsics::BuiltinIntrin::Make(const std::string& name, + llvm::ArrayRef args, + llvm::Intrinsic::ID id, + int64_t arg_nums, + const Type& type) { auto* n = new BuiltinIntrin; n->name = name; n->args.assign(args.begin(), args.end()); - n->id = id; + n->id = id; n->arg_nums = arg_nums; CHECK(!type.is_unk()); n->type_ = type; diff --git a/paddle/cinn/ir/intrinsic_ops.h b/paddle/cinn/ir/intrinsic_ops.h index 9f9ac71acbe11..6b8b433ad2100 100644 --- a/paddle/cinn/ir/intrinsic_ops.h +++ b/paddle/cinn/ir/intrinsic_ops.h @@ -24,8 +24,9 @@ #include "paddle/cinn/common/type.h" #include "paddle/cinn/ir/ir.h" -//! This file defines some intrinsic IR nodes, this is similar to the MLIR operations, we try to expose some underlying -//! opaque operations to IR system to helpe more intuitive codegen. +//! This file defines some intrinsic IR nodes, this is similar to the MLIR +//! operations, we try to expose some underlying opaque operations to IR system +//! to helpe more intuitive codegen. namespace cinn::ir { @@ -49,7 +50,9 @@ enum class IntrinsicKind { class IntrinsicOp : public IrNode { public: - IntrinsicOp(IntrinsicKind kind, llvm::ArrayRef input_types, llvm::ArrayRef output_types) + IntrinsicOp(IntrinsicKind kind, + llvm::ArrayRef input_types, + llvm::ArrayRef output_types) : kind_(kind), input_types_(input_types.begin(), input_types.end()), output_types_(output_types.begin(), output_types.end()) {} @@ -60,11 +63,17 @@ class IntrinsicOp : public IrNode { void AddInputType(const Type& type) { input_types_.push_back(type); } void AddOutputType(const Type& type) { output_types_.push_back(type); } - const llvm::SmallVectorImpl& input_types() const { return input_types_; } - const llvm::SmallVectorImpl& output_types() const { return input_types_; } - - //! Verify the \p input_types and \p output_types matches the signature of this operation. - void Verify(llvm::ArrayRef input_types, llvm::ArrayRef output_types) const; + const llvm::SmallVectorImpl& input_types() const { + return input_types_; + } + const llvm::SmallVectorImpl& output_types() const { + return input_types_; + } + + //! Verify the \p input_types and \p output_types matches the signature of + //! this operation. + void Verify(llvm::ArrayRef input_types, + llvm::ArrayRef output_types) const; void Verify(llvm::ArrayRef inputs, llvm::ArrayRef outputs) const; void Verify(llvm::ArrayRef inputs) const; @@ -92,11 +101,15 @@ namespace intrinsics { struct BufferGetDataHandle : public IntrinsicOp { // signature: (cinn_buffer_t*) -> (void*) BufferGetDataHandle() - : IntrinsicOp(IntrinsicKind::kBufferGetDataHandle, {type_of()}, {type_of()}) {} + : IntrinsicOp(IntrinsicKind::kBufferGetDataHandle, + {type_of()}, + {type_of()}) {} static Expr Make(Expr buffer); - static bool classof(const IntrinsicOp* s) { return s->getKind() == IntrinsicKind::kBufferGetDataHandle; } + static bool classof(const IntrinsicOp* s) { + return s->getKind() == IntrinsicKind::kBufferGetDataHandle; + } Expr buffer; }; @@ -107,11 +120,15 @@ struct BufferGetDataHandle : public IntrinsicOp { struct BufferGetDataConstHandle : public IntrinsicOp { // signature: (cinn_buffer_t*) -> (const void*) BufferGetDataConstHandle() - : IntrinsicOp(IntrinsicKind::kBufferGetDataConstHandle, {type_of()}, {type_of()}) {} + : IntrinsicOp(IntrinsicKind::kBufferGetDataConstHandle, + {type_of()}, + {type_of()}) {} static Expr Make(Expr buffer); - static bool classof(const IntrinsicOp* s) { return s->getKind() == IntrinsicKind::kBufferGetDataConstHandle; } + static bool classof(const IntrinsicOp* s) { + return s->getKind() == IntrinsicKind::kBufferGetDataConstHandle; + } Expr buffer; }; @@ -127,11 +144,15 @@ struct BufferGetDataConstHandle : public IntrinsicOp { */ struct PodValueToX : public IntrinsicOp { // signature: (cinn_pod_value_t*) -> (X), X is some pod type. - PodValueToX() : IntrinsicOp(IntrinsicKind::kPodValueToX, {type_of()}, {}) {} + PodValueToX() + : IntrinsicOp( + IntrinsicKind::kPodValueToX, {type_of()}, {}) {} static Expr Make(Expr pod_value_ptr, const Type& type); - static bool classof(const IntrinsicOp* s) { return s->getKind() == IntrinsicKind::kPodValueToX; } + static bool classof(const IntrinsicOp* s) { + return s->getKind() == IntrinsicKind::kPodValueToX; + } Expr pod_value_ptr; }; @@ -141,11 +162,15 @@ struct PodValueToX : public IntrinsicOp { */ struct BufferCreate : public IntrinsicOp { // signature: (cinn_buffer_t*) -> void - BufferCreate() : IntrinsicOp(IntrinsicKind::kBufferCreate, {type_of()}, {}) {} + BufferCreate() + : IntrinsicOp( + IntrinsicKind::kBufferCreate, {type_of()}, {}) {} static Expr Make(Expr buffer); - static bool classof(const IntrinsicOp* s) { return s->getKind() == IntrinsicKind::kBufferCreate; } + static bool classof(const IntrinsicOp* s) { + return s->getKind() == IntrinsicKind::kBufferCreate; + } Expr buffer; }; @@ -159,7 +184,9 @@ struct GetAddr : public IntrinsicOp { static Expr Make(Expr data); - static bool classof(const IntrinsicOp* s) { return s->getKind() == IntrinsicKind::kGetAddr; } + static bool classof(const IntrinsicOp* s) { + return s->getKind() == IntrinsicKind::kGetAddr; + } Expr data; }; @@ -172,7 +199,9 @@ struct ArgsConstruct : public IntrinsicOp { static Expr Make(Var var, llvm::ArrayRef args); - static bool classof(const IntrinsicOp* s) { return s->getKind() == IntrinsicKind::kArgsConstruct; } + static bool classof(const IntrinsicOp* s) { + return s->getKind() == IntrinsicKind::kArgsConstruct; + } Var var; llvm::SmallVector args; @@ -184,10 +213,15 @@ struct ArgsConstruct : public IntrinsicOp { struct BuiltinIntrin : public IntrinsicOp { BuiltinIntrin() : IntrinsicOp(IntrinsicKind::kBuiltinIntrin, {}, {}) {} - static Expr Make( - const std::string& name, llvm::ArrayRef args, llvm::Intrinsic::ID id, int64_t arg_nums, const Type& type); + static Expr Make(const std::string& name, + llvm::ArrayRef args, + llvm::Intrinsic::ID id, + int64_t arg_nums, + const Type& type); - static bool classof(const IntrinsicOp* s) { return s->getKind() == IntrinsicKind::kBuiltinIntrin; } + static bool classof(const IntrinsicOp* s) { + return s->getKind() == IntrinsicKind::kBuiltinIntrin; + } std::string name; llvm::SmallVector args; diff --git a/paddle/cinn/ir/intrinsic_ops_test.cc b/paddle/cinn/ir/intrinsic_ops_test.cc index 185dbcefc477e..7af2f93db4e46 100644 --- a/paddle/cinn/ir/intrinsic_ops_test.cc +++ b/paddle/cinn/ir/intrinsic_ops_test.cc @@ -21,7 +21,7 @@ namespace cinn::ir { TEST(IntrinsicOp, basic) { Expr buffer(1); buffer->set_type(type_of()); - auto op = intrinsics::BufferGetDataHandle::Make(buffer); + auto op = intrinsics::BufferGetDataHandle::Make(buffer); auto* ptr = op.As(); ASSERT_TRUE(ptr); auto* obj = llvm::dyn_cast(ptr); diff --git a/paddle/cinn/ir/ir.cc b/paddle/cinn/ir/ir.cc index 112b19f8d7910..f9a3b2655396b 100755 --- a/paddle/cinn/ir/ir.cc +++ b/paddle/cinn/ir/ir.cc @@ -44,7 +44,8 @@ Expr Cast::Make(Type t, Expr v) { void Cast::Verify() const { if (v().type() == type()) - LOG(WARNING) << "Found a Cast Node casting a value to the same type, this is not reasonable"; + LOG(WARNING) << "Found a Cast Node casting a value to the same type, this " + "is not reasonable"; } Expr Add::Make(Expr a, Expr b) { @@ -57,7 +58,8 @@ Add::Add(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) {} void BinaryNodeVerify(const Expr &a, const Expr &b, absl::string_view ir_name) { CHECK(a.defined()); CHECK(b.defined()); - CHECK_EQ(a.type(), b.type()) << "The operands' types of the node [" << ir_name << "] don't match"; + CHECK_EQ(a.type(), b.type()) + << "The operands' types of the node [" << ir_name << "] don't match"; } void Add::Verify() const { BinaryNodeVerify(a(), b(), "Add"); } @@ -192,7 +194,7 @@ Expr Let::Make(Expr symbol, Expr body) { CHECK(body.type().valid()); } n->symbol = symbol; - n->body = body; + n->body = body; n->set_type(n->symbol->type()); return Expr(n); } @@ -212,22 +214,25 @@ Expr _Var_::Make(const std::string &name, const Type &type) { return Expr(node); } -Expr _Var_::Make(Expr lower_bound, Expr upper_bound, const std::string &name, bool is_reduce_axis) { - auto *n = make_shared<_Var_>(); - n->lower_bound = lower_bound; - n->upper_bound = upper_bound; +Expr _Var_::Make(Expr lower_bound, + Expr upper_bound, + const std::string &name, + bool is_reduce_axis) { + auto *n = make_shared<_Var_>(); + n->lower_bound = lower_bound; + n->upper_bound = upper_bound; n->is_reduce_axis = is_reduce_axis; - n->name = name; + n->name = name; n->set_type(lower_bound.type()); return Expr(n); } Expr _Var_::Copy() const { - auto *n = make_shared<_Var_>(); - n->name = name; + auto *n = make_shared<_Var_>(); + n->name = name; n->is_reduce_axis = is_reduce_axis; - n->lower_bound = lower_bound; - n->upper_bound = upper_bound; + n->lower_bound = lower_bound; + n->upper_bound = upper_bound; n->set_type(type()); return Expr(n); } @@ -248,26 +253,29 @@ Expr For::Make(Var loop_var, CHECK(loop_var.defined()); CHECK(min.defined()); CHECK(extent.defined()); - node->loop_var = loop_var; - node->min = min; - node->extent = extent; + node->loop_var = loop_var; + node->min = min; + node->extent = extent; node->device_api = device_api; - node->body = body; + node->body = body; node->set_for_type(for_type); node->set_vectorize_info(vector_info); node->set_bind_info(bind_info); if (node->is_vectorized()) CHECK(node->vectorize_info().valid()); - if (node->is_binded() && bind_info.offset >= 0) CHECK(node->bind_info().valid()); + if (node->is_binded() && bind_info.offset >= 0) + CHECK(node->bind_info().valid()); return Expr(node); } std::vector For::expr_fields() { return {&min, &extent, &body}; } -std::vector For::expr_fields() const { return {&min, &extent, &body}; } +std::vector For::expr_fields() const { + return {&min, &extent, &body}; +} Expr Block::Make(const std::vector &stmts) { - auto node = make_shared(); + auto node = make_shared(); node->stmts = stmts; return Expr(node); } @@ -287,12 +295,12 @@ Expr ScheduleBlock::Make(const std::vector &iter_vars, const std::vector &write_buffers, const std::string &name, Expr body) { - auto node = make_shared(); - node->iter_vars = iter_vars; - node->read_buffers = read_buffers; + auto node = make_shared(); + node->iter_vars = iter_vars; + node->read_buffers = read_buffers; node->write_buffers = write_buffers; - node->name = name; - node->body = body; + node->name = name; + node->body = body; return Expr(node); } void ScheduleBlock::Verify() const { @@ -310,9 +318,10 @@ std::vector ScheduleBlock::expr_fields() const { return res; } -Expr ScheduleBlockRealize::Make(const std::vector &iter_values, const Expr &schedule_block) { - auto node = make_shared(); - node->iter_values = iter_values; +Expr ScheduleBlockRealize::Make(const std::vector &iter_values, + const Expr &schedule_block) { + auto node = make_shared(); + node->iter_values = iter_values; node->schedule_block = schedule_block; return Expr(node); } @@ -342,22 +351,30 @@ Expr IfThenElse::Make(Expr condition, Expr true_case, Expr false_case) { } IfThenElse::IfThenElse(Expr condition, Expr true_case, Expr false_case) - : ExprNode(Type()), condition(condition), true_case(true_case), false_case(false_case) { + : ExprNode(Type()), + condition(condition), + true_case(true_case), + false_case(false_case) { CHECK(condition.defined()); CHECK(true_case.defined()); } -std::vector IfThenElse::expr_fields() { return {&condition, &true_case, &false_case}; } -std::vector IfThenElse::expr_fields() const { return {&condition, &true_case, &false_case}; } +std::vector IfThenElse::expr_fields() { + return {&condition, &true_case, &false_case}; +} +std::vector IfThenElse::expr_fields() const { + return {&condition, &true_case, &false_case}; +} Expr Store::Make(Expr tensor, Expr value, const std::vector &indices) { CHECK(tensor.As<_Tensor_>()) << "tensor should be _Tensor_ type"; - auto node = make_shared(); - node->tensor = tensor; - node->value = value; + auto node = make_shared(); + node->tensor = tensor; + node->value = value; node->indices = indices; if (tensor->type() != Void()) { - node->set_type(tensor->type().ElementOf().with_lanes(node->index().type().lanes())); + node->set_type( + tensor->type().ElementOf().with_lanes(node->index().type().lanes())); } return Expr(node); } @@ -394,18 +411,24 @@ std::vector Store::expr_fields() const { void Store::Verify() const { CHECK(tensor.defined()); } -Expr Alloc::Make(Expr dest, Type type, const std::vector &extents, Expr condition, Expr body) { +Expr Alloc::Make(Expr dest, + Type type, + const std::vector &extents, + Expr condition, + Expr body) { auto node = make_shared(); CHECK(dest.As<_Buffer_>()) << "Alloc destination only supports Buffer"; node->destination = dest; - node->extents = extents; - node->condition = condition; - node->body = body; + node->extents = extents; + node->condition = condition; + node->body = body; node->set_type(type); return Expr(node); } -int32_t Alloc::ConstantAllocationSize() const { return ConstantAllocationSize(extents); } +int32_t Alloc::ConstantAllocationSize() const { + return ConstantAllocationSize(extents); +} int32_t Alloc::ConstantAllocationSize(const std::vector &extents) { int32_t res{1}; @@ -450,12 +473,12 @@ Expr Call::Make(Type type, CHECK(read_args[i].defined()); } - auto node = common::make_shared(type); - node->name = name; - node->read_args = read_args; - node->write_args = write_args; - node->call_type = call_type; - node->func = func; + auto node = common::make_shared(type); + node->name = name; + node->read_args = read_args; + node->write_args = write_args; + node->call_type = call_type; + node->func = func; node->value_index = value_index; node->set_type(type); node->attrs = attrs; @@ -484,13 +507,13 @@ Expr PolyFor::Make(Var iterator, Expr body, VectorizeInfo vectorize_info, BindInfo bind_info) { - auto n = make_shared(); - n->iterator = iterator; - n->init = init_val; - n->condition = condition; - n->inc = inc; + auto n = make_shared(); + n->iterator = iterator; + n->init = init_val; + n->condition = condition; + n->inc = inc; n->device_api = device_api; - n->body = body; + n->body = body; n->set_for_type(for_type); n->set_vectorize_info(vectorize_info); n->set_bind_info(bind_info); @@ -500,8 +523,12 @@ Expr PolyFor::Make(Var iterator, return Expr(n); } -std::vector PolyFor::expr_fields() { return {&init, &condition, &inc, &body}; } -std::vector PolyFor::expr_fields() const { return {&init, &condition, &inc, &body}; } +std::vector PolyFor::expr_fields() { + return {&init, &condition, &inc, &body}; +} +std::vector PolyFor::expr_fields() const { + return {&init, &condition, &inc, &body}; +} Expr PolyFor::ExtractExtent() const { auto nodes = CollectIRNodes(condition, [&](const Expr *e) { @@ -522,7 +549,8 @@ Expr PolyFor::ExtractExtent() const { if (le_n) { if (le_n->a() != Expr(iterator)) return Expr(); auto *le_b_int = le_n->b().As(); - if (le_b_int) return Expr(make_shared(Int(32), le_b_int->value + 1)); + if (le_b_int) + return Expr(make_shared(Int(32), le_b_int->value + 1)); return Add::Make(le_n->b(), Expr(1)); } @@ -533,7 +561,9 @@ Expr PolyFor::ExtractExtent() const { return Expr(); } -bool Var::operator==(const Var &o) const { return o->name == operator->()->name; } +bool Var::operator==(const Var &o) const { + return o->name == operator->()->name; +} bool Var::operator!=(const Var &o) const { return !(*this == o); } Var &Var::operator=(_Var_ *x) { @@ -550,8 +580,8 @@ Expr Load::Make(Expr tensor, const std::vector &indices) { CHECK(tensor->type().valid()); CHECK(!indices.empty()); for (auto &idx : indices) CHECK_EQ(idx.type().ElementOf(), Int(32)); - auto node = make_shared(); - node->tensor = tensor; + auto node = make_shared(); + node->tensor = tensor; node->indices = indices; node->set_type(node->type()); return Expr(node); @@ -584,7 +614,8 @@ Expr Load::index() const { if (is_addr_tensor()) { auto *tensor_n = tensor.As<_Tensor_>(); CHECK(tensor_n); - VLOG(3) << "Begin Load::index IndiceToAbsOffset of tensor: " << this->name(); + VLOG(3) << "Begin Load::index IndiceToAbsOffset of tensor: " + << this->name(); if (indices.size() == 1) { return indices[0]; } @@ -609,12 +640,15 @@ void Load::Verify() const { CHECK(!indices.empty()) << "At least one indice is needed"; for (auto &indice : indices) { CHECK(indice.defined()); - CHECK(indice.type().ElementOf() == type_of() || indice.type().ElementOf() == type_of()) + CHECK(indice.type().ElementOf() == type_of() || + indice.type().ElementOf() == type_of()) << "get type " << indice.type() << " vs (int64 or int32)"; } } -bool LoadStoreAddrMnger::is_addr_tensor() const { return tensor.As<_Tensor_>(); } +bool LoadStoreAddrMnger::is_addr_tensor() const { + return tensor.As<_Tensor_>(); +} bool LoadStoreAddrMnger::is_addr_scalar() const { return !is_addr_tensor(); } Expr Ramp::Make(Expr base, Expr stride, int lanes) { @@ -625,10 +659,10 @@ Expr Ramp::Make(Expr base, Expr stride, int lanes) { CHECK_EQ(stride.type(), Int(32)); CHECK_GT(lanes, 0); - auto *n = make_shared(); - n->base = base; + auto *n = make_shared(); + n->base = base; n->stride = stride; - n->lanes = lanes; + n->lanes = lanes; Type type(base.type().type(), base.type().bits(), lanes); n->set_type(type); return Expr(n); @@ -638,7 +672,7 @@ Expr Broadcast::Make(Expr value, int lanes) { CHECK(value.defined()); CHECK(value.type().valid()); - auto *n = make_shared(); + auto *n = make_shared(); n->value = value; n->lanes = lanes; @@ -648,13 +682,15 @@ Expr Broadcast::Make(Expr value, int lanes) { return Expr(n); } -Type Broadcast::type() const { return value.type().ElementOf().with_lanes(lanes); } +Type Broadcast::type() const { + return value.type().ElementOf().with_lanes(lanes); +} Expr Sum::Make(const std::vector &vs) { CHECK(!vs.empty()); if (vs.size() == 1) return vs.front(); - auto *n = make_shared(); + auto *n = make_shared(); auto type = vs.front().type(); for (auto &v : vs) CHECK_EQ(v.type(), type) << vs.front() << " " << v; @@ -668,7 +704,7 @@ Expr Sum::Make(const std::vector &vs) { Expr Product::Make(const std::vector &vs) { CHECK_GE(vs.size(), 1); - auto *n = make_shared(); + auto *n = make_shared(); auto type = vs.front().type(); for (auto &v : vs) CHECK_EQ(v.type(), type); @@ -681,31 +717,35 @@ Expr Product::Make(const std::vector &vs) { Expr FracOp::Make(Expr n, Expr d) { auto *node = make_shared(); - node->a() = n; - node->b() = d; + node->a() = n; + node->b() = d; return Expr(node); } ir::Module _Module_::Make(const std::string &name, Target target) { - auto n = make_shared<_Module_>(); - n->name = name; + auto n = make_shared<_Module_>(); + n->name = name; n->target = target; return ir::Module(n); } -Expr PrimitiveNode::Make(const std::string &name, const std::map &attrs) { - auto *n = make_shared(); - n->name = name; +Expr PrimitiveNode::Make(const std::string &name, + const std::map &attrs) { + auto *n = make_shared(); + n->name = name; n->attrs = attrs; return Expr(n); } -Expr Reduce::Make(Reduce::ReduceType reduce_type, Expr init, Expr body, const std::vector &reduce_aixs) { +Expr Reduce::Make(Reduce::ReduceType reduce_type, + Expr init, + Expr body, + const std::vector &reduce_aixs) { CHECK(body.defined()); CHECK(init.defined()); - auto n = common::make_shared(); - n->init = init; - n->body = body; + auto n = common::make_shared(); + n->init = init; + n->body = body; n->reduce_type = reduce_type; n->reduce_axis.append(reduce_aixs.begin(), reduce_aixs.end()); CHECK(body.type().valid()); @@ -746,7 +786,8 @@ void Select::Verify() const { CHECK(condition.defined()); CHECK(true_value.defined()); CHECK(false_value.defined()); - CHECK(condition.type().is_bool()) << "Select Node's condition should be a boolean"; + CHECK(condition.type().is_bool()) + << "Select Node's condition should be a boolean"; CHECK_EQ(true_value.type(), false_value.type()) << "Select Node's true_value and false_value should have the same type"; } @@ -802,12 +843,14 @@ void MultiOperandVerify(llvm::ArrayRef operands) { } void Product::Verify() const { - CHECK_GT(operands().size(), 1UL) << "Product node should have more than 1 operands"; + CHECK_GT(operands().size(), 1UL) + << "Product node should have more than 1 operands"; MultiOperandVerify(operands()); } void Sum::Verify() const { - CHECK_GT(operands().size(), 1UL) << "Sum node should have more than 1 operands"; + CHECK_GT(operands().size(), 1UL) + << "Sum node should have more than 1 operands"; MultiOperandVerify(operands()); } diff --git a/paddle/cinn/ir/ir.h b/paddle/cinn/ir/ir.h index 515d20c5909a2..49013139e51ea 100644 --- a/paddle/cinn/ir/ir.h +++ b/paddle/cinn/ir/ir.h @@ -46,8 +46,8 @@ class Module; using common::Object; using common::Shared; -// NOTE attr_t only support POD, can not contain Expr or other IR nodes, or the IRVisitor or IRCopy on PrimitiveNode -// will result in undefined behavior. +// NOTE attr_t only support POD, can not contain Expr or other IR nodes, or the +// IRVisitor or IRCopy on PrimitiveNode will result in undefined behavior. using attr_t = absl::variant; /** @@ -71,7 +71,9 @@ struct Cast : public ExprNode { static const IrNodeTy _node_type_ = IrNodeTy::Cast; std::vector expr_fields() override { return {&operand(0)}; } - std::vector expr_fields() const override { return {&operand(0)}; } + std::vector expr_fields() const override { + return {&operand(0)}; + } }; /** @@ -349,17 +351,21 @@ struct Call : public ExprNode { const std::vector& read_args, const std::vector& write_args, CallType call_type, - FunctionRef func = FunctionRef(), - int value_index = 0, + FunctionRef func = FunctionRef(), + int value_index = 0, const std::map& attrs = {}); void Verify() const override; - inline size_t total_args_count() const { return read_args.size() + write_args.size(); } + inline size_t total_args_count() const { + return read_args.size() + write_args.size(); + } inline bool is_extern_call() const { return call_type == CallType::Extern; } inline bool is_cinn_call() const { return call_type == CallType::CINN; } - inline bool is_intrinsic_call() const { return call_type == CallType::Intrinsic; } + inline bool is_intrinsic_call() const { + return call_type == CallType::Intrinsic; + } inline bool is_isl_call() const { return call_type == CallType::ISL; } std::vector expr_fields() override; @@ -385,11 +391,15 @@ struct _Var_ : public ExprNode<_Var_> { std::string tag; _Var_() = default; - _Var_(const std::string& name, Type type) : ExprNode<_Var_>(type), name(name) {} + _Var_(const std::string& name, Type type) + : ExprNode<_Var_>(type), name(name) {} static Expr Make(const std::string& name, const Type& type); //! Make a reduce axis. - static Expr Make(Expr lower_bound, Expr upper_bound, const std::string& name, bool is_reduce); + static Expr Make(Expr lower_bound, + Expr upper_bound, + const std::string& name, + bool is_reduce); void Verify() const override; @@ -402,11 +412,17 @@ struct _Var_ : public ExprNode<_Var_> { struct Var : public IrNodeRef { Var() = default; explicit Var(IrNode* n) : IrNodeRef(n) {} - explicit Var(const std::string& name_hint, Type t = type_of()) : Var(_Var_::Make(name_hint, t).ptr()) {} - Var(Expr lower_bound, Expr upper_bound, const std::string& name, bool is_reduce = false) + explicit Var(const std::string& name_hint, Type t = type_of()) + : Var(_Var_::Make(name_hint, t).ptr()) {} + Var(Expr lower_bound, + Expr upper_bound, + const std::string& name, + bool is_reduce = false) : Var(_Var_::Make(lower_bound, upper_bound, name, is_reduce)) {} - Var(int upper_bound, const std::string& name) : Var(_Var_::Make(Expr(0), Expr(upper_bound), name, false)) {} - Var(Expr upper_bound, const std::string& name) : Var(_Var_::Make(Expr(0), upper_bound, name, false)) {} + Var(int upper_bound, const std::string& name) + : Var(_Var_::Make(Expr(0), Expr(upper_bound), name, false)) {} + Var(Expr upper_bound, const std::string& name) + : Var(_Var_::Make(Expr(0), upper_bound, name, false)) {} operator Expr() { return Expr(get()); } operator Expr() const { @@ -449,7 +465,10 @@ struct Reduce : public ExprNode { //! The type of the reduce operation. ReduceType reduce_type; - static Expr Make(ReduceType reduce_type, Expr init, Expr body, const std::vector& reduce_aixs); + static Expr Make(ReduceType reduce_type, + Expr init, + Expr body, + const std::vector& reduce_aixs); Type type() const override { return body.type().ElementOf(); } @@ -462,7 +481,8 @@ struct Reduce : public ExprNode { }; /** - * Evaluates `true_value` and `false_value` then selects between them based on `condition`. + * Evaluates `true_value` and `false_value` then selects between them based on + * `condition`. */ struct Select : public ExprNode { Expr false_value; Select(Expr condition, Expr true_value, Expr false_value) - : ExprNode(true_value.type()), + condition(condition), + true_value(true_value), + false_value(false_value) { CHECK_EQ(true_value.type(), false_value.type()); CHECK(condition.type().is_bool()); } @@ -487,8 +510,12 @@ struct Select : public ExprNode(); - return Compare(lhs->condition, rhs->condition) && Compare(lhs->true_value, rhs->true_value) && + return Compare(lhs->condition, rhs->condition) && + Compare(lhs->true_value, rhs->true_value) && Compare(lhs->false_value, rhs->false_value); } bool IrEqualVisitor::Visit(const IfThenElse* lhs, const Expr* other) { auto* rhs = other->As(); - return Compare(lhs->condition, rhs->condition) && Compare(lhs->true_case, rhs->true_case) && + return Compare(lhs->condition, rhs->condition) && + Compare(lhs->true_case, rhs->true_case) && Compare(lhs->false_case, rhs->false_case); } @@ -164,31 +180,36 @@ bool IrEqualVisitor::Visit(const Block* lhs, const Expr* other) { bool IrEqualVisitor::Visit(const Call* lhs, const Expr* other) { auto* rhs = other->As(); return lhs->name == rhs->name && Compare(lhs->read_args, rhs->read_args) && - Compare(lhs->write_args, rhs->write_args) && Compare(lhs->attrs, rhs->attrs) && - lhs->call_type == rhs->call_type; + Compare(lhs->write_args, rhs->write_args) && + Compare(lhs->attrs, rhs->attrs) && lhs->call_type == rhs->call_type; // TODO(CtfGo): Compare `func` field } bool IrEqualVisitor::Visit(const _Var_* lhs, const Expr* other) { auto* rhs = other->As<_Var_>(); - return lhs->name == rhs->name && Compare(lhs->lower_bound, rhs->lower_bound) && + return lhs->name == rhs->name && + Compare(lhs->lower_bound, rhs->lower_bound) && Compare(lhs->upper_bound, rhs->upper_bound) && lhs->tag == rhs->tag; } bool IrEqualVisitor::Visit(const Load* lhs, const Expr* other) { auto* rhs = other->As(); - return Compare(lhs->tensor, rhs->tensor) && Compare(lhs->indices, rhs->indices); + return Compare(lhs->tensor, rhs->tensor) && + Compare(lhs->indices, rhs->indices); } bool IrEqualVisitor::Visit(const Store* lhs, const Expr* other) { auto* rhs = other->As(); - return Compare(lhs->tensor, rhs->tensor) && Compare(lhs->indices, rhs->indices); + return Compare(lhs->tensor, rhs->tensor) && + Compare(lhs->indices, rhs->indices); } bool IrEqualVisitor::Visit(const Alloc* lhs, const Expr* other) { auto* rhs = other->As(); - return Compare(lhs->destination, rhs->destination) && Compare(lhs->extents, rhs->extents) && - Compare(lhs->condition, rhs->condition) && Compare(lhs->body, rhs->body); + return Compare(lhs->destination, rhs->destination) && + Compare(lhs->extents, rhs->extents) && + Compare(lhs->condition, rhs->condition) && + Compare(lhs->body, rhs->body); } bool IrEqualVisitor::Visit(const Free* lhs, const Expr* other) { @@ -198,10 +219,14 @@ bool IrEqualVisitor::Visit(const Free* lhs, const Expr* other) { bool IrEqualVisitor::Visit(const _Buffer_* lhs, const Expr* other) { auto* rhs = other->As<_Buffer_>(); - return Compare(lhs->shape, rhs->shape) && Compare(lhs->strides, rhs->strides) && lhs->name == rhs->name && - lhs->scope == rhs->scope && Compare(lhs->elem_offset, rhs->elem_offset) && - lhs->offset_factor == rhs->offset_factor && lhs->target == rhs->target && - lhs->data_alignment == rhs->data_alignment && lhs->memory_type == rhs->memory_type && lhs->dtype == rhs->dtype; + return Compare(lhs->shape, rhs->shape) && + Compare(lhs->strides, rhs->strides) && lhs->name == rhs->name && + lhs->scope == rhs->scope && + Compare(lhs->elem_offset, rhs->elem_offset) && + lhs->offset_factor == rhs->offset_factor && + lhs->target == rhs->target && + lhs->data_alignment == rhs->data_alignment && + lhs->memory_type == rhs->memory_type && lhs->dtype == rhs->dtype; } bool IrEqualVisitor::Visit(const _Tensor_* lhs, const Expr* other) { @@ -212,22 +237,28 @@ bool IrEqualVisitor::Visit(const _Tensor_* lhs, const Expr* other) { bool IrEqualVisitor::Visit(const _LoweredFunc_* lhs, const Expr* other) { auto* rhs = other->As<_LoweredFunc_>(); if (lhs->name != rhs->name) { - VLOG(6) << "Not equal, lhs name=" << lhs->name << ", rhs name=" << rhs->name; + VLOG(6) << "Not equal, lhs name=" << lhs->name + << ", rhs name=" << rhs->name; return false; } - auto compare_args_fn = [this](const std::vector& largs, const std::vector& rargs) -> bool { + auto compare_args_fn = [this](const std::vector& largs, + const std::vector& rargs) -> bool { if (largs.size() != rargs.size()) { - VLOG(6) << "Not equal, lhs args size=" << largs.size() << ", rhs args size=" << rargs.size(); + VLOG(6) << "Not equal, lhs args size=" << largs.size() + << ", rhs args size=" << rargs.size(); return false; } for (auto i = 0; i < largs.size(); ++i) { const Argument& a = largs.at(i); const Argument& b = rargs.at(i); - bool equal = a.io == b.io; - equal = equal && (!a.is_var() && !b.is_var() || a.is_var() && b.is_var() && Compare(a.var_arg(), b.var_arg())); + bool equal = a.io == b.io; + equal = equal && + (!a.is_var() && !b.is_var() || + a.is_var() && b.is_var() && Compare(a.var_arg(), b.var_arg())); equal = equal && (!a.is_buffer() && !b.is_buffer() || - a.is_buffer() && b.is_buffer() && Compare(a.buffer_arg(), b.buffer_arg())); + a.is_buffer() && b.is_buffer() && + Compare(a.buffer_arg(), b.buffer_arg())); if (!equal) { VLOG(6) << "Not equal at Argument index=" << i; return false; @@ -236,18 +267,23 @@ bool IrEqualVisitor::Visit(const _LoweredFunc_* lhs, const Expr* other) { return true; }; - return compare_args_fn(lhs->args, rhs->args) && Compare(lhs->temp_bufs, rhs->temp_bufs) && + return compare_args_fn(lhs->args, rhs->args) && + Compare(lhs->temp_bufs, rhs->temp_bufs) && Compare(lhs->body, rhs->body) && lhs->device_api == rhs->device_api && - Compare(lhs->alloc_output_buffer_exprs, rhs->alloc_output_buffer_exprs) && - Compare(lhs->dealloc_output_buffer_exprs, rhs->dealloc_output_buffer_exprs) && + Compare(lhs->alloc_output_buffer_exprs, + rhs->alloc_output_buffer_exprs) && + Compare(lhs->dealloc_output_buffer_exprs, + rhs->dealloc_output_buffer_exprs) && Compare(lhs->buffer_data_cast_exprs, rhs->buffer_data_cast_exprs) && Compare(lhs->argument_prepare_exprs, rhs->argument_prepare_exprs); } bool IrEqualVisitor::Visit(const _Module_* lhs, const Expr* other) { auto* rhs = other->As<_Module_>(); - return lhs->name == rhs->name && lhs->target == rhs->target && Compare(lhs->buffers, rhs->buffers) && - Compare(lhs->functions, rhs->functions) && Compare(lhs->submodules, rhs->submodules); + return lhs->name == rhs->name && lhs->target == rhs->target && + Compare(lhs->buffers, rhs->buffers) && + Compare(lhs->functions, rhs->functions) && + Compare(lhs->submodules, rhs->submodules); } bool IrEqualVisitor::Visit(const Let* lhs, const Expr* other) { @@ -257,13 +293,15 @@ bool IrEqualVisitor::Visit(const Let* lhs, const Expr* other) { bool IrEqualVisitor::Visit(const Reduce* lhs, const Expr* other) { auto* rhs = other->As(); - return Compare(lhs->init, rhs->init) && Compare(lhs->body, rhs->body) && lhs->reduce_type == rhs->reduce_type; + return Compare(lhs->init, rhs->init) && Compare(lhs->body, rhs->body) && + lhs->reduce_type == rhs->reduce_type; // TODO(CtfGo): compare `reduce_axis` field } bool IrEqualVisitor::Visit(const Ramp* lhs, const Expr* other) { auto* rhs = other->As(); - return Compare(lhs->base, rhs->base) && Compare(lhs->stride, rhs->stride) && lhs->lanes == rhs->lanes; + return Compare(lhs->base, rhs->base) && Compare(lhs->stride, rhs->stride) && + lhs->lanes == rhs->lanes; } bool IrEqualVisitor::Visit(const Broadcast* lhs, const Expr* other) { @@ -288,12 +326,14 @@ bool IrEqualVisitor::Visit(const Sum* lhs, const Expr* other) { bool IrEqualVisitor::Visit(const PrimitiveNode* lhs, const Expr* other) { auto* rhs = other->As(); - return lhs->name == rhs->name && Compare(lhs->arguments, rhs->arguments) && Compare(lhs->attrs, rhs->attrs); + return lhs->name == rhs->name && Compare(lhs->arguments, rhs->arguments) && + Compare(lhs->attrs, rhs->attrs); } bool IrEqualVisitor::Visit(const IntrinsicOp* lhs, const Expr* other) { auto* rhs = other->As(); - return lhs->getKind() == rhs->getKind() && lhs->input_types() == rhs->input_types() && + return lhs->getKind() == rhs->getKind() && + lhs->input_types() == rhs->input_types() && lhs->output_types() == rhs->output_types(); // TODO(CtfGo): Compare every derived class of IntrinsicOp separately } @@ -305,14 +345,17 @@ bool IrEqualVisitor::Visit(const _BufferRange_* lhs, const Expr* other) { bool IrEqualVisitor::Visit(const ScheduleBlock* lhs, const Expr* other) { auto* rhs = other->As(); - return Compare(lhs->name, rhs->name, allow_name_suffix_diff_) && Compare(lhs->iter_vars, rhs->iter_vars) && - Compare(lhs->read_buffers, rhs->read_buffers) && Compare(lhs->write_buffers, rhs->write_buffers) && + return Compare(lhs->name, rhs->name, allow_name_suffix_diff_) && + Compare(lhs->iter_vars, rhs->iter_vars) && + Compare(lhs->read_buffers, rhs->read_buffers) && + Compare(lhs->write_buffers, rhs->write_buffers) && Compare(lhs->attrs, rhs->attrs) && Compare(lhs->body, rhs->body); } bool IrEqualVisitor::Visit(const ScheduleBlockRealize* lhs, const Expr* other) { auto* rhs = other->As(); - return Compare(lhs->iter_values, rhs->iter_values) && Compare(lhs->schedule_block, rhs->schedule_block); + return Compare(lhs->iter_values, rhs->iter_values) && + Compare(lhs->schedule_block, rhs->schedule_block); } } // namespace ir diff --git a/paddle/cinn/ir/ir_compare.h b/paddle/cinn/ir/ir_compare.h index 3b69b13d53235..75e9bcf2dccb0 100644 --- a/paddle/cinn/ir/ir_compare.h +++ b/paddle/cinn/ir/ir_compare.h @@ -21,16 +21,21 @@ namespace cinn { namespace ir { -// Determine whether two ir AST trees are euqal by comparing their struct and fields of each node through dfs visitor +// Determine whether two ir AST trees are euqal by comparing their struct and +// fields of each node through dfs visitor class IrEqualVisitor : public IRVisitorBase { public: - explicit IrEqualVisitor(bool allow_name_suffix_diff = false) : allow_name_suffix_diff_(allow_name_suffix_diff) {} + explicit IrEqualVisitor(bool allow_name_suffix_diff = false) + : allow_name_suffix_diff_(allow_name_suffix_diff) {} // Return true if they are euqal, otherwise false; bool Compare(const Expr& lhs, const Expr& rhs); private: - bool Compare(const std::string& lhs, const std::string& rhs, bool allow_name_suffix_diff = false); - bool Compare(const std::map& lhs, const std::map& rhs); + bool Compare(const std::string& lhs, + const std::string& rhs, + bool allow_name_suffix_diff = false); + bool Compare(const std::map& lhs, + const std::map& rhs); template bool Compare(const std::vector& lhs, const std::vector& rhs); diff --git a/paddle/cinn/ir/ir_compare_test.cc b/paddle/cinn/ir/ir_compare_test.cc index bb3c54c8d8e57..a4c374dc59960 100644 --- a/paddle/cinn/ir/ir_compare_test.cc +++ b/paddle/cinn/ir/ir_compare_test.cc @@ -37,13 +37,34 @@ TEST(TestIrCompare, SingleFunction) { {M, N}, [&](Var i, Var j) { return A(i, j) + ir::Expr(2.f); }, "C"); cinn::common::Context::Global().ResetNameId(); - auto funcs_1 = lang::LowerVec("add_const", poly::CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); + auto funcs_1 = lang::LowerVec("add_const", + poly::CreateStages({A, B}), + {A, B}, + {}, + {}, + nullptr, + target, + true); cinn::common::Context::Global().ResetNameId(); - auto funcs_2 = lang::LowerVec("add_const", poly::CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); + auto funcs_2 = lang::LowerVec("add_const", + poly::CreateStages({A, B}), + {A, B}, + {}, + {}, + nullptr, + target, + true); cinn::common::Context::Global().ResetNameId(); - auto funcs_3 = lang::LowerVec("add_const", poly::CreateStages({A, C}), {A, C}, {}, {}, nullptr, target, true); + auto funcs_3 = lang::LowerVec("add_const", + poly::CreateStages({A, C}), + {A, C}, + {}, + {}, + nullptr, + target, + true); ASSERT_EQ(funcs_1.size(), 1); ASSERT_EQ(funcs_2.size(), 1); @@ -114,10 +135,12 @@ TEST(TestIrCompare, SingleFunction) { ASSERT_TRUE(compartor.Compare(funcs_1.front(), funcs_1.front())); IrEqualVisitor compartor_allow_suffix_diff(true); // they are euqal if allowing suffix of name different - ASSERT_TRUE(compartor_allow_suffix_diff.Compare(funcs_1.front(), funcs_2.front())); + ASSERT_TRUE( + compartor_allow_suffix_diff.Compare(funcs_1.front(), funcs_2.front())); ASSERT_FALSE(compartor.Compare(funcs_1.front(), funcs_3.front())); - ASSERT_FALSE(compartor_allow_suffix_diff.Compare(funcs_1.front(), funcs_3.front())); + ASSERT_FALSE( + compartor_allow_suffix_diff.Compare(funcs_1.front(), funcs_3.front())); } } // namespace ir diff --git a/paddle/cinn/ir/ir_mutator.h b/paddle/cinn/ir/ir_mutator.h index 9cfaac27e47b1..9a7fa33756f4d 100755 --- a/paddle/cinn/ir/ir_mutator.h +++ b/paddle/cinn/ir/ir_mutator.h @@ -102,7 +102,8 @@ void IRMutator::Visit(const IfThenElse *expr, T op) { auto *node = op->template As(); IRVisitorBase::Visit(&node->condition, &node->condition); IRVisitorBase::Visit(&node->true_case, &node->true_case); - if (node->false_case.defined()) IRVisitorBase::Visit(&node->false_case, &node->false_case); + if (node->false_case.defined()) + IRVisitorBase::Visit(&node->false_case, &node->false_case); } template void IRMutator::Visit(const Block *expr, T op) { @@ -164,7 +165,8 @@ void IRMutator::Visit(const Alloc *expr, T op) { IRVisitorBase::Visit(&e, &e); } - if (node->condition.defined()) IRVisitorBase::Visit(&node->condition, &node->condition); + if (node->condition.defined()) + IRVisitorBase::Visit(&node->condition, &node->condition); if (node->body.defined()) { Expr body(node->body); IRVisitorBase::Visit(&node->body, &body); @@ -204,12 +206,14 @@ template void IRMutator::Visit(const Let *expr, T op) { auto *node = op->template As(); IRVisitorBase::Visit(&node->symbol, &node->symbol); - if (node->body.defined()) IRVisitorBase::Visit(&node->body, &node->body); + if (node->body.defined()) + IRVisitorBase::Visit(&node->body, &node->body); } template void IRMutator::Visit(const Reduce *expr, T op) { auto *node = op->template As(); - if (node->init.defined()) IRVisitorBase::Visit(&node->init, &node->init); + if (node->init.defined()) + IRVisitorBase::Visit(&node->init, &node->init); CHECK(node->body.defined()); IRVisitorBase::Visit(&node->body, &node->body); } diff --git a/paddle/cinn/ir/ir_operators.cc b/paddle/cinn/ir/ir_operators.cc index 09a0274dfdf38..407c771e070d7 100644 --- a/paddle/cinn/ir/ir_operators.cc +++ b/paddle/cinn/ir/ir_operators.cc @@ -32,12 +32,13 @@ Expr operator<<(Expr a, Expr b) { CHECK(b.type().is_int() || b.type().is_uint()); auto int_a = a.As(); auto int_b = b.As(); - Type t_a = a.type(); - Type t_b = b.type(); + Type t_a = a.type(); + Type t_b = b.type(); if (t_a.is_index_type() && t_b.is_index_type()) { if (int_b) { CHECK(int_b->value >= 0 && int_b->value < t_a.bits()) - << "Shift amount must be non-negative and less than " << t_a.bits() << " for type " << t_a << std::endl; + << "Shift amount must be non-negative and less than " << t_a.bits() + << " for type " << t_a << std::endl; if (int_b->value == 0) return a; } if (int_a && int_b) { @@ -52,12 +53,13 @@ Expr operator>>(Expr a, Expr b) { CHECK(b.type().is_int() || b.type().is_uint()); auto int_a = a.As(); auto int_b = b.As(); - Type t_a = a.type(); - Type t_b = b.type(); + Type t_a = a.type(); + Type t_b = b.type(); if (t_a.is_index_type() && t_b.is_index_type()) { if (int_b) { CHECK(int_b->value >= 0 && int_b->value < t_a.bits()) - << "Shift amount must be non-negative and less than " << t_a.bits() << " for type " << t_a << std::endl; + << "Shift amount must be non-negative and less than " << t_a.bits() + << " for type " << t_a << std::endl; if (int_b->value == 0) return a; } if (int_a && int_b) { @@ -72,8 +74,8 @@ Expr operator|(Expr a, Expr b) { CHECK(b.type().is_int() || b.type().is_uint()); auto int_a = a.As(); auto int_b = b.As(); - Type t_a = a.type(); - Type t_b = b.type(); + Type t_a = a.type(); + Type t_b = b.type(); if (t_a.is_index_type() && t_b.is_index_type()) { if (int_a && int_b) { return Expr(int_a->value | int_b->value); @@ -95,8 +97,8 @@ Expr operator&(Expr a, Expr b) { CHECK(b.type().is_int() || b.type().is_uint()); auto int_a = a.As(); auto int_b = b.As(); - Type t_a = a.type(); - Type t_b = b.type(); + Type t_a = a.type(); + Type t_b = b.type(); if (t_a.is_index_type() && t_b.is_index_type()) { if (int_a && int_b) { return Expr(int_a->value & int_b->value); @@ -109,7 +111,8 @@ Expr operator&(Expr a, Expr b) { auto func_name = hlir::GetExternFuncName(target, t_a, "bitwise_and"); return lang::CallExtern(func_name, {a, b}, {{"vectorizable", false}}); } else { - LOG(FATAL) << "Unsupport arch: " << target.arch_str() << " for bitwise_and."; + LOG(FATAL) << "Unsupport arch: " << target.arch_str() + << " for bitwise_and."; } } @@ -118,8 +121,8 @@ Expr operator^(Expr a, Expr b) { CHECK(b.type().is_int() || b.type().is_uint()); auto int_a = a.As(); auto int_b = b.As(); - Type t_a = a.type(); - Type t_b = b.type(); + Type t_a = a.type(); + Type t_b = b.type(); if (t_a.is_index_type() && t_b.is_index_type()) { if (int_a && int_b) { return Expr(int_a->value ^ int_b->value); @@ -132,7 +135,8 @@ Expr operator^(Expr a, Expr b) { auto func_name = hlir::GetExternFuncName(target, t_a, "bitwise_xor"); return lang::CallExtern(func_name, {a, b}, {{"vectorizable", false}}); } else { - LOG(FATAL) << "Unsupport arch: " << target.arch_str() << " for bitwise_xor."; + LOG(FATAL) << "Unsupport arch: " << target.arch_str() + << " for bitwise_xor."; } } @@ -145,7 +149,8 @@ Expr operator~(Expr a) { auto func_name = hlir::GetExternFuncName(target, a->type(), "bitwise_not"); return lang::CallExtern(func_name, {a}, {{"vectorizable", false}}); } else { - LOG(FATAL) << "Unsupport arch: " << target.arch_str() << " for bitwise_not."; + LOG(FATAL) << "Unsupport arch: " << target.arch_str() + << " for bitwise_not."; } } diff --git a/paddle/cinn/ir/ir_operators.h b/paddle/cinn/ir/ir_operators.h index b0a2eab109526..16c6901a1da61 100644 --- a/paddle/cinn/ir/ir_operators.h +++ b/paddle/cinn/ir/ir_operators.h @@ -22,85 +22,105 @@ namespace cinn { namespace ir { //-- left hand -- -template ::value>::type> +template ::value>::type> Expr operator+(Expr a, POD b) { return Add::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator-(Expr a, POD b) { return Sub::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator*(Expr a, POD b) { return Mul::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator/(Expr a, POD b) { return Div::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator%(Expr a, POD b) { return Mod::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator<(Expr a, POD b) { return LT::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator<=(Expr a, POD b) { return LE::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator>(Expr a, POD b) { return GT::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator>=(Expr a, POD b) { return GE::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator==(Expr a, POD b) { return EQ::Make(Expr(a), Expr(b)); } //- right hand -- -template ::value>::type> +template ::value>::type> Expr operator+(POD a, Expr b) { return Add::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator-(POD a, Expr b) { return Sub::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator*(POD a, Expr b) { return Mul::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator/(POD a, Expr b) { return Div::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator%(POD a, Expr b) { return Mod::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator<(POD a, Expr b) { return LT::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator<=(POD a, Expr b) { return LE::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator>(POD a, Expr b) { return GT::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator>=(POD a, Expr b) { return GE::Make(Expr(a), Expr(b)); } -template ::value>::type> +template ::value>::type> Expr operator==(POD a, Expr b) { return EQ::Make(Expr(a), Expr(b)); } diff --git a/paddle/cinn/ir/ir_printer.cc b/paddle/cinn/ir/ir_printer.cc index 7bcd404ec6201..068dc93807ccd 100644 --- a/paddle/cinn/ir/ir_printer.cc +++ b/paddle/cinn/ir/ir_printer.cc @@ -33,7 +33,8 @@ using common::bfloat16; using common::float16; void IrPrinter::Print(Expr e) { IRVisitor::Visit(&e); } -void IrPrinter::Print(const std::vector &exprs, const std::string &splitter) { +void IrPrinter::Print(const std::vector &exprs, + const std::string &splitter) { for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) { Print(exprs[i]); os_ << splitter; @@ -80,7 +81,8 @@ void IrPrinter::Visit(const FloatImm *x) { } else if (std::isnan(x->value)) { os_ << "cinn::common::raw_uint16_to_float16(0x7e00)"; } else { - os_ << "(float16)" << std::setprecision(std::numeric_limits::max_digits10) + os_ << "(float16)" + << std::setprecision(std::numeric_limits::max_digits10) << static_cast(x->value) << "f"; } } else if (x->type().is_bfloat16()) { @@ -89,16 +91,19 @@ void IrPrinter::Visit(const FloatImm *x) { } else if (std::isnan(x->value)) { os_ << "cinn::common::raw_uint16_to_bfloat16(0x7FC0)"; } else { - os_ << "(bfloat16)" << std::setprecision(std::numeric_limits::max_digits10) + os_ << "(bfloat16)" + << std::setprecision(std::numeric_limits::max_digits10) << static_cast(x->value) << "f"; } } else if (x->type().is_float(32)) { - os_ << std::setprecision(std::numeric_limits::max_digits10) << std::showpoint << x->value; + os_ << std::setprecision(std::numeric_limits::max_digits10) + << std::showpoint << x->value; if (std::isfinite(x->value)) { os_ << "f"; } } else if (x->type().is_float(64)) { - os_ << std::setprecision(std::numeric_limits::max_digits10) << std::showpoint << x->value; + os_ << std::setprecision(std::numeric_limits::max_digits10) + << std::showpoint << x->value; } else { LOG(FATAL) << "Not support float type: " << x->type(); } @@ -151,9 +156,10 @@ void IrPrinter::Visit(const For *x) { } else if (x->is_binded()) { auto &bind_info = x->bind_info(); if (bind_info.valid()) { - char axis_name = 'x' + bind_info.offset; - auto for_type = bind_info.for_type; - std::string prefix = for_type == ForType::GPUBlock ? "blockIdx." : "threadIdx."; + char axis_name = 'x' + bind_info.offset; + auto for_type = bind_info.for_type; + std::string prefix = + for_type == ForType::GPUBlock ? "blockIdx." : "threadIdx."; os() << "thread_bind[" << prefix << axis_name << "] for ("; } else { os() << "thread_bind[invalid info] for ("; @@ -334,11 +340,13 @@ void IrPrinter::DecIndent() { indent_ -= indent_unit; } void IrPrinter::Visit(const _Buffer_ *x) { std::vector dim_names; - std::transform(x->shape.begin(), x->shape.end(), std::back_inserter(dim_names), [&](const Expr &x) { - return utils::GetStreamCnt(x); - }); + std::transform(x->shape.begin(), + x->shape.end(), + std::back_inserter(dim_names), + [&](const Expr &x) { return utils::GetStreamCnt(x); }); - os_ << "_Buffer_<" << x->type() << ": " << utils::Join(dim_names, ",") << ">(" << x->name << ")"; + os_ << "_Buffer_<" << x->type() << ": " << utils::Join(dim_names, ",") << ">(" + << x->name << ")"; } void IrPrinter::Visit(const _Tensor_ *x) { os_ << "Tensor("; @@ -502,7 +510,7 @@ void IrPrinter::Visit(const ScheduleBlockRealize *x) { DoIndent(); os() << "{\n"; // print block vars and bindings - auto iter_vars = schedule_block->iter_vars; + auto iter_vars = schedule_block->iter_vars; auto iter_values = x->iter_values; CHECK_EQ(iter_vars.size(), iter_values.size()); IncIndent(); diff --git a/paddle/cinn/ir/ir_printer.h b/paddle/cinn/ir/ir_printer.h index cf6b41e75fbae..8d7e80a1d1982 100644 --- a/paddle/cinn/ir/ir_printer.h +++ b/paddle/cinn/ir/ir_printer.h @@ -35,7 +35,8 @@ struct IrPrinter : public IRVisitor { //! Emit an expression on the output stream. void Print(Expr e); //! Emit a expression list with , splitted. - void Print(const std::vector &exprs, const std::string &splitter = ", "); + void Print(const std::vector &exprs, + const std::string &splitter = ", "); //! Emit a binary operator template void PrintBinaryOp(const std::string &op, const BinaryOpNode *x); @@ -68,7 +69,8 @@ std::ostream &operator<<(std::ostream &os, const std::vector &a); std::ostream &operator<<(std::ostream &os, const Module &m); template -void IrPrinter::PrintBinaryOp(const std::string &op, const BinaryOpNode *x) { +void IrPrinter::PrintBinaryOp(const std::string &op, + const BinaryOpNode *x) { os_ << "("; Print(x->a()); os_ << " " + op + " "; diff --git a/paddle/cinn/ir/ir_schedule.cc b/paddle/cinn/ir/ir_schedule.cc index 06699e6442082..b7e4946fa2b26 100644 --- a/paddle/cinn/ir/ir_schedule.cc +++ b/paddle/cinn/ir/ir_schedule.cc @@ -52,7 +52,8 @@ class ScheduleImpl { ScheduleImpl() = default; explicit ScheduleImpl(const ModuleExpr& module_expr, bool debug_flag = false) : module_expr_(module_expr), debug_flag_(debug_flag) {} - explicit ScheduleImpl(ModuleExpr&& module_expr) : module_expr_(std::move(module_expr)) {} + explicit ScheduleImpl(ModuleExpr&& module_expr) + : module_expr_(std::move(module_expr)) {} //! Set the debug flag. void SetDebugFlag(bool debug_flag) { debug_flag_ = debug_flag; } @@ -62,7 +63,9 @@ class ScheduleImpl { void MergeExprs(); - void SetExprs(const std::vector& exprs) { module_expr_.SetExprs(exprs); } + void SetExprs(const std::vector& exprs) { + module_expr_.SetExprs(exprs); + } bool HasBlock(const std::string& block_name) const; @@ -72,23 +75,33 @@ class ScheduleImpl { std::vector GetChildBlocks(const Expr& expr) const; Expr GetBlock(const std::string& block_name) const; std::vector Split(const Expr& loop, const std::vector& factors); - std::vector SamplePerfectTile(utils::LinearRandomEngine::StateType* rand_seed, - const Expr& loop, - int n, - int max_innermost_factor); + std::vector SamplePerfectTile( + utils::LinearRandomEngine::StateType* rand_seed, + const Expr& loop, + int n, + int max_innermost_factor); Expr Fuse(const std::vector& loops); Expr Fuse(const std::string& block_name, const std::vector& loops_index); Expr Fuse(const Expr& block, const std::vector& loops_index); void ComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops); void SimpleComputeAt(const Expr& block, const Expr& loop); - void ReverseComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops); + void ReverseComputeAt(const Expr& block, + const Expr& loop, + bool keep_unit_loops); Expr GetRootBlock(const Expr& expr) const; - Expr CacheRead(const Expr& block, int read_buffer_index, const std::string& memory_type); - Expr CacheWrite(const Expr& block, int write_buffer_index, const std::string& memory_type); + Expr CacheRead(const Expr& block, + int read_buffer_index, + const std::string& memory_type); + Expr CacheWrite(const Expr& block, + int write_buffer_index, + const std::string& memory_type); void SyncThreads(const Expr& ir_node, bool after_node = true); - void SetBuffer(Expr& block, const std::string& memory_type, bool fixed = false); + void SetBuffer(Expr& block, + const std::string& memory_type, + bool fixed = false); Expr Reorder(const std::vector& loops); - Expr Reorder(const std::string& block_name, const std::vector& loops_index); + Expr Reorder(const std::string& block_name, + const std::vector& loops_index); Expr Reorder(const Expr& block, const std::vector& loops_index); DeviceAPI GetDeviceAPI() const; void MutateForType(const Expr& loop, ForType for_type, int factor = -1); @@ -102,9 +115,11 @@ class ScheduleImpl { Expr AddUnitLoop(const Expr& block) const; void Annotate(const Expr& block, const std::string& key, const attr_t& value); void Unannotate(Expr& block, const std::string& key); - void FlattenLoops(const std::vector& loops, const bool force_flat = false); + void FlattenLoops(const std::vector& loops, + const bool force_flat = false); void CopyTransformAndLoopInfo(const Expr& block, const Expr& block_target); - void CopyTransformAndLoopInfo(const std::string& block_name, const std::string& block_target_name); + void CopyTransformAndLoopInfo(const std::string& block_name, + const std::string& block_target_name); Expr SampleCategorical(utils::LinearRandomEngine::StateType* rand_seed, const std::vector& candidates, const std::vector& probs); @@ -116,38 +131,52 @@ class ScheduleImpl { bool debug_flag_{false}; }; -std::vector ScheduleImpl::Split(const Expr& loop, const std::vector& factors) { - CHECK(loop.As()) << "Expr param of Split must be For node! Please check."; +std::vector ScheduleImpl::Split(const Expr& loop, + const std::vector& factors) { + CHECK(loop.As()) + << "Expr param of Split must be For node! Please check."; auto* for_node = loop.As(); - CHECK(common::is_zero(for_node->min)) << "The For node must start with 0! Please check."; - CHECK(for_node->extent.is_constant()) << "The For node's extent must be constant! Please check."; + CHECK(common::is_zero(for_node->min)) + << "The For node must start with 0! Please check."; + CHECK(for_node->extent.is_constant()) + << "The For node's extent must be constant! Please check."; int tot_extent = for_node->extent.get_constant(); - VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, " << tot_extent << ") to (" - << cinn::utils::Join(factors, ", ") << ") at loop:\n" + VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, " + << tot_extent << ") to (" << cinn::utils::Join(factors, ", ") + << ") at loop:\n" << loop; auto processed_factors = ValidateFactors(factors, tot_extent); - int prod_size = std::accumulate(processed_factors.begin(), processed_factors.end(), 1, std::multiplies()); + int prod_size = std::accumulate(processed_factors.begin(), + processed_factors.end(), + 1, + std::multiplies()); std::vector new_loop_vars; Expr substitute_value(0); for (int i = 0; i < processed_factors.size(); ++i) { Var temp_var(common::UniqName(for_node->loop_var->name)); - substitute_value = Expr(temp_var) + substitute_value * Expr(processed_factors[i]); + substitute_value = + Expr(temp_var) + substitute_value * Expr(processed_factors[i]); new_loop_vars.push_back(temp_var); } substitute_value = common::AutoSimplify(substitute_value); - Expr new_node = optim::IRCopy(for_node->body); + Expr new_node = optim::IRCopy(for_node->body); ReplaceExpr(&new_node, {for_node->loop_var}, {substitute_value}); std::vector splited_loops; splited_loops.resize(processed_factors.size()); if (tot_extent < prod_size) { - new_node = IfThenElse::Make(LT::Make(substitute_value, for_node->extent), new_node); + new_node = IfThenElse::Make(LT::Make(substitute_value, for_node->extent), + new_node); } for (int i = processed_factors.size() - 1; i >= 0; i--) { if (!new_node.As()) new_node = Block::Make({new_node}); - new_node = For::Make( - new_loop_vars[i], Expr(0), Expr(processed_factors[i]), for_node->for_type(), for_node->device_api, new_node); + new_node = For::Make(new_loop_vars[i], + Expr(0), + Expr(processed_factors[i]), + for_node->for_type(), + for_node->device_api, + new_node); splited_loops[i] = new_node; } @@ -160,21 +189,26 @@ Expr ScheduleImpl::Fuse(const std::vector& loops) { VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n"); std::vector for_nodes; std::vector loop_vars; - CHECK(!loops.empty()) << "The loops param of Fuse should not be empty! Please check."; + CHECK(!loops.empty()) + << "The loops param of Fuse should not be empty! Please check."; for (const Expr& it_loop : loops) { - CHECK(it_loop.As()) << "Expr param of Fuse must be For node! Please check."; + CHECK(it_loop.As()) + << "Expr param of Fuse must be For node! Please check."; if (!for_nodes.empty()) { - CHECK(for_nodes.back()->body.As()) << "The body of for node is not Block!"; - CHECK_EQ(for_nodes.back()->body.As()->stmts.size(), 1U) << "The Block'size of for node is not 1!"; + CHECK(for_nodes.back()->body.As()) + << "The body of for node is not Block!"; + CHECK_EQ(for_nodes.back()->body.As()->stmts.size(), 1U) + << "The Block'size of for node is not 1!"; CHECK_EQ(for_nodes.back()->body.As()->stmts[0], it_loop) - << "The For nodes in loops param of Fuse must be adjacent! Please check."; + << "The For nodes in loops param of Fuse must be adjacent! Please " + "check."; } for_nodes.push_back(it_loop.As()); loop_vars.push_back(it_loop.As()->loop_var); } std::string suffix; - suffix = for_nodes[0]->loop_var->name; + suffix = for_nodes[0]->loop_var->name; int loops_number = for_nodes.size(); for (int i = 1; i < loops_number; ++i) { suffix += "_" + for_nodes[i]->loop_var->name; @@ -186,7 +220,7 @@ Expr ScheduleImpl::Fuse(const std::vector& loops) { Expr fused_expr(fused_var); for (int i = loops_number - 1; i > 0; i--) { substitute_value[i] = Mod::Make(fused_expr, for_nodes[i]->extent); - fused_expr = Div::Make(fused_expr, for_nodes[i]->extent); + fused_expr = Div::Make(fused_expr, for_nodes[i]->extent); } substitute_value[0] = fused_expr; @@ -200,51 +234,66 @@ Expr ScheduleImpl::Fuse(const std::vector& loops) { fused_extent = common::AutoSimplify(fused_extent); if (!fused_body.As()) fused_body = Block::Make({fused_body}); - Expr new_stmt = - For::Make(fused_var, Expr(0), fused_extent, for_nodes[0]->for_type(), for_nodes[0]->device_api, fused_body); + Expr new_stmt = For::Make(fused_var, + Expr(0), + fused_extent, + for_nodes[0]->for_type(), + for_nodes[0]->device_api, + fused_body); this->Replace(loops[0], new_stmt); VLOG(3) << "After fuse, ir is:\n" << new_stmt; return new_stmt; } -Expr ScheduleImpl::Fuse(const std::string& block_name, const std::vector& loops_index) { +Expr ScheduleImpl::Fuse(const std::string& block_name, + const std::vector& loops_index) { std::vector all_loops = this->GetLoops(block_name); std::vector loops_expr; loops_expr.reserve(loops_index.size()); for (int i = 0; i < loops_index.size(); ++i) { - if (i > 0) CHECK_EQ(loops_index[i - 1] + 1, loops_index[i]) << "Loops index in Fuse shoule be continuous!"; + if (i > 0) + CHECK_EQ(loops_index[i - 1] + 1, loops_index[i]) + << "Loops index in Fuse shoule be continuous!"; } for (int i : loops_index) { - CHECK_LT(i, (int)all_loops.size()) << "The loop index in Fuse should be less than total loop's number."; + CHECK_LT(i, (int)all_loops.size()) + << "The loop index in Fuse should be less than total loop's number."; CHECK_GE(i, 0) << "The loop index in Fuse should be >= 0."; loops_expr.emplace_back(all_loops[i]); } return this->Fuse(loops_expr); } -Expr ScheduleImpl::Fuse(const Expr& block, const std::vector& loops_index) { +Expr ScheduleImpl::Fuse(const Expr& block, + const std::vector& loops_index) { std::vector all_loops = this->GetLoops(block); std::vector loops_expr; loops_expr.reserve(loops_index.size()); for (int i = 0; i < loops_index.size(); ++i) { - if (i > 0) CHECK_EQ(loops_index[i - 1] + 1, loops_index[i]) << "Loops index in Fuse shoule be continuous!"; + if (i > 0) + CHECK_EQ(loops_index[i - 1] + 1, loops_index[i]) + << "Loops index in Fuse shoule be continuous!"; } for (int i : loops_index) { - CHECK_LT(i, (int)all_loops.size()) << "The loop index in Fuse should be less than total loop's number."; + CHECK_LT(i, (int)all_loops.size()) + << "The loop index in Fuse should be less than total loop's number."; CHECK_GE(i, 0) << "The loop index in Fuse should be >= 0."; loops_expr.emplace_back(all_loops[i]); } return this->Fuse(loops_expr); } -void ScheduleImpl::MutateForType(const Expr& loop, ForType for_type, int factor) { +void ScheduleImpl::MutateForType(const Expr& loop, + ForType for_type, + int factor) { auto* for_node = loop.As(); CHECK(for_node) << "loop param must be For node! Please check."; - CHECK(for_node->is_serial()) << "loop is not serial, current forloop type is " - << static_cast(for_node->for_type()) << ", and it cannot become " - << static_cast(for_type); - auto loop_copy = optim::IRCopy(loop); + CHECK(for_node->is_serial()) + << "loop is not serial, current forloop type is " + << static_cast(for_node->for_type()) << ", and it cannot become " + << static_cast(for_type); + auto loop_copy = optim::IRCopy(loop); auto* new_for_node = loop_copy.As(); CHECK(new_for_node); new_for_node->set_for_type(for_type); @@ -258,19 +307,28 @@ void ScheduleImpl::MutateForType(const Expr& loop, ForType for_type, int factor) this->Replace(loop, loop_copy); } -void ScheduleImpl::Parallel(const Expr& loop) { MutateForType(loop, ForType::Parallel); } +void ScheduleImpl::Parallel(const Expr& loop) { + MutateForType(loop, ForType::Parallel); +} void ScheduleImpl::Vectorize(const Expr& loop, int factor) { CHECK_GT(factor, 0) << "vectorize factor should be more than 0"; MutateForType(loop, ForType::Vectorized, factor); } -void ScheduleImpl::Unroll(const Expr& loop) { MutateForType(loop, ForType::Unrolled); } +void ScheduleImpl::Unroll(const Expr& loop) { + MutateForType(loop, ForType::Unrolled); +} void ScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) { - static std::set thread_axes = { - "blockIdx.x", "blockIdx.y", "blockIdx.z", "threadIdx.x", "threadIdx.y", "threadIdx.z"}; - CHECK(thread_axes.count(thread_axis)) << "thread_axis " << thread_axis << " is not supported"; + static std::set thread_axes = {"blockIdx.x", + "blockIdx.y", + "blockIdx.z", + "threadIdx.x", + "threadIdx.y", + "threadIdx.z"}; + CHECK(thread_axes.count(thread_axis)) + << "thread_axis " << thread_axis << " is not supported"; int offset = thread_axis.back() - 'x'; if (thread_axis[0] == 'b') { MutateForType(loop, ForType::GPUBlock, offset); @@ -282,7 +340,8 @@ void ScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) { // The struct used to mutate new rfactor forloop and its' schedule block. struct RfMutator : public ir::IRMutator<> { public: - RfMutator(const Expr& rf_loop, const int& rf_axis) : rf_loop_(rf_loop), rf_axis_(rf_axis) {} + RfMutator(const Expr& rf_loop, const int& rf_axis) + : rf_loop_(rf_loop), rf_axis_(rf_axis) {} void operator()(Expr* expr) { auto* rf_for = rf_loop_.As(); CHECK(rf_for); @@ -299,9 +358,9 @@ struct RfMutator : public ir::IRMutator<> { CHECK(node); auto* schedule_block = node->schedule_block.As(); CHECK(schedule_block); - old_output_name_ = schedule_block->name; - find_tensor_ = false; - auto& block_vars = schedule_block->iter_vars; + old_output_name_ = schedule_block->name; + find_tensor_ = false; + auto& block_vars = schedule_block->iter_vars; auto& iter_values = node->iter_values; CHECK(old_rf_loop_var_.defined()); CHECK(new_rf_loop_var_.defined()); @@ -310,21 +369,27 @@ struct RfMutator : public ir::IRMutator<> { for (int i = 0; i < iter_values.size(); ++i) { // substitute the old rfactor loop var to new rfactor loop var if (ContainVar({iter_values[i]}, old_rf_loop_var_->name)) { - CHECK_EQ(rf_index, -1) << "only one block var can bind the rfactor loop var"; - CHECK(iter_values[i].As<_Var_>()) << "rfactor loop var not support composite bindings"; + CHECK_EQ(rf_index, -1) + << "only one block var can bind the rfactor loop var"; + CHECK(iter_values[i].As<_Var_>()) + << "rfactor loop var not support composite bindings"; rf_index = i; - optim::ReplaceVarWithExpr(&iter_values[i], old_rf_loop_var_, new_rf_loop_var_); + optim::ReplaceVarWithExpr( + &iter_values[i], old_rf_loop_var_, new_rf_loop_var_); new_rf_itervar_ = block_vars[i]; } } // create new rfactor block var if not exist if (rf_index == -1) { - new_rf_itervar_ = Var(cinn::UniqName("i" + std::to_string(block_vars.size()))); + new_rf_itervar_ = + Var(cinn::UniqName("i" + std::to_string(block_vars.size()))); iter_values.push_back(new_rf_loop_var_); block_vars.push_back(new_rf_itervar_); } IRMutator::Visit(&node->schedule_block, &node->schedule_block); - CHECK(find_tensor_) << "not find the store tensor with the schedule block name " << old_output_name_; + CHECK(find_tensor_) + << "not find the store tensor with the schedule block name " + << old_output_name_; schedule_block->name = "rf_" + old_output_name_; } @@ -336,11 +401,12 @@ struct RfMutator : public ir::IRMutator<> { CHECK(tensor); if (tensor->name == "rf_" + old_output_name_) { int size = node->indices.size(); - CHECK_LE(rf_axis_, size) << "rf_axis should not be greater than indice size " << size; + CHECK_LE(rf_axis_, size) + << "rf_axis should not be greater than indice size " << size; CHECK(new_rf_itervar_.defined()); CHECK(!ContainVar(node->indices, new_rf_itervar_->name)) - << "original output tensor " << old_output_name_ << " should not have the new rfactor index " - << new_rf_itervar_; + << "original output tensor " << old_output_name_ + << " should not have the new rfactor index " << new_rf_itervar_; node->indices.insert(node->indices.begin() + rf_axis_, new_rf_itervar_); } } @@ -354,25 +420,30 @@ struct RfMutator : public ir::IRMutator<> { if (tensor->name == old_output_name_) { find_tensor_ = true; tensor->name = "rf_" + tensor->name; - int size = node->indices.size(); - CHECK_LE(rf_axis_, size) << "rf_axis should not be greater than indice size " << size; + int size = node->indices.size(); + CHECK_LE(rf_axis_, size) + << "rf_axis should not be greater than indice size " << size; CHECK(!ContainVar(node->indices, new_rf_itervar_->name)) - << "original output tensor " << old_output_name_ << " should not have the new rfactor index " - << new_rf_itervar_; + << "original output tensor " << old_output_name_ + << " should not have the new rfactor index " << new_rf_itervar_; node->indices.insert(node->indices.begin() + rf_axis_, new_rf_itervar_); auto* rf_for = rf_loop_.As(); CHECK(rf_for); CHECK(is_zero(rf_for->min)) << "rfactor loop's min should be zero"; - auto extent = common::AutoSimplify(rf_for->extent); - auto& shape = tensor->shape; + auto extent = common::AutoSimplify(rf_for->extent); + auto& shape = tensor->shape; auto& domain = tensor->domain; - CHECK_LE(rf_axis_, shape.size()) << "rf_axis should not be greater than tensor shape size " << shape.size(); - CHECK_LE(rf_axis_, domain.size()) << "rf_axis should not be greater than tensor domain size " << domain.size(); + CHECK_LE(rf_axis_, shape.size()) + << "rf_axis should not be greater than tensor shape size " + << shape.size(); + CHECK_LE(rf_axis_, domain.size()) + << "rf_axis should not be greater than tensor domain size " + << domain.size(); shape.insert(shape.begin() + rf_axis_, extent); domain.insert(domain.begin() + rf_axis_, extent); if (tensor->buffer.defined()) { if (tensor->buffer->name.find_first_of("rf") == std::string::npos) { - tensor->buffer->name = "rf_" + tensor->buffer->name; + tensor->buffer->name = "rf_" + tensor->buffer->name; tensor->buffer->shape = shape; } } @@ -401,12 +472,20 @@ struct RfMutator : public ir::IRMutator<> { } if (rf_axis_ == 0 && depth == rf_axis_) { // insert new rfactor forloop in the rf_axis as serial loop - *expr = For::Make( - new_rf_loop_var_, rf_for->min, rf_for->extent, ForType::Serial, rf_for->device_api, Block::Make({*expr})); + *expr = For::Make(new_rf_loop_var_, + rf_for->min, + rf_for->extent, + ForType::Serial, + rf_for->device_api, + Block::Make({*expr})); } else if (depth == rf_axis_ - 1) { // insert new rfactor forloop in the rf_axis as serial loop - node->body = Block::Make( - {For::Make(new_rf_loop_var_, rf_for->min, rf_for->extent, ForType::Serial, rf_for->device_api, node->body)}); + node->body = Block::Make({For::Make(new_rf_loop_var_, + rf_for->min, + rf_for->extent, + ForType::Serial, + rf_for->device_api, + node->body)}); } depth--; } @@ -416,7 +495,7 @@ struct RfMutator : public ir::IRMutator<> { Var old_rf_loop_var_; Var new_rf_loop_var_; int rf_axis_; - int depth = -1; + int depth = -1; bool find_tensor_ = false; std::string old_output_name_; Var new_rf_itervar_; @@ -426,7 +505,9 @@ struct RfMutator : public ir::IRMutator<> { // The struct used to mutate final write-back forloop and schedule block. struct FinalMutator : public ir::IRMutator<> { public: - FinalMutator(const Expr& rf_loop, const int& rf_axis, const Tensor& new_rf_tensor) + FinalMutator(const Expr& rf_loop, + const int& rf_axis, + const Tensor& new_rf_tensor) : rf_loop_(rf_loop), rf_axis_(rf_axis), new_rf_tensor_(new_rf_tensor) {} void operator()(Expr* expr) { auto* rf_for = rf_loop_.As(); @@ -440,26 +521,29 @@ struct FinalMutator : public ir::IRMutator<> { CHECK(node); auto* schedule_block = node->schedule_block.As(); CHECK(schedule_block); - auto& iter_vars = schedule_block->iter_vars; + auto& iter_vars = schedule_block->iter_vars; auto& iter_values = node->iter_values; - output_name_ = schedule_block->name; + output_name_ = schedule_block->name; visit_init_block_ = output_name_.rfind("_init") != std::string::npos; if (!visit_init_block_) { for (int i = 0; i < iter_values.size(); ++i) { if (ContainVar({iter_values[i]}, old_rf_loop_var_->name)) { // record the rfactor loop var's block var - CHECK(iter_values[i].As<_Var_>()) << "not support complex reduce bindings: " << iter_values[i]; + CHECK(iter_values[i].As<_Var_>()) + << "not support complex reduce bindings: " << iter_values[i]; old_rf_iter_var_ = iter_vars[i]; break; } } } IRMutator::Visit(&node->schedule_block, &node->schedule_block); - // modify iter_vars and iter_values, erase other reduce block vars and values + // modify iter_vars and iter_values, erase other reduce block vars and + // values for (auto it = iter_values.begin(); it != iter_values.end(); ++it) { for (auto erase_var : erase_reduce_loopvars_) { if (ContainVar({*it}, erase_var)) { - CHECK((*it).As<_Var_>()) << "not support complex reduce bindings: " << *it; + CHECK((*it).As<_Var_>()) + << "not support complex reduce bindings: " << *it; iter_vars.erase(it - iter_values.begin() + iter_vars.begin()); iter_values.erase(it); --it; @@ -474,28 +558,28 @@ struct FinalMutator : public ir::IRMutator<> { auto* node = expr->As(); CHECK(node); auto& oper_b = node->b(); - oper_b = Load::Make(new_rf_tensor_, new_rf_indice_); + oper_b = Load::Make(new_rf_tensor_, new_rf_indice_); } void Visit(const Mul* op, Expr* expr) override { auto* node = expr->As(); CHECK(node); auto& oper_b = node->b(); - oper_b = Load::Make(new_rf_tensor_, new_rf_indice_); + oper_b = Load::Make(new_rf_tensor_, new_rf_indice_); } void Visit(const Min* op, Expr* expr) override { auto* node = expr->As(); CHECK(node); auto& oper_b = node->b(); - oper_b = Load::Make(new_rf_tensor_, new_rf_indice_); + oper_b = Load::Make(new_rf_tensor_, new_rf_indice_); } void Visit(const Max* op, Expr* expr) override { auto* node = expr->As(); CHECK(node); auto& oper_b = node->b(); - oper_b = Load::Make(new_rf_tensor_, new_rf_indice_); + oper_b = Load::Make(new_rf_tensor_, new_rf_indice_); } void Visit(const Store* op, Expr* expr) override { @@ -504,13 +588,16 @@ struct FinalMutator : public ir::IRMutator<> { CHECK(node); auto* tensor = node->tensor.As<_Tensor_>(); CHECK(tensor); - CHECK_EQ(tensor->name, output_name_) << "store name should be same with the schedule block name"; + CHECK_EQ(tensor->name, output_name_) + << "store name should be same with the schedule block name"; if (!visit_init_block_) { new_rf_indice_ = node->indices; CHECK_LE(rf_axis_, new_rf_indice_.size()) - << "rf_axis_ should not be greater than tensor indice size " << new_rf_indice_.size(); + << "rf_axis_ should not be greater than tensor indice size " + << new_rf_indice_.size(); CHECK(old_rf_iter_var_.defined()); - new_rf_indice_.insert(new_rf_indice_.begin() + rf_axis_, old_rf_iter_var_); + new_rf_indice_.insert(new_rf_indice_.begin() + rf_axis_, + old_rf_iter_var_); IRMutator::Visit(&node->value, &node->value); } } @@ -561,7 +648,8 @@ struct RfCreater : public ir::IRMutator<> { CHECK(root_block); Expr root_loop = optim::IRCopy(root_block->body); if (auto block = root_loop.As()) { - CHECK_EQ(block->stmts.size(), 1U) << "rfactor root should only have one block stmt"; + CHECK_EQ(block->stmts.size(), 1U) + << "rfactor root should only have one block stmt"; root_loop = block->stmts[0]; } auto* root_for = root_loop.As(); @@ -578,8 +666,10 @@ struct RfCreater : public ir::IRMutator<> { Expr final_forloop = optim::IRCopy(root_loop); FinalMutator final_mutator(rf_loop_, rf_axis_, new_rf_tensor); final_mutator(&final_forloop); - VLOG(3) << "After FinalMuator, final write-back forloop is\n" << final_forloop; - // combine the new created rfactor forloops with the final write-back forloops and replace + VLOG(3) << "After FinalMuator, final write-back forloop is\n" + << final_forloop; + // combine the new created rfactor forloops with the final write-back + // forloops and replace root_block->body = Block::Make({new_rf_forloop, final_forloop}); return new_rf_tensor; } @@ -611,12 +701,14 @@ struct CacheReadRewriter : public ir::IRMutator<> { void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } private: - explicit CacheReadRewriter(const Expr& root, CacheBlockInfo* info) : root_(root), info_(info) {} + explicit CacheReadRewriter(const Expr& root, CacheBlockInfo* info) + : root_(root), info_(info) {} void Visit(const ir::Block* expr, Expr* op) override { if (*op == info_->loc_block) { IRMutator::Visit(expr, op); - op->As()->stmts.insert(op->As()->stmts.begin() + info_->loc_pos, info_->cache_block); + op->As()->stmts.insert( + op->As()->stmts.begin() + info_->loc_pos, info_->cache_block); } else { IRMutator::Visit(expr, op); } @@ -642,19 +734,24 @@ struct CacheWriteRewriter : public ir::IRMutator<> { public: static Expr Rewrite(const Expr& root, CacheBlockInfo* info) { CacheWriteRewriter rewriter(root, info); - Expr new_root = optim::IRCopy(root); + Expr new_root = optim::IRCopy(root); rewriter.mutate_cache_block = true; rewriter(&info->cache_block); rewriter.mutate_cache_block = false; rewriter(&new_root); auto find_tensor = ir::CollectIRNodesWithoutTensor( new_root, - [&](const Expr* x) { return x->As() && (x->As()->tensor == Expr(info->read_tensor)); }, + [&](const Expr* x) { + return x->As() && + (x->As()->tensor == Expr(info->read_tensor)); + }, true); if (!find_tensor.empty()) { - auto find_store = ir::CollectIRNodesWithoutTensor((*find_tensor.begin()), [&](const Expr* x) { - return x->As() && (x->As()->tensor == Expr(info->write_tensor)); - }); + auto find_store = ir::CollectIRNodesWithoutTensor( + (*find_tensor.begin()), [&](const Expr* x) { + return x->As() && + (x->As()->tensor == Expr(info->write_tensor)); + }); for (auto load_ir : find_store) { load_ir.As()->tensor = Expr(info->read_tensor); } @@ -665,12 +762,14 @@ struct CacheWriteRewriter : public ir::IRMutator<> { void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } private: - explicit CacheWriteRewriter(const Expr& root, CacheBlockInfo* info) : root_(root), info_(info) {} + explicit CacheWriteRewriter(const Expr& root, CacheBlockInfo* info) + : root_(root), info_(info) {} void Visit(const ir::Block* expr, Expr* op) override { if (*op == info_->loc_block) { IRMutator::Visit(expr, op); - op->As()->stmts.insert(op->As()->stmts.begin() + info_->loc_pos, info_->cache_block); + op->As()->stmts.insert( + op->As()->stmts.begin() + info_->loc_pos, info_->cache_block); } else { IRMutator::Visit(expr, op); } @@ -687,9 +786,11 @@ struct CacheWriteRewriter : public ir::IRMutator<> { void Visit(const ir::Load* expr, Expr* op) override { IRMutator::Visit(expr, op); - if (op->As()->tensor == Expr(info_->write_tensor) && mutate_cache_block) { + if (op->As()->tensor == Expr(info_->write_tensor) && + mutate_cache_block) { op->As()->tensor = Expr(info_->read_tensor); - } else if (op->As()->tensor == Expr(info_->read_tensor) && mutate_cache_block) { + } else if (op->As()->tensor == Expr(info_->read_tensor) && + mutate_cache_block) { op->As()->tensor = Expr(info_->write_tensor); } } @@ -698,7 +799,8 @@ struct CacheWriteRewriter : public ir::IRMutator<> { IRMutator::Visit(expr, op); if (op->As()->tensor == Expr(info_->write_tensor)) { op->As()->tensor = Expr(info_->read_tensor); - } else if (op->As()->tensor == Expr(info_->read_tensor) && mutate_cache_block) { + } else if (op->As()->tensor == Expr(info_->read_tensor) && + mutate_cache_block) { op->As()->tensor = Expr(info_->write_tensor); } } @@ -725,21 +827,24 @@ struct ChangeBodyToBlock : public ir::IRMutator<> { private: void Visit(const ir::ScheduleBlock* expr, Expr* op) override { if (!op->As()->body.As()) { - op->As()->body = Block::Make({op->As()->body}); + op->As()->body = + Block::Make({op->As()->body}); } IRMutator::Visit(expr, op); } }; DeviceAPI ScheduleImpl::GetDeviceAPI() const { - auto exprs = this->GetModule().GetExprs(); + auto exprs = this->GetModule().GetExprs(); auto find_for_nodes = ir::CollectIRNodesWithoutTensor( exprs.front(), [&](const Expr* x) { return x->As(); }, true); CHECK(!find_for_nodes.empty()); return (*find_for_nodes.begin()).As()->device_api; } -Expr ScheduleImpl::CacheRead(const Expr& block, int read_tensor_index, const std::string& memory_type) { +Expr ScheduleImpl::CacheRead(const Expr& block, + int read_tensor_index, + const std::string& memory_type) { CHECK(block.As()); auto root = GetRootBlock(block); ChangeBodyToBlock::Change(&root); @@ -747,20 +852,27 @@ Expr ScheduleImpl::CacheRead(const Expr& block, int read_tensor_index, const std CHECK(read_expr.As()); auto tensor_indices = read_expr.As()->indices; CacheBlockInfo info; - info.read_tensor = read_expr.As()->tensor.as_tensor_ref(); + info.read_tensor = read_expr.As()->tensor.as_tensor_ref(); info.write_tensor = MakeCacheTensor(info.read_tensor, memory_type); - info.alloc = info.write_tensor; + info.alloc = info.write_tensor; - auto read_ranges = CalculateTensorRegions(block, tensor_indices, info.read_tensor, root); - auto new_block = MakeCacheBlock(read_ranges, &info, memory_type, this->GetDeviceAPI()); + auto read_ranges = + CalculateTensorRegions(block, tensor_indices, info.read_tensor, root); + auto new_block = + MakeCacheBlock(read_ranges, &info, memory_type, this->GetDeviceAPI()); FindInsertionPoint(root, &info, false); auto new_root = CacheReadRewriter::Rewrite(root, &info); - this->Replace(root.As()->schedule_block.As()->body, - new_root.As()->schedule_block.As()->body); + this->Replace( + root.As()->schedule_block.As()->body, + new_root.As() + ->schedule_block.As() + ->body); return new_block; } -Expr ScheduleImpl::CacheWrite(const Expr& block, int write_buffer_index, const std::string& memory_type) { +Expr ScheduleImpl::CacheWrite(const Expr& block, + int write_buffer_index, + const std::string& memory_type) { CHECK(block.As()); auto root = GetRootBlock(block); ChangeBodyToBlock::Change(&root); @@ -769,21 +881,27 @@ Expr ScheduleImpl::CacheWrite(const Expr& block, int write_buffer_index, const s Tensor write_tensor = write_expr.As()->tensor.as_tensor_ref(); auto tensor_indices = write_expr.As()->indices; CacheBlockInfo info; - info.read_tensor = MakeCacheTensor(write_tensor, memory_type); + info.read_tensor = MakeCacheTensor(write_tensor, memory_type); info.write_tensor = write_tensor; - info.alloc = info.read_tensor; - auto write_ranges = CalculateTensorRegions(block, tensor_indices, info.write_tensor, root); - auto new_block = MakeCacheBlock(write_ranges, &info, memory_type, this->GetDeviceAPI()); + info.alloc = info.read_tensor; + auto write_ranges = + CalculateTensorRegions(block, tensor_indices, info.write_tensor, root); + auto new_block = + MakeCacheBlock(write_ranges, &info, memory_type, this->GetDeviceAPI()); FindInsertionPoint(root, &info, true); auto new_root = CacheWriteRewriter::Rewrite(root, &info); - this->Replace(root.As()->schedule_block.As()->body, - new_root.As()->schedule_block.As()->body); + this->Replace( + root.As()->schedule_block.As()->body, + new_root.As() + ->schedule_block.As() + ->body); auto find_cache_block = ir::CollectIRNodesWithoutTensor( root, [&](const Expr* x) { - return x->As() && !x->As()->iter_values.empty() && + return x->As() && + !x->As()->iter_values.empty() && GetTensor(*x)->name == info.read_tensor->name; }, true); @@ -791,11 +909,13 @@ Expr ScheduleImpl::CacheWrite(const Expr& block, int write_buffer_index, const s CHECK(info.write_tensor->buffer.defined()); // Replace buffer - auto all_tensors = ir::CollectIRNodesWithoutTensor( - root, [&](const Expr* x) { return x->as_tensor() && x->as_tensor()->buffer.defined(); }); + auto all_tensors = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { + return x->as_tensor() && x->as_tensor()->buffer.defined(); + }); for (auto i : all_tensors) { - if (i.as_tensor()->name != info.write_tensor->name && i.as_tensor()->buffer.defined() && + if (i.as_tensor()->name != info.write_tensor->name && + i.as_tensor()->buffer.defined() && i.as_tensor()->buffer->name == info.write_tensor->buffer->name) { i.as_tensor()->Bind(info.read_tensor->buffer); } @@ -808,7 +928,10 @@ Expr ScheduleImpl::CacheWrite(const Expr& block, int write_buffer_index, const s struct InsertExpr : public ir::IRMutator<> { public: - static void Insert(const Expr& ir_node, const Expr& insert_node, bool after_node, Expr* expr) { + static void Insert(const Expr& ir_node, + const Expr& insert_node, + bool after_node, + Expr* expr) { InsertExpr mutator(ir_node, insert_node, after_node); mutator(expr); } @@ -816,16 +939,20 @@ struct InsertExpr : public ir::IRMutator<> { void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } private: - explicit InsertExpr(const Expr& ir_node, const Expr& insert_node, bool after_node) + explicit InsertExpr(const Expr& ir_node, + const Expr& insert_node, + bool after_node) : ir_node_(ir_node), insert_node_(insert_node), after_node_(after_node) {} void Visit(const ir::Block* expr, Expr* op) override { for (int i = 0; i < expr->stmts.size(); i++) { if (expr->stmts[i] == ir_node_) { if (after_node_) { - op->As()->stmts.insert(op->As()->stmts.begin() + i + 1, insert_node_); + op->As()->stmts.insert( + op->As()->stmts.begin() + i + 1, insert_node_); } else { - op->As()->stmts.insert(op->As()->stmts.begin() + i, insert_node_); + op->As()->stmts.insert( + op->As()->stmts.begin() + i, insert_node_); } return; } @@ -836,9 +963,11 @@ struct InsertExpr : public ir::IRMutator<> { void Visit(const ir::For* expr, Expr* op) override { if (expr->body == ir_node_) { if (after_node_) - op->As()->body = ir::Block::Make({op->As()->body, insert_node_}); + op->As()->body = + ir::Block::Make({op->As()->body, insert_node_}); else - op->As()->body = ir::Block::Make({insert_node_, op->As()->body}); + op->As()->body = + ir::Block::Make({insert_node_, op->As()->body}); return; } IRMutator::Visit(expr, op); @@ -865,13 +994,16 @@ void ScheduleImpl::SyncThreads(const Expr& ir_node, bool after_node) { * @param tgt_stmt The For node we want. */ void ScheduleImpl::Replace(const Expr& src_sref, const Expr& tgt_stmt) { - CHECK(src_sref.As() || src_sref.As() || src_sref.As()); - CHECK(tgt_stmt.As() || tgt_stmt.As() || tgt_stmt.As()); + CHECK(src_sref.As() || src_sref.As() || + src_sref.As()); + CHECK(tgt_stmt.As() || tgt_stmt.As() || + tgt_stmt.As()); if (src_sref == tgt_stmt) { return; } struct ForLoopMutator : public ir::IRMutator<> { - ForLoopMutator(const Expr& source, const Expr& target) : source_(source), target_(target) {} + ForLoopMutator(const Expr& source, const Expr& target) + : source_(source), target_(target) {} void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } @@ -916,36 +1048,40 @@ Expr ScheduleImpl::Reorder(const std::vector& loops) { VLOG(4) << "Before Reorder, ir is:\n" << loops[0]; std::set loop_set = CollectLoopsToSet(loops); - auto boundary = GetBoundaryOfReorderRange(loop_set); - Expr top = boundary.first; - Expr bottom = boundary.second; - std::vector chain = GetLoopsInRange(top, bottom); - std::vector if_nodes = GetIfThenElseInRange(top, bottom); - Expr new_loop = ConstructNewLoopChain(chain, loops, loop_set, if_nodes); + auto boundary = GetBoundaryOfReorderRange(loop_set); + Expr top = boundary.first; + Expr bottom = boundary.second; + std::vector chain = GetLoopsInRange(top, bottom); + std::vector if_nodes = GetIfThenElseInRange(top, bottom); + Expr new_loop = ConstructNewLoopChain(chain, loops, loop_set, if_nodes); this->Replace(top, new_loop); VLOG(4) << "After Reorder, ir is:\n" << new_loop; return new_loop; } -Expr ScheduleImpl::Reorder(const std::string& block_name, const std::vector& loops_index) { +Expr ScheduleImpl::Reorder(const std::string& block_name, + const std::vector& loops_index) { std::vector all_loops = this->GetLoops(block_name); std::vector loops_expr; loops_expr.reserve(loops_index.size()); for (int i : loops_index) { - CHECK_LT(i, (int)all_loops.size()) << "The loop index in Reorder should be less than total loop's number."; + CHECK_LT(i, (int)all_loops.size()) + << "The loop index in Reorder should be less than total loop's number."; CHECK_GE(i, 0) << "The loop index in Reorder should be >= 0."; loops_expr.emplace_back(all_loops[i]); } return this->Reorder(loops_expr); } -Expr ScheduleImpl::Reorder(const Expr& block, const std::vector& loops_index) { +Expr ScheduleImpl::Reorder(const Expr& block, + const std::vector& loops_index) { std::vector all_loops = this->GetLoops(block); std::vector loops_expr; loops_expr.reserve(loops_index.size()); for (int i : loops_index) { - CHECK_LT(i, (int)all_loops.size()) << "The loop index in Reorder should be less than total loop's number."; + CHECK_LT(i, (int)all_loops.size()) + << "The loop index in Reorder should be less than total loop's number."; CHECK_GE(i, 0) << "The loop index in Reorder should be >= 0."; loops_expr.emplace_back(all_loops[i]); } @@ -956,7 +1092,11 @@ Expr ScheduleImpl::GetRootBlock(const Expr& expr) const { auto exprs = this->GetModule().GetExprs(); for (auto& it_expr : exprs) { auto find_expr = ir::CollectIRNodesWithoutTensor( - it_expr, [&](const Expr* x) { return x->node_type() == expr.node_type() && *x == expr; }, true); + it_expr, + [&](const Expr* x) { + return x->node_type() == expr.node_type() && *x == expr; + }, + true); if (!find_expr.empty()) { CHECK(it_expr.As()); CHECK_EQ(it_expr.As()->stmts.size(), 1U); @@ -964,22 +1104,29 @@ Expr ScheduleImpl::GetRootBlock(const Expr& expr) const { return it_expr.As()->stmts[0]; } } - LOG(FATAL) << "Didn't find expr \n" << expr << "in ScheduleImpl:\n" << exprs[0]; + LOG(FATAL) << "Didn't find expr \n" + << expr << "in ScheduleImpl:\n" + << exprs[0]; } // The struct used to reconstruct the new For node to replace the old For node. struct LoopReconstructor : public ir::IRMutator<> { public: - explicit LoopReconstructor(const Expr& root, const Expr& block, const Expr& loop) + explicit LoopReconstructor(const Expr& root, + const Expr& block, + const Expr& loop) : root_(root), block_(block), loop_(loop) {} void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } - /* \param inserted_pos The position index of the new_loop_ body `stmts` to be inserted: + /* \param inserted_pos The position index of the new_loop_ body `stmts` to be + * inserted: * - `index = -1` means inserted into the tail * - otherwise, it should be a index between [0, stmts size) */ - std::string MakeNewLoop(const std::vector& iter_ranges, bool keep_unit_loops, int inserted_pos = -1) { + std::string MakeNewLoop(const std::vector& iter_ranges, + bool keep_unit_loops, + int inserted_pos = -1) { int n_iters = iter_ranges.size(); std::vector loop_vars; std::vector loop_extents; @@ -991,7 +1138,8 @@ struct LoopReconstructor : public ir::IRMutator<> { for (int i = 0; i < n_iters; ++i) { const auto& range = iter_ranges[i]; if (keep_unit_loops || range.extent != Expr(1)) { - std::string var_name = common::UniqName("ax" + std::to_string(loop_vars.size())); + std::string var_name = + common::UniqName("ax" + std::to_string(loop_vars.size())); new_var_names.push_back(var_name); Var var(var_name, Int(32)); loop_vars.push_back(var); @@ -1001,15 +1149,21 @@ struct LoopReconstructor : public ir::IRMutator<> { iter_values.push_back(common::AutoSimplify(range.min)); } } - auto schedule_block_node = block_.As()->schedule_block; - new_block_ = ScheduleBlockRealize::Make(std::move(iter_values), std::move(schedule_block_node)); - Expr loop_body = new_block_; + auto schedule_block_node = + block_.As()->schedule_block; + new_block_ = ScheduleBlockRealize::Make(std::move(iter_values), + std::move(schedule_block_node)); + Expr loop_body = new_block_; for (int i = static_cast(loop_vars.size()) - 1; i >= 0; --i) { - auto loop_var = loop_vars[i]; + auto loop_var = loop_vars[i]; auto loop_extent = loop_extents[i]; if (!loop_body.As()) loop_body = Block::Make({loop_body}); - loop_body = For::Make( - loop_var, Expr(0), loop_extent, ForType::Serial, loop_.As()->device_api, std::move(loop_body)); + loop_body = For::Make(loop_var, + Expr(0), + loop_extent, + ForType::Serial, + loop_.As()->device_api, + std::move(loop_body)); } new_loop_ = optim::IRCopy(loop_); @@ -1023,13 +1177,17 @@ struct LoopReconstructor : public ir::IRMutator<> { } return false; }); - auto find_store = ir::CollectIRNodesWithoutTensor(new_loop_, [](const Expr* x) { return x->As(); }); + auto find_store = ir::CollectIRNodesWithoutTensor( + new_loop_, [](const Expr* x) { return x->As(); }); for (auto store : find_store) { - store.As()->tensor = tensors_map.at(store.As()->tensor.as_tensor()->name); + store.As()->tensor = + tensors_map.at(store.As()->tensor.as_tensor()->name); } - auto find_load = ir::CollectIRNodesWithoutTensor(new_loop_, [](const Expr* x) { return x->As(); }); + auto find_load = ir::CollectIRNodesWithoutTensor( + new_loop_, [](const Expr* x) { return x->As(); }); for (auto load : find_load) { - load.As()->tensor = tensors_map.at(load.As()->tensor.as_tensor()->name); + load.As()->tensor = + tensors_map.at(load.As()->tensor.as_tensor()->name); } InsertBlock(new_loop_, loop_body, inserted_pos); @@ -1048,55 +1206,65 @@ struct LoopReconstructor : public ir::IRMutator<> { Expr new_loop_{nullptr}; /*! \brief The new block realize to the moved block */ Expr new_block_{nullptr}; - /*! \brief The plan to remove the given block by replacing this loop/block in the AST */ + /*! \brief The plan to remove the given block by replacing this loop/block in + * the AST */ Expr source_expr{nullptr}; - /*! \brief The plan to remove the given block by replacing to this loop/block in the AST */ + /*! \brief The plan to remove the given block by replacing to this loop/block + * in the AST */ Expr target_expr{nullptr}; }; struct FixLocalBufferSize : public ir::IRMutator<> { public: - FixLocalBufferSize(const std::string& tensor_name) : tensor_name_(tensor_name) {} + FixLocalBufferSize(const std::string& tensor_name) + : tensor_name_(tensor_name) {} void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } private: void Visit(const ir::Store* expr, Expr* op) override { if (op->As()->tensor.As<_Tensor_>()->name == tensor_name_) { - op->As()->tensor.As<_Tensor_>()->shape = {Expr(1)}; - op->As()->tensor.As<_Tensor_>()->domain = {Expr(1)}; + op->As()->tensor.As<_Tensor_>()->shape = {Expr(1)}; + op->As()->tensor.As<_Tensor_>()->domain = {Expr(1)}; op->As()->tensor.As<_Tensor_>()->buffer->shape = {Expr(1)}; - op->As()->indices = {Expr(0)}; + op->As()->indices = {Expr(0)}; } IRMutator::Visit(expr, op); } void Visit(const ir::Load* expr, Expr* op) override { if (op->As()->tensor.As<_Tensor_>()->name == tensor_name_) { - op->As()->tensor.As<_Tensor_>()->shape = {Expr(1)}; - op->As()->tensor.As<_Tensor_>()->domain = {Expr(1)}; + op->As()->tensor.As<_Tensor_>()->shape = {Expr(1)}; + op->As()->tensor.As<_Tensor_>()->domain = {Expr(1)}; op->As()->tensor.As<_Tensor_>()->buffer->shape = {Expr(1)}; - op->As()->indices = {Expr(0)}; + op->As()->indices = {Expr(0)}; } IRMutator::Visit(expr, op); } std::string tensor_name_; }; -void ScheduleImpl::SetBuffer(Expr& block, const std::string& memory_type, bool fixed) { +void ScheduleImpl::SetBuffer(Expr& block, + const std::string& memory_type, + bool fixed) { CHECK(block.As()); auto find_tensor = ir::CollectIRNodesWithoutTensor( block, [&](const Expr* x) { return x->As(); }, true); - CHECK_EQ(find_tensor.size(), 1U) << "One block should only have one Store node!(except for root block)"; + CHECK_EQ(find_tensor.size(), 1U) + << "One block should only have one Store node!(except for root block)"; auto& tensor = (*find_tensor.begin()).As()->tensor; - tensor.as_tensor_ref()->WithBuffer(memory_type, "_" + tensor.as_tensor_ref()->name + "_temp_buffer"); + tensor.as_tensor_ref()->WithBuffer( + memory_type, "_" + tensor.as_tensor_ref()->name + "_temp_buffer"); auto exprs = this->GetModule().GetExprs(); for (auto& it_expr : exprs) { - auto find_tensor = ir::CollectIRNodesWithoutTensor(it_expr, [&](const Expr* x) { - return x->as_tensor() && (x->as_tensor()->name == tensor.as_tensor_ref()->name || - x->as_tensor()->name == tensor.as_tensor_ref()->name + "__reduce_init"); - }); + auto find_tensor = + ir::CollectIRNodesWithoutTensor(it_expr, [&](const Expr* x) { + return x->as_tensor() && + (x->as_tensor()->name == tensor.as_tensor_ref()->name || + x->as_tensor()->name == + tensor.as_tensor_ref()->name + "__reduce_init"); + }); for (auto& t : find_tensor) { CHECK(t.as_tensor()); t.as_tensor_ref()->Bind(tensor.as_tensor_ref()->buffer); @@ -1105,7 +1273,9 @@ void ScheduleImpl::SetBuffer(Expr& block, const std::string& memory_type, bool f // if buffer type == "local" if (memory_type == "local" && fixed) { - FixLocalBufferSize mutator(block.As()->schedule_block.As()->name); + FixLocalBufferSize mutator(block.As() + ->schedule_block.As() + ->name); auto root = GetRootBlock(block); mutator(&root); } @@ -1117,21 +1287,32 @@ void ScheduleImpl::MergeExprs() { CHECK(exprs[0].As()); CHECK_EQ(exprs[0].As()->stmts.size(), 1U); CHECK(exprs[0].As()->stmts[0].As()); - CHECK(exprs[0].As()->stmts[0].As()->schedule_block.As()); + CHECK(exprs[0] + .As() + ->stmts[0] + .As() + ->schedule_block.As()); std::vector merged_block; - merged_block.push_back( - exprs[0].As()->stmts[0].As()->schedule_block.As()->body); + merged_block.push_back(exprs[0] + .As() + ->stmts[0] + .As() + ->schedule_block.As() + ->body); VLOG(3) << "Before merging, exprs[0] is : " << exprs[0]; for (int i = 1; i < exprs.size(); ++i) { auto root_block = ir::CollectIRNodesWithoutTensor( exprs[i], [&](const Expr* x) { - return x->As() && x->As()->iter_values.empty(); + return x->As() && + x->As()->iter_values.empty(); }, true); CHECK_EQ(root_block.size(), 1U); for (auto& it_block : root_block) { - auto& block_body = it_block.As()->schedule_block.As()->body; + auto& block_body = it_block.As() + ->schedule_block.As() + ->body; merged_block.push_back(block_body); } } @@ -1139,14 +1320,20 @@ void ScheduleImpl::MergeExprs() { VLOG(3) << "in merged_block, it has " << block; } auto merged_expr = ir::Block::Make(merged_block); - exprs[0].As()->stmts[0].As()->schedule_block.As()->body = - merged_expr; + exprs[0] + .As() + ->stmts[0] + .As() + ->schedule_block.As() + ->body = merged_expr; VLOG(3) << "After merging, exprs[0] is : " << exprs[0]; exprs.erase(exprs.begin() + 1, exprs.end()); this->SetExprs(exprs); } -void ScheduleImpl::ComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops) { +void ScheduleImpl::ComputeAt(const Expr& block, + const Expr& loop, + bool keep_unit_loops) { CHECK(block.As()); CHECK(loop.As()); Expr root = this->GetRootBlock(block); @@ -1157,12 +1344,15 @@ void ScheduleImpl::ComputeAt(const Expr& block, const Expr& loop, bool keep_unit auto consumers = GetConsumers(block, root); CheckComputeAtValidation(block, loop, root); LoopReconstructor reconstructor(root, block, loop); - LeafBlockRemovalPlan remove_plan(block, &reconstructor.source_expr, &reconstructor.target_expr); + LeafBlockRemovalPlan remove_plan( + block, &reconstructor.source_expr, &reconstructor.target_expr); remove_plan(&root); - auto iter_ranges = CalculateRequiredRegions(block, loop, root, consumers); - std::string new_var_names = reconstructor.MakeNewLoop(iter_ranges, keep_unit_loops, 0); - auto sch_block_expr = block.As()->schedule_block; - sch_block_expr.As()->attrs.emplace(ir::attr::compute_at_extra_var, new_var_names); + auto iter_ranges = CalculateRequiredRegions(block, loop, root, consumers); + std::string new_var_names = + reconstructor.MakeNewLoop(iter_ranges, keep_unit_loops, 0); + auto sch_block_expr = block.As()->schedule_block; + sch_block_expr.As()->attrs.emplace( + ir::attr::compute_at_extra_var, new_var_names); this->Replace(reconstructor.source_expr, reconstructor.target_expr); this->Replace(reconstructor.loop_, reconstructor.new_loop_); @@ -1173,25 +1363,28 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) { CHECK(block.As()); CHECK(loop.As()); std::vector block_loops = this->GetLoops(block); - Expr root = this->GetRootBlock(block); - auto loops = GetLoopsOfExpr(loop, root); + Expr root = this->GetRootBlock(block); + auto loops = GetLoopsOfExpr(loop, root); - VLOG(3) << "Begin SimpleComputeAt of loop:\n" << loop << "\nat block:\n" << root; + VLOG(3) << "Begin SimpleComputeAt of loop:\n" + << loop << "\nat block:\n" + << root; - auto this_loop = loop; + auto this_loop = loop; auto block_name = GetTensor(block)->name; auto this_block = block; if (GetLoopExtent(loops[0]) == 1 && GetLoopExtent(block_loops[0]) != 1) { this->Split(block_loops[0], {1, -1}); this_block = this->GetBlock(block_name); - } else if (GetLoopExtent(loops[0]) != 1 && GetLoopExtent(block_loops[0]) == 1) { + } else if (GetLoopExtent(loops[0]) != 1 && + GetLoopExtent(block_loops[0]) == 1) { auto splited = this->Split(loops[0], {1, -1}); - this_loop = splited[1]; + this_loop = splited[1]; } block_loops = this->GetLoops(this_block); - root = this->GetRootBlock(this_block); - loops = GetLoopsOfExpr(this_loop, root); + root = this->GetRootBlock(this_block); + loops = GetLoopsOfExpr(this_loop, root); CHECK_LE(loops.size(), block_loops.size()); @@ -1199,29 +1392,35 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) { std::vector substitute_expr; for (int i = 0; i < loops.size(); ++i) { CHECK_EQ(GetLoopExtent(loops[i]), GetLoopExtent(block_loops[i])); - if (block_loops[i].As()->bind_info().valid() && !loops[i].As()->bind_info().valid()) { - loops[i].As()->set_bind_info(block_loops[i].As()->bind_info()); + if (block_loops[i].As()->bind_info().valid() && + !loops[i].As()->bind_info().valid()) { + loops[i].As()->set_bind_info( + block_loops[i].As()->bind_info()); } replaced_var.push_back(block_loops[i].As()->loop_var); substitute_expr.push_back(Expr(loops[i].As()->loop_var)); } - Expr result = - loops.size() < block_loops.size() ? optim::IRCopy(block_loops[loops.size()]) : optim::IRCopy(this_block); + Expr result = loops.size() < block_loops.size() + ? optim::IRCopy(block_loops[loops.size()]) + : optim::IRCopy(this_block); Expr new_loop = optim::IRCopy(this_loop); // Get the body of block_loop under the same loops auto body = block_loops.at(loops.size() - 1).As()->body; // collect if auto if_checker = [](const Expr* x) { return x->As(); }; - auto if_set = ir::CollectIRNodesWithoutTensor(body, if_checker); + auto if_set = ir::CollectIRNodesWithoutTensor(body, if_checker); for (auto if_expr : if_set) { auto checker = [block_name](const Expr* x) { return x->As() && - x->As()->schedule_block.As()->name == block_name; + x->As() + ->schedule_block.As() + ->name == block_name; }; if (ir::CollectIRNodesWithoutTensor(if_expr, checker, true).size() > 0) { - result = IfThenElse::Make(if_expr.As()->condition, result); + result = + IfThenElse::Make(if_expr.As()->condition, result); break; } } @@ -1229,26 +1428,41 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) { ReplaceExpr(&result, replaced_var, substitute_expr); // When there are two identical IfThenElse if (new_loop.As() && new_loop.As()->body.As() && - new_loop.As()->body.As()->stmts[0].As()) { + new_loop.As() + ->body.As() + ->stmts[0] + .As()) { auto if_then_else = new_loop.As()->body.As()->stmts[0]; if (result.As() && - if_then_else.As()->condition == result.As()->condition) { - new_loop.As()->body.As()->stmts[0].As()->true_case = - ir::Block::Make({result.As()->true_case, - new_loop.As()->body.As()->stmts[0].As()->true_case}); + if_then_else.As()->condition == + result.As()->condition) { + new_loop.As() + ->body.As() + ->stmts[0] + .As() + ->true_case = ir::Block::Make({result.As()->true_case, + new_loop.As() + ->body.As() + ->stmts[0] + .As() + ->true_case}); } else { - std::vector::iterator pos = new_loop.As()->body.As()->stmts.begin(); + std::vector::iterator pos = + new_loop.As()->body.As()->stmts.begin(); new_loop.As()->body.As()->stmts.insert(pos, result); } } else { - new_loop.As()->body = ir::Block::Make({result, new_loop.As()->body}); + new_loop.As()->body = + ir::Block::Make({result, new_loop.As()->body}); } Expr source_expr{nullptr}; Expr target_expr{nullptr}; LeafBlockRemovalPlan remove_plan( - result.As() ? block_loops[loops.size()] : this_block, &source_expr, &target_expr); + result.As() ? block_loops[loops.size()] : this_block, + &source_expr, + &target_expr); remove_plan(&root); this->Replace(source_expr, target_expr); @@ -1257,20 +1471,26 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) { VLOG(3) << "After SimpleComputeAt, ir is:\n" << new_loop; } -void ScheduleImpl::ReverseComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops) { +void ScheduleImpl::ReverseComputeAt(const Expr& block, + const Expr& loop, + bool keep_unit_loops) { CHECK(block.As()); CHECK(loop.As()); - Expr root = this->GetRootBlock(block); + Expr root = this->GetRootBlock(block); auto producers = GetProducers(block, root); auto consumers = GetConsumers(block, root); CheckComputeAtValidation(block, loop, root); LoopReconstructor reconstructor(root, block, loop); - LeafBlockRemovalPlan remove_plan(block, &reconstructor.source_expr, &reconstructor.target_expr); + LeafBlockRemovalPlan remove_plan( + block, &reconstructor.source_expr, &reconstructor.target_expr); remove_plan(&root); - auto iter_ranges = CalculateRequiredRegions(block, loop, root, producers, false); - std::string new_var_names = reconstructor.MakeNewLoop(iter_ranges, keep_unit_loops, -1); - auto sch_block_expr = block.As()->schedule_block; - sch_block_expr.As()->attrs.emplace(ir::attr::reverse_compute_at_extra_var, new_var_names); + auto iter_ranges = + CalculateRequiredRegions(block, loop, root, producers, false); + std::string new_var_names = + reconstructor.MakeNewLoop(iter_ranges, keep_unit_loops, -1); + auto sch_block_expr = block.As()->schedule_block; + sch_block_expr.As()->attrs.emplace( + ir::attr::reverse_compute_at_extra_var, new_var_names); this->Replace(reconstructor.source_expr, reconstructor.target_expr); this->Replace(reconstructor.loop_, reconstructor.new_loop_); return; @@ -1289,7 +1509,8 @@ void BaseInliner::Visit(const ir::Block* expr, Expr* op) { IRMutator::Visit(expr, op); } -bool BaseInliner::UpdateAndCheckIndexVars(const std::vector& indices, int expected_ndim) { +bool BaseInliner::UpdateAndCheckIndexVars(const std::vector& indices, + int expected_ndim) { int n = indices.size(); if (n != expected_ndim) { return false; @@ -1334,7 +1555,8 @@ bool ComputeInliner::BodyPatternAllowInline() { return false; } CHECK(inlined_store_.As()); - auto find_vars = ir::CollectIRNodesWithoutTensor(inlined_store_, [&](const Expr* x) { return x->as_var(); }); + auto find_vars = ir::CollectIRNodesWithoutTensor( + inlined_store_, [&](const Expr* x) { return x->as_var(); }); std::set vars_set; for (auto& i : find_vars) vars_set.insert(i.as_var_ref()); int n_vars = vars_set.size(); @@ -1363,12 +1585,13 @@ Expr ComputeInliner::ReplaceInlinedTensor(Expr* load) { void ScheduleImpl::ComputeInline(const Expr& schedule_block) { CHECK(schedule_block.As()); - Expr root = this->GetRootBlock(schedule_block); + Expr root = this->GetRootBlock(schedule_block); Expr store = CheckComputeInlineValidationAndGetStore(schedule_block, root); ComputeInliner inliner(store.As()->tensor.as_tensor_ref(), store); CHECK(inliner.BodyPatternAllowInline()); // Create a plan that removes the block to be inlined - LeafBlockRemovalPlan remove_plan(schedule_block, &inliner.src_stmt, &inliner.tgt_stmt); + LeafBlockRemovalPlan remove_plan( + schedule_block, &inliner.src_stmt, &inliner.tgt_stmt); remove_plan(&root); inliner(&root); return; @@ -1376,7 +1599,7 @@ void ScheduleImpl::ComputeInline(const Expr& schedule_block) { bool ComputeInlineChecker::Check() { Expr root = ir_schedule_.GetRootBlock(block_); - store_ = CheckComputeInlineValidationAndGetStore(block_, root); + store_ = CheckComputeInlineValidationAndGetStore(block_, root); IRMutator::Visit(&root, &root); return !should_skip_; } @@ -1400,7 +1623,8 @@ bool ReverseComputeInliner::BodyPatternAllowInline() { CHECK(inlined_store_.As()); CHECK(inlined_load_.As()); CHECK(target_store_.As()); - auto find_vars = ir::CollectIRNodesWithoutTensor(inlined_store_, [&](const Expr* x) { return x->as_var(); }); + auto find_vars = ir::CollectIRNodesWithoutTensor( + inlined_store_, [&](const Expr* x) { return x->as_var(); }); std::set vars_set; for (auto& i : find_vars) vars_set.insert(i.as_var_ref()); int n_vars = vars_set.size(); @@ -1451,16 +1675,21 @@ Expr ReverseComputeInliner::ReplaceTargetTensor(Expr* store) { } void ScheduleImpl::ReverseComputeInline(const Expr& schedule_block) { - Expr root = this->GetRootBlock(schedule_block); - auto exprs = CheckReverseComputeInlineValidationAndGetExprs(schedule_block, root); - Expr inlined_load = std::get<0>(exprs); + Expr root = this->GetRootBlock(schedule_block); + auto exprs = + CheckReverseComputeInlineValidationAndGetExprs(schedule_block, root); + Expr inlined_load = std::get<0>(exprs); Expr inlined_store = std::get<1>(exprs); - Expr target_store = std::get<2>(exprs); + Expr target_store = std::get<2>(exprs); ReverseComputeInliner inliner( - inlined_store.As()->tensor.as_tensor_ref(), inlined_store, inlined_load, target_store); + inlined_store.As()->tensor.as_tensor_ref(), + inlined_store, + inlined_load, + target_store); CHECK(inliner.BodyPatternAllowInline()); // Create a plan that removes the block to be inlined - LeafBlockRemovalPlan remove_plan(schedule_block, &inliner.src_stmt, &inliner.tgt_stmt); + LeafBlockRemovalPlan remove_plan( + schedule_block, &inliner.src_stmt, &inliner.tgt_stmt); remove_plan(&root); inliner(&root); inliner(&root); @@ -1477,7 +1706,9 @@ struct FindBlockParent : public ir::IRMutator<> { if (target_) return; for (auto& stmt : expr->stmts) { if (stmt.As()) { - if (stmt.As()->schedule_block.As()->name == block_name_) { + if (stmt.As() + ->schedule_block.As() + ->name == block_name_) { target_ = op; return; } @@ -1489,7 +1720,9 @@ struct FindBlockParent : public ir::IRMutator<> { void Visit(const ir::For* expr, Expr* op) override { if (target_) return; if (expr->body.As()) { - if (expr->body.As()->schedule_block.As()->name == block_name_) { + if (expr->body.As() + ->schedule_block.As() + ->name == block_name_) { target_ = op; return; } @@ -1500,7 +1733,9 @@ struct FindBlockParent : public ir::IRMutator<> { void Visit(const ir::ScheduleBlock* expr, Expr* op) override { if (target_) return; if (expr->body.As()) { - if (expr->body.As()->schedule_block.As()->name == block_name_) { + if (expr->body.As() + ->schedule_block.As() + ->name == block_name_) { target_ = op; return; } @@ -1517,8 +1752,11 @@ struct FindBlockParent : public ir::IRMutator<> { Expr ScheduleImpl::AddUnitLoop(const Expr& block) const { auto exprs = module_expr_.GetExprs(); CHECK(block.As()); - CHECK(block.As()->schedule_block.As()); - std::string block_name = block.As()->schedule_block.As()->name; + CHECK(block.As() + ->schedule_block.As()); + std::string block_name = block.As() + ->schedule_block.As() + ->name; FindBlockParent visitor(block_name); for (auto expr : exprs) { @@ -1532,29 +1770,40 @@ Expr ScheduleImpl::AddUnitLoop(const Expr& block) const { if (visitor.target_->As()) { for (auto& stmt : visitor.target_->As()->stmts) { if (stmt.As()) { - if (stmt.As()->schedule_block.As()->name == block_name) { + if (stmt.As() + ->schedule_block.As() + ->name == block_name) { auto block = ir::Block::Make({GetBlock(block_name)}); - auto loop = ir::For::Make(ir::Var(common::UniqName("ix")), + auto loop = ir::For::Make(ir::Var(common::UniqName("ix")), ir::Expr(0), ir::Expr(1), ir::ForType::Serial, ir::DeviceAPI::UNK, block); - stmt = loop; + stmt = loop; return loop; } } } } else if (visitor.target_->As()) { auto block = ir::Block::Make({visitor.target_->As()->body}); - auto loop = ir::For::Make( - ir::Var(common::UniqName("ix")), ir::Expr(0), ir::Expr(1), ir::ForType::Serial, ir::DeviceAPI::UNK, block); + auto loop = ir::For::Make(ir::Var(common::UniqName("ix")), + ir::Expr(0), + ir::Expr(1), + ir::ForType::Serial, + ir::DeviceAPI::UNK, + block); visitor.target_->As()->body = loop; return loop; } else if (visitor.target_->As()) { - auto block = ir::Block::Make({visitor.target_->As()->body}); - auto loop = ir::For::Make( - ir::Var(common::UniqName("ix")), ir::Expr(0), ir::Expr(1), ir::ForType::Serial, ir::DeviceAPI::UNK, block); + auto block = + ir::Block::Make({visitor.target_->As()->body}); + auto loop = ir::For::Make(ir::Var(common::UniqName("ix")), + ir::Expr(0), + ir::Expr(1), + ir::ForType::Serial, + ir::DeviceAPI::UNK, + block); visitor.target_->As()->body = loop; return loop; } else { @@ -1568,14 +1817,19 @@ std::vector ScheduleImpl::GetLoops(const Expr& block) const { std::vector result; auto exprs = module_expr_.GetExprs(); CHECK(block.As()); - CHECK(block.As()->schedule_block.As()); - std::string block_name = block.As()->schedule_block.As()->name; + CHECK(block.As() + ->schedule_block.As()); + std::string block_name = block.As() + ->schedule_block.As() + ->name; for (auto& it_expr : exprs) { ir::FindLoopsVisitor visitor(block); auto find_loops = visitor(&it_expr); if (!find_loops.empty()) { - if (!result.empty()) LOG(FATAL) << "Find block with name: \n" << block_name << " appeared in more than one AST!"; + if (!result.empty()) + LOG(FATAL) << "Find block with name: \n" + << block_name << " appeared in more than one AST!"; result = find_loops; } } @@ -1587,7 +1841,7 @@ std::vector ScheduleImpl::GetLoops(const Expr& block) const { } std::vector ScheduleImpl::GetLoops(const std::string& block_name) const { - Expr block = this->GetBlock(block_name); + Expr block = this->GetBlock(block_name); std::vector result = this->GetLoops(block); return result; } @@ -1620,7 +1874,8 @@ bool ScheduleImpl::HasBlock(const std::string& block_name) const { ir::FindBlocksVisitor visitor(block_name); auto find_blocks = visitor(&it_expr); if (!find_blocks.empty()) { - CHECK_EQ(find_blocks.size(), 1U) << "There should not be more than 1 block with identical name!"; + CHECK_EQ(find_blocks.size(), 1U) + << "There should not be more than 1 block with identical name!"; return true; } } @@ -1634,27 +1889,35 @@ Expr ScheduleImpl::GetBlock(const std::string& block_name) const { ir::FindBlocksVisitor visitor(block_name); auto find_blocks = visitor(&it_expr); if (!find_blocks.empty()) { - CHECK_EQ(find_blocks.size(), 1U) << "There should not be more than 1 block with identical name!"; + CHECK_EQ(find_blocks.size(), 1U) + << "There should not be more than 1 block with identical name!"; result = find_blocks[0]; return result; } } - LOG(FATAL) << "Didn't find a block with name " << block_name << " in this ModuleExpr!"; + LOG(FATAL) << "Didn't find a block with name " << block_name + << " in this ModuleExpr!"; } -void ScheduleImpl::Annotate(const Expr& block, const std::string& key, const attr_t& value) { +void ScheduleImpl::Annotate(const Expr& block, + const std::string& key, + const attr_t& value) { CHECK(block.As()); - CHECK(block.As()->schedule_block.As()); - auto copied_block = optim::IRCopy(block); - auto* schedule_block = copied_block.As()->schedule_block.As(); + CHECK(block.As() + ->schedule_block.As()); + auto copied_block = optim::IRCopy(block); + auto* schedule_block = copied_block.As() + ->schedule_block.As(); schedule_block->attrs.emplace(key, value); this->Replace(block, copied_block); } void ScheduleImpl::Unannotate(Expr& block, const std::string& ann_key) { CHECK(block.As()); - CHECK(block.As()->schedule_block.As()); - auto* schedule_block = block.As()->schedule_block.As(); + CHECK(block.As() + ->schedule_block.As()); + auto* schedule_block = block.As() + ->schedule_block.As(); if (schedule_block->attrs.count(ann_key)) { schedule_block->attrs.erase(ann_key); } else { @@ -1663,7 +1926,8 @@ void ScheduleImpl::Unannotate(Expr& block, const std::string& ann_key) { } } -void ScheduleImpl::FlattenLoops(const std::vector& loops, const bool flat_tensor) { +void ScheduleImpl::FlattenLoops(const std::vector& loops, + const bool flat_tensor) { CHECK_GT(loops.size(), 0) << "Loops can't be empty!"; VLOG(4) << "Before FlattenLoops, ir is:\n" << loops[0]; // compute loop @@ -1679,9 +1943,14 @@ void ScheduleImpl::FlattenLoops(const std::vector& loops, const bool flat_ // create new loop. auto last = loops.back().As(); - auto var = ir::Var("flat_i"); + auto var = ir::Var("flat_i"); auto _var = ir::Var("_flat_i"); - auto loop = ir::For::Make(var, ir::Expr(0), ir::Expr(extent), last->for_type(), last->device_api, last->body); + auto loop = ir::For::Make(var, + ir::Expr(0), + ir::Expr(extent), + last->for_type(), + last->device_api, + last->body); // map loop var to old loop var. auto _iter = ir::Expr(_var); @@ -1692,14 +1961,16 @@ void ScheduleImpl::FlattenLoops(const std::vector& loops, const bool flat_ loops_to_flat_var_map[loops[idx].As()->loop_var->name] = _iter; } else { // flat_i_to_loop_var.push_back(_iter / Expr(strides[idx])); - loops_to_flat_var_map[loops[idx].As()->loop_var->name] = _iter / Expr(strides[idx]); - _iter = _iter % Expr(strides[idx]); + loops_to_flat_var_map[loops[idx].As()->loop_var->name] = + _iter / Expr(strides[idx]); + _iter = _iter % Expr(strides[idx]); } } ir::FindBlocksVisitor visitor; - auto blocks = visitor(&last->body); - auto can_do_flat = [](const std::vector& indexs, const std::vector& loop_vars) { + auto blocks = visitor(&last->body); + auto can_do_flat = [](const std::vector& indexs, + const std::vector& loop_vars) { if (indexs.size() != loop_vars.size()) { return false; } @@ -1719,7 +1990,7 @@ void ScheduleImpl::FlattenLoops(const std::vector& loops, const bool flat_ // change blocks iter value/iter var for (auto& block : blocks) { - auto block_realize = block.As(); + auto block_realize = block.As(); auto schedule_block = block_realize->schedule_block.As(); // checkout loops in orders. @@ -1729,15 +2000,17 @@ void ScheduleImpl::FlattenLoops(const std::vector& loops, const bool flat_ for (int idx = 0; idx < block_realize->iter_values.size(); ++idx) { auto& iter = block_realize->iter_values[idx]; if (iter.is_var()) { - CHECK_EQ(iter.as_var_ref()->name, loop_vars[idx]->name) << "loops is not the same order with tensor!"; + CHECK_EQ(iter.as_var_ref()->name, loop_vars[idx]->name) + << "loops is not the same order with tensor!"; } else { CHECK(iter.As()); CHECK_EQ(iter.as_int32(), 0); } } - auto exprs = ir::CollectIRNodesInOrder(schedule_block->body, - [&](const Expr* x) { return x->As() || x->As(); }); + auto exprs = ir::CollectIRNodesInOrder( + schedule_block->body, + [&](const Expr* x) { return x->As() || x->As(); }); // reverse exprs from last to first. std::reverse(std::begin(exprs), std::end(exprs)); @@ -1748,7 +2021,8 @@ void ScheduleImpl::FlattenLoops(const std::vector& loops, const bool flat_ if (block_realize->iter_values[idx].is_var()) { var_to_replace.push_back(schedule_block->iter_vars[idx]); auto var_name = block_realize->iter_values[idx].as_var_ref()->name; - CHECK(loops_to_flat_var_map.count(var_name)) << "Can't find var name : " << var_name; + CHECK(loops_to_flat_var_map.count(var_name)) + << "Can't find var name : " << var_name; flat_i_to_loop_var.push_back(loops_to_flat_var_map[var_name]); } else { CHECK_EQ(block_realize->iter_values[idx].as_int32(), 0); @@ -1765,10 +2039,15 @@ void ScheduleImpl::FlattenLoops(const std::vector& loops, const bool flat_ if (store->is_addr_tensor()) { auto t = store->tensor.as_tensor_ref(); CHECK(!t->reduce_axis.size()); - auto tsize = std::accumulate(t->shape.begin(), t->shape.end(), 1, [](const int sum, const Expr& expr) { - return sum * expr.as_int32(); - }); - if ((!flat_tensor && !can_do_flat(store->indices, schedule_block->iter_vars)) || extent != tsize) { + auto tsize = std::accumulate(t->shape.begin(), + t->shape.end(), + 1, + [](const int sum, const Expr& expr) { + return sum * expr.as_int32(); + }); + if ((!flat_tensor && + !can_do_flat(store->indices, schedule_block->iter_vars)) || + extent != tsize) { // just replace indexs for (auto& indice : store->indices) { if (!indice.is_var()) { @@ -1788,10 +2067,15 @@ void ScheduleImpl::FlattenLoops(const std::vector& loops, const bool flat_ if (load->is_addr_tensor()) { auto t = load->tensor.as_tensor_ref(); CHECK(!t->reduce_axis.size()); - auto tsize = std::accumulate(t->shape.begin(), t->shape.end(), 1, [](const int sum, const Expr& expr) { - return sum * expr.as_int32(); - }); - if ((!flat_tensor && !can_do_flat(load->indices, schedule_block->iter_vars)) || extent != tsize) { + auto tsize = std::accumulate(t->shape.begin(), + t->shape.end(), + 1, + [](const int sum, const Expr& expr) { + return sum * expr.as_int32(); + }); + if ((!flat_tensor && + !can_do_flat(load->indices, schedule_block->iter_vars)) || + extent != tsize) { // just replace indexs for (auto& indice : load->indices) { if (!indice.is_var()) { @@ -1811,40 +2095,51 @@ void ScheduleImpl::FlattenLoops(const std::vector& loops, const bool flat_ ReplaceExpr(&schedule_block->body, var_to_replace, flat_i_to_loop_var); // update iter values - auto iter = ir::Expr(var); + auto iter = ir::Expr(var); block_realize->iter_values = {iter}; // update iter_vars schedule_block->iter_vars = {_var}; - CHECK_EQ(block_realize->iter_values.size(), schedule_block->iter_vars.size()); + CHECK_EQ(block_realize->iter_values.size(), + schedule_block->iter_vars.size()); } this->Replace(loops[0], loop); VLOG(4) << "After FlattenLoops, ir is:\n" << loop; } -void ScheduleImpl::CopyTransformAndLoopInfo(const std::string& block_name, const std::string& block_target_name) { - auto block = this->GetBlock(block_name); +void ScheduleImpl::CopyTransformAndLoopInfo( + const std::string& block_name, const std::string& block_target_name) { + auto block = this->GetBlock(block_name); auto block_target = this->GetBlock(block_target_name); this->CopyTransformAndLoopInfo(block, block_target); } -void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, const Expr& block_target) { +void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, + const Expr& block_target) { CHECK(block.As()); CHECK(block_target.As()); auto exprs = this->GetModule().GetExprs(); CHECK_EQ(exprs.size(), 1U); - auto expr = exprs[0]; - auto vars = block.As()->schedule_block.As()->iter_vars; - auto vars_target = block_target.As()->schedule_block.As()->iter_vars; + auto expr = exprs[0]; + auto vars = block.As() + ->schedule_block.As() + ->iter_vars; + auto vars_target = block_target.As() + ->schedule_block.As() + ->iter_vars; auto old_iter_values = block.As()->iter_values; - auto iter_values_target = block_target.As()->iter_values; + auto iter_values_target = + block_target.As()->iter_values; std::vector new_iter_values; for (int i = 0; i < vars.size() && i < vars_target.size(); ++i) { - CHECK(vars[i]->upper_bound.defined() && vars_target[i]->upper_bound.defined()); - if (vars[i]->upper_bound.is_constant() && vars_target[i]->upper_bound.is_constant() && - vars[i]->upper_bound.get_constant() == vars_target[i]->upper_bound.get_constant() && !vars[i]->is_reduce_axis && - !vars_target[i]->is_reduce_axis) { + CHECK(vars[i]->upper_bound.defined() && + vars_target[i]->upper_bound.defined()); + if (vars[i]->upper_bound.is_constant() && + vars_target[i]->upper_bound.is_constant() && + vars[i]->upper_bound.get_constant() == + vars_target[i]->upper_bound.get_constant() && + !vars[i]->is_reduce_axis && !vars_target[i]->is_reduce_axis) { new_iter_values.push_back(iter_values_target[i]); VLOG(3) << "new_iter_values.push_back " << iter_values_target[i]; } else @@ -1852,16 +2147,19 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, const Expr& block } if (new_iter_values.empty()) - LOG(FATAL) << "Cannot CopyTransformAndLoopInfo since shape[0] of source and target is not equal! " - << vars[0]->upper_bound << " v.s " << vars_target[0]->upper_bound; + LOG(FATAL) << "Cannot CopyTransformAndLoopInfo since shape[0] of source " + "and target is not equal! " + << vars[0]->upper_bound << " v.s " + << vars_target[0]->upper_bound; int changed_loop_num = new_iter_values.size(); std::set used_target_loop_vars; for (auto& iter_val : new_iter_values) { - auto find_partial_loop = ir::CollectIRNodesWithoutTensor(iter_val, [&](const Expr* x) { - if (x->as_var()) used_target_loop_vars.insert(x->as_var_ref()->name); - return x->as_var(); - }); + auto find_partial_loop = + ir::CollectIRNodesWithoutTensor(iter_val, [&](const Expr* x) { + if (x->as_var()) used_target_loop_vars.insert(x->as_var_ref()->name); + return x->as_var(); + }); } CHECK(!used_target_loop_vars.empty()); std::vector used_target_loops; @@ -1870,16 +2168,18 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, const Expr& block auto find_loop_var = ir::CollectIRNodesWithoutTensor( expr_copy, [&](const Expr* x) { - return x->As() && x->As()->loop_var->name == var && Contains(*x, block_target); + return x->As() && x->As()->loop_var->name == var && + Contains(*x, block_target); }, true); CHECK_EQ(find_loop_var.size(), 1U); used_target_loops.push_back(*find_loop_var.begin()); VLOG(3) << "used_target_loops push_back " << used_target_loops.back(); } - std::sort(used_target_loops.begin(), used_target_loops.end(), [&](Expr i, Expr j) { - return (utils::GetStreamCnt(i).size() > utils::GetStreamCnt(j).size()); - }); + std::sort( + used_target_loops.begin(), used_target_loops.end(), [&](Expr i, Expr j) { + return (utils::GetStreamCnt(i).size() > utils::GetStreamCnt(j).size()); + }); for (int i = new_iter_values.size(); i < old_iter_values.size(); ++i) { CHECK(old_iter_values[i].as_var()); new_iter_values.push_back(old_iter_values[i]); @@ -1888,23 +2188,27 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, const Expr& block VLOG(3) << "changed_loop_num is : " << changed_loop_num; VLOG(3) << "old_iter_values.size() is : " << old_iter_values.size(); if (changed_loop_num >= (int)old_iter_values.size()) { - new_loop = optim::IRCopy(block); + new_loop = optim::IRCopy(block); new_loop.As()->iter_values = new_iter_values; } else { CHECK(old_iter_values[changed_loop_num].as_var()); - auto old_var = old_iter_values[changed_loop_num].as_var_ref(); + auto old_var = old_iter_values[changed_loop_num].as_var_ref(); auto find_partial_loop = ir::CollectIRNodesWithoutTensor( expr, [&](const Expr* x) { - return x->As() && x->As()->loop_var->name == old_var->name && Contains(*x, block); + return x->As() && + x->As()->loop_var->name == old_var->name && + Contains(*x, block); }, true); CHECK_EQ(find_partial_loop.size(), 1U); - new_loop = optim::IRCopy(*find_partial_loop.begin()); + new_loop = optim::IRCopy(*find_partial_loop.begin()); auto find_schedule_block = ir::CollectIRNodesWithoutTensor( - new_loop, [&](const Expr* x) { return x->As(); }, true); + new_loop, + [&](const Expr* x) { return x->As(); }, + true); CHECK_EQ(find_schedule_block.size(), 1U); - Expr sch_block = (*find_schedule_block.begin()); + Expr sch_block = (*find_schedule_block.begin()); sch_block.As()->iter_values = new_iter_values; } VLOG(3) << "new_loop is : " << new_loop; @@ -1912,7 +2216,7 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, const Expr& block Expr res; if (used_target_loops.size() == 1) { auto for_loop = used_target_loops[0].As(); - res = For::Make(for_loop->loop_var, + res = For::Make(for_loop->loop_var, for_loop->min, for_loop->extent, for_loop->for_type(), @@ -1921,10 +2225,10 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, const Expr& block for_loop->vectorize_info(), for_loop->bind_info()); } else { - Expr outer_loop = used_target_loops.front(); - Expr inner_loop = used_target_loops.back(); + Expr outer_loop = used_target_loops.front(); + Expr inner_loop = used_target_loops.back(); inner_loop.As()->body = Block::Make({new_loop}); - res = outer_loop; + res = outer_loop; } VLOG(3) << "res is : " << res; std::vector all_loops = this->GetLoops(block); @@ -1932,14 +2236,18 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, const Expr& block this->Replace(all_loops[0], res); } -std::vector ScheduleImpl::SamplePerfectTile(utils::LinearRandomEngine::StateType* rand_seed, - const Expr& loop, - int n, - int max_innermost_factor) { - CHECK(loop.As()) << "Expr param of SamplePerfectTile should be a For loop"; +std::vector ScheduleImpl::SamplePerfectTile( + utils::LinearRandomEngine::StateType* rand_seed, + const Expr& loop, + int n, + int max_innermost_factor) { + CHECK(loop.As()) + << "Expr param of SamplePerfectTile should be a For loop"; CHECK_GE(n, 2) << "The number of tile factors should be at least 2"; - CHECK_GE(max_innermost_factor, 1) << "The max innermost factor should be at least 1"; - CHECK(common::is_zero(loop.As()->min)) << "The For loop should start from 0"; + CHECK_GE(max_innermost_factor, 1) + << "The max innermost factor should be at least 1"; + CHECK(common::is_zero(loop.As()->min)) + << "The For loop should start from 0"; int loop_extent = GetLoopExtent(loop); std::vector innermost_factors; for (int i = max_innermost_factor; i >= 1; --i) { @@ -1948,8 +2256,9 @@ std::vector ScheduleImpl::SamplePerfectTile(utils::LinearRandomEngine::Sta } } CHECK(!innermost_factors.empty()) << "No innermost factor found"; - int innermost_factor = innermost_factors[utils::SampleUniformInt(0, innermost_factors.size(), rand_seed)]; - auto result = SampleTile(rand_seed, n - 1, loop_extent / innermost_factor); + int innermost_factor = innermost_factors[utils::SampleUniformInt( + 0, innermost_factors.size(), rand_seed)]; + auto result = SampleTile(rand_seed, n - 1, loop_extent / innermost_factor); std::vector result_expr; for (auto& factor : result) { result_expr.push_back(Expr(factor)); @@ -1958,47 +2267,56 @@ std::vector ScheduleImpl::SamplePerfectTile(utils::LinearRandomEngine::Sta return result_expr; } -Expr ScheduleImpl::SampleCategorical(utils::LinearRandomEngine::StateType* rand_seed, - const std::vector& candidates, - const std::vector& probs) { +Expr ScheduleImpl::SampleCategorical( + utils::LinearRandomEngine::StateType* rand_seed, + const std::vector& candidates, + const std::vector& probs) { // check two sizes - CHECK_EQ(candidates.size(), probs.size()) << "candidates and probs must have same size."; + CHECK_EQ(candidates.size(), probs.size()) + << "candidates and probs must have same size."; int seed_idx = utils::SampleDiscreteFromDistribution(probs, rand_seed); - auto result = candidates[seed_idx]; + auto result = candidates[seed_idx]; Expr result_expr(result); return result_expr; } IRSchedule::IRSchedule() {} -IRSchedule::IRSchedule(const ModuleExpr& module_expr, utils::LinearRandomEngine::StateType rand_seed, bool debug_flag) { +IRSchedule::IRSchedule(const ModuleExpr& module_expr, + utils::LinearRandomEngine::StateType rand_seed, + bool debug_flag) { impl_ = std::make_unique(module_expr, debug_flag); this->InitSeed(rand_seed); } -IRSchedule::IRSchedule(ir::ModuleExpr&& mod_expr, ScheduleDesc&& trace, utils::LinearRandomEngine::StateType rand_seed) - : impl_(std::make_unique(std::move(mod_expr))), trace_(std::move(trace)) { +IRSchedule::IRSchedule(ir::ModuleExpr&& mod_expr, + ScheduleDesc&& trace, + utils::LinearRandomEngine::StateType rand_seed) + : impl_(std::make_unique(std::move(mod_expr))), + trace_(std::move(trace)) { this->InitSeed(rand_seed); } IRSchedule::IRSchedule(const IRSchedule& other) - : impl_(std::make_unique(optim::IRCopy(other.GetModule()))), trace_(other.trace_) { + : impl_(std::make_unique(optim::IRCopy(other.GetModule()))), + trace_(other.trace_) { this->InitSeed(other.ForkSeed()); } IRSchedule& IRSchedule::operator=(const IRSchedule& src) { - impl_ = std::make_unique(optim::IRCopy(src.GetModule())); + impl_ = std::make_unique(optim::IRCopy(src.GetModule())); trace_ = src.trace_; this->InitSeed(src.ForkSeed()); return *this; } -IRSchedule::IRSchedule(IRSchedule&& other) : impl_(std::move(other.impl_)), trace_(std::move(other.trace_)) { +IRSchedule::IRSchedule(IRSchedule&& other) + : impl_(std::move(other.impl_)), trace_(std::move(other.trace_)) { this->InitSeed(other.ForkSeed()); } IRSchedule& IRSchedule::operator=(IRSchedule&& src) { - impl_ = std::move(src.impl_); + impl_ = std::move(src.impl_); trace_ = std::move(src.trace_); this->InitSeed(src.ForkSeed()); return *this; @@ -2010,7 +2328,9 @@ void IRSchedule::InitSeed(utils::LinearRandomEngine::StateType rand_seed) { this->rand_seed_ = utils::LinearRandomEngine::NormalizeState(rand_seed); } -utils::LinearRandomEngine::StateType IRSchedule::ForkSeed() const { return utils::ForkRandomState(&rand_seed_); } +utils::LinearRandomEngine::StateType IRSchedule::ForkSeed() const { + return utils::ForkRandomState(&rand_seed_); +} void IRSchedule::SetExprs(const std::vector& exprs) { return impl_->SetExprs(exprs); @@ -2034,13 +2354,15 @@ void IRSchedule::MergeExprs() { std::vector IRSchedule::GetLoops(const Expr& block) const { auto results = impl_->GetLoops(block); - trace_.Append(ScheduleDesc::Step("GetLoops", {{"block", std::vector({block})}}, {}, results)); + trace_.Append(ScheduleDesc::Step( + "GetLoops", {{"block", std::vector({block})}}, {}, results)); return results; } std::vector IRSchedule::GetLoops(const std::string& block_name) const { auto results = impl_->GetLoops(block_name); - trace_.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", block_name}}, results)); + trace_.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", block_name}}, results)); return results; } @@ -2052,37 +2374,52 @@ std::vector IRSchedule::GetAllBlocks() const { std::vector IRSchedule::GetChildBlocks(const Expr& expr) const { auto results = impl_->GetChildBlocks(expr); - trace_.Append(ScheduleDesc::Step("GetChildBlocks", {{"expr", std::vector({expr})}}, {}, results)); + trace_.Append(ScheduleDesc::Step( + "GetChildBlocks", {{"expr", std::vector({expr})}}, {}, results)); return results; } Expr IRSchedule::GetBlock(const std::string& block_name) const { auto result = impl_->GetBlock(block_name); - trace_.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", block_name}}, {result})); + trace_.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", block_name}}, {result})); return result; } -std::vector IRSchedule::Split(const Expr& loop, const std::vector& factors) { - std::vector decision = SamplePerfectTile(loop, factors.size(), loop.As()->extent.as_int32(), factors); - auto results = Split(loop, decision); +std::vector IRSchedule::Split(const Expr& loop, + const std::vector& factors) { + std::vector decision = SamplePerfectTile( + loop, factors.size(), loop.As()->extent.as_int32(), factors); + auto results = Split(loop, decision); return results; } -std::vector IRSchedule::Split(const std::string& block_name, int loop_index, const std::vector& factors) { +std::vector IRSchedule::Split(const std::string& block_name, + int loop_index, + const std::vector& factors) { std::vector all_loops = this->GetLoops(block_name); Expr loop_expr; - CHECK_LT(loop_index, (int)all_loops.size()) << "The loop index in Split should be less than total loop's number."; + CHECK_LT(loop_index, (int)all_loops.size()) + << "The loop index in Split should be less than total loop's number."; CHECK_GE(loop_index, 0) << "The loop index in Split should be >= 0."; loop_expr = all_loops[loop_index]; return this->Split(loop_expr, factors); } -std::vector IRSchedule::Split(const Expr& loop, const std::vector& factors) { +std::vector IRSchedule::Split(const Expr& loop, + const std::vector& factors) { std::vector int_factors; - std::transform(factors.begin(), factors.end(), std::back_inserter(int_factors), [](Expr x) { return x.as_int32(); }); + std::transform(factors.begin(), + factors.end(), + std::back_inserter(int_factors), + [](Expr x) { return x.as_int32(); }); auto results = impl_->Split(loop, int_factors); - trace_.Append(ScheduleDesc::Step("Split", {{"loop", std::vector({loop})}, {"factors", factors}}, {}, results)); + trace_.Append(ScheduleDesc::Step( + "Split", + {{"loop", std::vector({loop})}, {"factors", factors}}, + {}, + results)); return results; } @@ -2092,76 +2429,105 @@ Expr IRSchedule::Fuse(const std::vector& loops) { return result; } -Expr IRSchedule::Fuse(const std::string& block_name, const std::vector& loops_index) { +Expr IRSchedule::Fuse(const std::string& block_name, + const std::vector& loops_index) { auto result = impl_->Fuse(block_name, loops_index); - trace_.Append( - ScheduleDesc::Step("FuseWithName", {}, {{"block_name", block_name}, {"loops_index", loops_index}}, {result})); + trace_.Append(ScheduleDesc::Step( + "FuseWithName", + {}, + {{"block_name", block_name}, {"loops_index", loops_index}}, + {result})); return result; } Expr IRSchedule::Fuse(const Expr& block, const std::vector& loops_index) { auto result = impl_->Fuse(block, loops_index); - trace_.Append(ScheduleDesc::Step( - "FuseWithBlock", {{"block", std::vector({block})}}, {{"loops_index", loops_index}}, {result})); + trace_.Append(ScheduleDesc::Step("FuseWithBlock", + {{"block", std::vector({block})}}, + {{"loops_index", loops_index}}, + {result})); return result; } -void IRSchedule::ComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops) { +void IRSchedule::ComputeAt(const Expr& block, + const Expr& loop, + bool keep_unit_loops) { impl_->ComputeAt(block, loop, keep_unit_loops); trace_.Append(ScheduleDesc::Step("ComputeAt", - {{"block", std::vector({block})}, {"loop", std::vector({loop})}}, + {{"block", std::vector({block})}, + {"loop", std::vector({loop})}}, {{"keep_unit_loops", keep_unit_loops}}, {})); } void IRSchedule::SimpleComputeAt(const Expr& block, const Expr& loop) { impl_->SimpleComputeAt(block, loop); - trace_.Append(ScheduleDesc::Step( - "SimpleComputeAt", {{"block", std::vector({block})}, {"loop", std::vector({loop})}}, {}, {})); + trace_.Append(ScheduleDesc::Step("SimpleComputeAt", + {{"block", std::vector({block})}, + {"loop", std::vector({loop})}}, + {}, + {})); } -void IRSchedule::ReverseComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops) { +void IRSchedule::ReverseComputeAt(const Expr& block, + const Expr& loop, + bool keep_unit_loops) { impl_->ReverseComputeAt(block, loop, keep_unit_loops); trace_.Append(ScheduleDesc::Step("ReverseComputeAt", - {{"block", std::vector({block})}, {"loop", std::vector({loop})}}, + {{"block", std::vector({block})}, + {"loop", std::vector({loop})}}, {{"keep_unit_loops", keep_unit_loops}}, {})); } Expr IRSchedule::GetRootBlock(const Expr& expr) const { auto result = impl_->GetRootBlock(expr); - trace_.Append(ScheduleDesc::Step("GetRootBlock", {{"expr", std::vector({expr})}}, {}, {result})); + trace_.Append(ScheduleDesc::Step( + "GetRootBlock", {{"expr", std::vector({expr})}}, {}, {result})); return result; } -Expr IRSchedule::CacheRead(const Expr& block, int read_buffer_index, const std::string& memory_type) { +Expr IRSchedule::CacheRead(const Expr& block, + int read_buffer_index, + const std::string& memory_type) { auto result = impl_->CacheRead(block, read_buffer_index, memory_type); - trace_.Append(ScheduleDesc::Step("CacheRead", - {{"block", std::vector({block})}}, - {{"read_buffer_index", read_buffer_index}, {"memory_type", memory_type}}, - {result})); + trace_.Append(ScheduleDesc::Step( + "CacheRead", + {{"block", std::vector({block})}}, + {{"read_buffer_index", read_buffer_index}, {"memory_type", memory_type}}, + {result})); return result; } -Expr IRSchedule::CacheWrite(const Expr& block, int write_buffer_index, const std::string& memory_type) { +Expr IRSchedule::CacheWrite(const Expr& block, + int write_buffer_index, + const std::string& memory_type) { auto result = impl_->CacheWrite(block, write_buffer_index, memory_type); trace_.Append(ScheduleDesc::Step("CacheWrite", {{"block", std::vector({block})}}, - {{"write_buffer_index", write_buffer_index}, {"memory_type", memory_type}}, + {{"write_buffer_index", write_buffer_index}, + {"memory_type", memory_type}}, {result})); return result; } void IRSchedule::SyncThreads(const Expr& ir_node, bool after_node) { impl_->SyncThreads(ir_node, after_node); - trace_.Append( - ScheduleDesc::Step("SyncThreads", {{"ir_node", std::vector({ir_node})}}, {{"after_node", after_node}}, {})); + trace_.Append(ScheduleDesc::Step("SyncThreads", + {{"ir_node", std::vector({ir_node})}}, + {{"after_node", after_node}}, + {})); } -void IRSchedule::SetBuffer(Expr& block, const std::string& memory_type, bool fixed) { +void IRSchedule::SetBuffer(Expr& block, + const std::string& memory_type, + bool fixed) { impl_->SetBuffer(block, memory_type, fixed); - trace_.Append(ScheduleDesc::Step( - "SetBuffer", {{"block", std::vector({block})}}, {{"memory_type", memory_type}, {"fixed", fixed}}, {})); + trace_.Append( + ScheduleDesc::Step("SetBuffer", + {{"block", std::vector({block})}}, + {{"memory_type", memory_type}, {"fixed", fixed}}, + {})); } Expr IRSchedule::Reorder(const std::vector& loops) { @@ -2170,68 +2536,95 @@ Expr IRSchedule::Reorder(const std::vector& loops) { return ret; } -Expr IRSchedule::Reorder(const std::string& block_name, const std::vector& loops_index) { +Expr IRSchedule::Reorder(const std::string& block_name, + const std::vector& loops_index) { Expr ret = impl_->Reorder(block_name, loops_index); - trace_.Append( - ScheduleDesc::Step("ReorderWithName", {}, {{"block_name", block_name}, {"loops_index", loops_index}}, {ret})); + trace_.Append(ScheduleDesc::Step( + "ReorderWithName", + {}, + {{"block_name", block_name}, {"loops_index", loops_index}}, + {ret})); return ret; } -Expr IRSchedule::Reorder(const Expr& block, const std::vector& loops_index) { +Expr IRSchedule::Reorder(const Expr& block, + const std::vector& loops_index) { Expr ret = impl_->Reorder(block, loops_index); - trace_.Append(ScheduleDesc::Step( - "ReorderWithBlock", {{"block", std::vector({block})}}, {{"loops_index", loops_index}}, {ret})); + trace_.Append(ScheduleDesc::Step("ReorderWithBlock", + {{"block", std::vector({block})}}, + {{"loops_index", loops_index}}, + {ret})); return ret; } void IRSchedule::Parallel(const Expr& loop) { impl_->Parallel(loop); - trace_.Append(ScheduleDesc::Step("Parallel", {{"loop", std::vector({loop})}}, {}, {})); + trace_.Append(ScheduleDesc::Step( + "Parallel", {{"loop", std::vector({loop})}}, {}, {})); } void IRSchedule::Vectorize(const Expr& loop, int factor) { impl_->Vectorize(loop, factor); - trace_.Append(ScheduleDesc::Step("Vectorize", {{"loop", std::vector({loop})}}, {{"factor", factor}}, {})); + trace_.Append(ScheduleDesc::Step("Vectorize", + {{"loop", std::vector({loop})}}, + {{"factor", factor}}, + {})); } void IRSchedule::Unroll(const Expr& loop) { impl_->Unroll(loop); - trace_.Append(ScheduleDesc::Step("Unroll", {{"loop", std::vector({loop})}}, {}, {})); + trace_.Append(ScheduleDesc::Step( + "Unroll", {{"loop", std::vector({loop})}}, {}, {})); } void IRSchedule::ComputeInline(const Expr& schedule_block) { impl_->ComputeInline(schedule_block); - trace_.Append(ScheduleDesc::Step("ComputeInline", {{"schedule_block", std::vector({schedule_block})}}, {}, {})); + trace_.Append(ScheduleDesc::Step( + "ComputeInline", + {{"schedule_block", std::vector({schedule_block})}}, + {}, + {})); } void IRSchedule::ReverseComputeInline(const Expr& schedule_block) { impl_->ReverseComputeInline(schedule_block); - trace_.Append( - ScheduleDesc::Step("ReverseComputeInline", {{"schedule_block", std::vector({schedule_block})}}, {}, {})); + trace_.Append(ScheduleDesc::Step( + "ReverseComputeInline", + {{"schedule_block", std::vector({schedule_block})}}, + {}, + {})); } void IRSchedule::Bind(const Expr& loop, const std::string& thread_axis) { impl_->Bind(loop, thread_axis); - trace_.Append(ScheduleDesc::Step("Bind", {{"loop", std::vector({loop})}}, {{"thread_axis", thread_axis}}, {})); + trace_.Append(ScheduleDesc::Step("Bind", + {{"loop", std::vector({loop})}}, + {{"thread_axis", thread_axis}}, + {})); } Expr IRSchedule::Rfactor(const Expr& rf_loop, int rf_axis) { auto result = impl_->Rfactor(rf_loop, rf_axis); - trace_.Append( - ScheduleDesc::Step("Rfactor", {{"rf_loop", std::vector({rf_loop})}}, {{"rf_axis", rf_axis}}, {result})); + trace_.Append(ScheduleDesc::Step("Rfactor", + {{"rf_loop", std::vector({rf_loop})}}, + {{"rf_axis", rf_axis}}, + {result})); return result; } -void IRSchedule::Annotate(const Expr& block, const std::string& key, const attr_t& value) { +void IRSchedule::Annotate(const Expr& block, + const std::string& key, + const attr_t& value) { impl_->Annotate(block, key, value); -#define TRACE_ANNOTATE_ITEM(data_type, step_name) \ - if (absl::holds_alternative(value)) { \ - trace_.Append(ScheduleDesc::Step(#step_name, \ - {{"block", std::vector({block})}}, \ - {{"key", key}, {"value", absl::get(value)}}, \ - {})); \ - return; \ +#define TRACE_ANNOTATE_ITEM(data_type, step_name) \ + if (absl::holds_alternative(value)) { \ + trace_.Append(ScheduleDesc::Step( \ + #step_name, \ + {{"block", std::vector({block})}}, \ + {{"key", key}, {"value", absl::get(value)}}, \ + {})); \ + return; \ } TRACE_ANNOTATE_ITEM(int, AnnotateIntAttr) TRACE_ANNOTATE_ITEM(bool, AnnotateBoolAttr) @@ -2244,48 +2637,69 @@ void IRSchedule::Annotate(const Expr& block, const std::string& key, const attr_ void IRSchedule::Unannotate(Expr& block, const std::string& key) { impl_->Unannotate(block, key); - trace_.Append(ScheduleDesc::Step("Unannotate", {{"block", std::vector({block})}}, {{"key", key}}, {})); + trace_.Append(ScheduleDesc::Step("Unannotate", + {{"block", std::vector({block})}}, + {{"key", key}}, + {})); } -void IRSchedule::FlattenLoops(const std::vector& loops, const bool force_flat) { +void IRSchedule::FlattenLoops(const std::vector& loops, + const bool force_flat) { impl_->FlattenLoops(loops, force_flat); - trace_.Append( - ScheduleDesc::Step("FlattenLoops", {{"loop", std::vector({loops})}}, {{"force_flat", force_flat}}, {})); + trace_.Append(ScheduleDesc::Step("FlattenLoops", + {{"loop", std::vector({loops})}}, + {{"force_flat", force_flat}}, + {})); } -void IRSchedule::CopyTransformAndLoopInfo(const Expr& block, const Expr& block_target) { +void IRSchedule::CopyTransformAndLoopInfo(const Expr& block, + const Expr& block_target) { impl_->CopyTransformAndLoopInfo(block, block_target); - // don't support to trace, because we can't ensure both blocks are from the same ModuleExpr + // don't support to trace, because we can't ensure both blocks are from the + // same ModuleExpr } -void IRSchedule::CopyTransformAndLoopInfo(const std::string& block_name, const std::string& block_target_name) { +void IRSchedule::CopyTransformAndLoopInfo( + const std::string& block_name, const std::string& block_target_name) { impl_->CopyTransformAndLoopInfo(block_name, block_target_name); - // don't support to trace, because we can't ensure both blocks are from the same ModuleExpr + // don't support to trace, because we can't ensure both blocks are from the + // same ModuleExpr } -std::vector IRSchedule::SamplePerfectTile(const Expr& loop, - int n, - int max_innermost_factor, - const std::vector& decision) { +std::vector IRSchedule::SamplePerfectTile( + const Expr& loop, + int n, + int max_innermost_factor, + const std::vector& decision) { std::vector factors; std::vector new_decision; if (decision.empty()) { - factors = impl_->SamplePerfectTile(&rand_seed_, loop, n, max_innermost_factor); - std::transform( - factors.begin(), factors.end(), std::back_inserter(new_decision), [](Expr x) { return x.as_int32(); }); + factors = + impl_->SamplePerfectTile(&rand_seed_, loop, n, max_innermost_factor); + std::transform(factors.begin(), + factors.end(), + std::back_inserter(new_decision), + [](Expr x) { return x.as_int32(); }); } else { new_decision = decision; - std::transform(decision.begin(), decision.end(), std::back_inserter(factors), [](int x) { return Expr(x); }); + std::transform(decision.begin(), + decision.end(), + std::back_inserter(factors), + [](int x) { return Expr(x); }); } trace_.Append( ScheduleDesc::Step("SamplePerfectTile", {{"loop", std::vector({loop})}}, - {{"n", n}, {"max_innermost_factor", max_innermost_factor}, {"decision", new_decision}}, + {{"n", n}, + {"max_innermost_factor", max_innermost_factor}, + {"decision", new_decision}}, factors)); return factors; } -void IRSchedule::TagPostSchedule() { trace_.Append(ScheduleDesc::Step("TagPostSchedule", {}, {}, {})); } +void IRSchedule::TagPostSchedule() { + trace_.Append(ScheduleDesc::Step("TagPostSchedule", {}, {}, {})); +} Expr IRSchedule::SampleCategorical(const std::vector& candidates, const std::vector& probs, @@ -2301,8 +2715,12 @@ Expr IRSchedule::SampleCategorical(const std::vector& candidates, result = Expr(ndco); } } - trace_.Append(ScheduleDesc::Step( - "SampleCategorical", {}, {{"candidates", candidates}, {"probs", probs}, {"decision", new_decision}}, {result})); + trace_.Append(ScheduleDesc::Step("SampleCategorical", + {}, + {{"candidates", candidates}, + {"probs", probs}, + {"decision", new_decision}}, + {result})); return result; } diff --git a/paddle/cinn/ir/ir_schedule.h b/paddle/cinn/ir/ir_schedule.h index 2361e50378303..d847e933eb54d 100644 --- a/paddle/cinn/ir/ir_schedule.h +++ b/paddle/cinn/ir/ir_schedule.h @@ -30,13 +30,14 @@ namespace cinn { namespace ir { /** - * A struct representing a module that contains Expr. This struct is only used in Schedule process. + * A struct representing a module that contains Expr. This struct is only used + * in Schedule process. */ class ModuleExpr { public: - ModuleExpr() = default; + ModuleExpr() = default; ModuleExpr(const ModuleExpr& mod_expr) = default; - ModuleExpr(ModuleExpr&& mod_expr) = default; + ModuleExpr(ModuleExpr&& mod_expr) = default; ModuleExpr& operator=(const ModuleExpr& mod_expr) = default; @@ -51,15 +52,17 @@ class ModuleExpr { void SetExprs(const std::vector& exprs) { exprs_ = exprs; } private: - //! Exprs stored in ModuleExpr. Each one is an AST, representing a computation kernel. + //! Exprs stored in ModuleExpr. Each one is an AST, representing a computation + //! kernel. std::vector exprs_; }; /** - * A struct containing all the schedule primitives. Each shedule primitive is a member function of IRSchedule. - * Schedule primitves are implmented by ScheduleImpl manipulating the AST - IR(Expr). - * To support serializing and replaying, each schedule primitive should append a ScheduleDesc::Step to - * the trace_ in its corresponding function implment. + * A struct containing all the schedule primitives. Each shedule primitive is a + * member function of IRSchedule. Schedule primitves are implmented by + * ScheduleImpl manipulating the AST - IR(Expr). To support serializing and + * replaying, each schedule primitive should append a ScheduleDesc::Step to the + * trace_ in its corresponding function implment. */ class ScheduleImpl; class IRSchedule { @@ -67,8 +70,10 @@ class IRSchedule { IRSchedule(); explicit IRSchedule(const ModuleExpr& modexpr, utils::LinearRandomEngine::StateType rand_seed = -1, - bool debug_flag = false); - IRSchedule(ir::ModuleExpr&& mod_expr, ScheduleDesc&& trace, utils::LinearRandomEngine::StateType rand_seed = -1); + bool debug_flag = false); + IRSchedule(ir::ModuleExpr&& mod_expr, + ScheduleDesc&& trace, + utils::LinearRandomEngine::StateType rand_seed = -1); IRSchedule(const IRSchedule& other); IRSchedule& operator=(const IRSchedule& src); IRSchedule(IRSchedule&& other); @@ -131,10 +136,13 @@ class IRSchedule { * @param factors The factors we used to split the loop. * @return The splited loops. */ - std::vector Split(const std::string& block_name, int loop_index, const std::vector& factors); + std::vector Split(const std::string& block_name, + int loop_index, + const std::vector& factors); /** - * \brief Split a for loop into multiple loops, based on the factors, only used for deserialization of trace. + * \brief Split a for loop into multiple loops, based on the factors, only + * used for deserialization of trace. * @param loop The loop to be splited. * @param factors The factors we used to split the loop. * @return The splited loops. @@ -151,7 +159,8 @@ class IRSchedule { /** * \brief Fuse for loops and return the fused loop. * @param block_name Name of the block we want to modify. - * @param loops_index Indices of the loops to be fused, stored in ascending order. + * @param loops_index Indices of the loops to be fused, stored in ascending + * order. * @return The fused loop. */ Expr Fuse(const std::string& block_name, const std::vector& loops_index); @@ -159,7 +168,8 @@ class IRSchedule { /** * \brief Fuse for loops and return the fused loop. * @param block The block we want to modify. - * @param loops_index Indices of the loops to be fused, stored in ascending order. + * @param loops_index Indices of the loops to be fused, stored in ascending + * order. * @return The fused loop. */ Expr Fuse(const Expr& block, const std::vector& loops_index); @@ -170,10 +180,13 @@ class IRSchedule { * @param loop The loop we will move the block to. * @param keep_unit_loops Whether to keep the unit loop. */ - void ComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops = false); + void ComputeAt(const Expr& block, + const Expr& loop, + bool keep_unit_loops = false); /** - * \brief Move a block's location under a loop without considering their dependency. + * \brief Move a block's location under a loop without considering their + * dependency. * @param block The block we want to move its computation location. * @param loop The loop we will move the block to. */ @@ -185,7 +198,9 @@ class IRSchedule { * @param loop The loop we will move the block to. * @param keep_unit_loops Whether to keep the unit loop. */ - void ReverseComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops = false); + void ReverseComputeAt(const Expr& block, + const Expr& loop, + bool keep_unit_loops = false); /** * \brief Find an expr's root ScheduleBlockRealize node @@ -201,7 +216,9 @@ class IRSchedule { * @param memory_type String that indicates the buffer's storage scope. * @return The buffer's cache. */ - Expr CacheRead(const Expr& block, int read_buffer_index, const std::string& memory_type); + Expr CacheRead(const Expr& block, + int read_buffer_index, + const std::string& memory_type); /** * \brief Find a buffer that is being written, and create its cache. @@ -210,23 +227,29 @@ class IRSchedule { * @param memory_type String that indicates the buffer's storage scope. * @return The buffer's cache. */ - Expr CacheWrite(const Expr& block, int write_buffer_index, const std::string& memory_type); + Expr CacheWrite(const Expr& block, + int write_buffer_index, + const std::string& memory_type); /** * \brief Add SyncThreads statements in AST. * @param ir_node The insertion point in AST. - * @param after_node Whether to insert the statement after the insertion point. When it is True, we will insert the - * SyncThreads statement after the insertion IR. When it is False, we will insert the SyncThreads statement before the - * insertion IR. + * @param after_node Whether to insert the statement after the insertion + * point. When it is True, we will insert the SyncThreads statement after the + * insertion IR. When it is False, we will insert the SyncThreads statement + * before the insertion IR. */ void SyncThreads(const Expr& ir_node, bool after_node = true); /*! * \brief Set a tensor's buffer type(memory_type) * \param block The ScheduleBlockRealize corresponding to an unique tensor. - * \param memory_type The memory type we want to set. Should be "local", "shared" or "global". + * \param memory_type The memory type we want to set. Should be "local", + * "shared" or "global". */ - void SetBuffer(Expr& block, const std::string& memory_type, bool fixed = false); + void SetBuffer(Expr& block, + const std::string& memory_type, + bool fixed = false); /** * \brief Reorder the loops in the order of vector. @@ -247,7 +270,8 @@ class IRSchedule { * stmts contain several loop chains if the reordered computation has * multiple loop chains. */ - Expr Reorder(const std::string& block_name, const std::vector& loops_index); + Expr Reorder(const std::string& block_name, + const std::vector& loops_index); /** * \brief Reorder the loops in the order of vector elements. @@ -314,13 +338,15 @@ class IRSchedule { //! Copy another block's schedule transform. void CopyTransformAndLoopInfo(const Expr& block, const Expr& block_target); - void CopyTransformAndLoopInfo(const std::string& block_name, const std::string& block_target_name); + void CopyTransformAndLoopInfo(const std::string& block_name, + const std::string& block_target_name); /** - * \brief Factorize the reduction block by the given loop. The block will be split into two blocks: rfactor block and - * final write-back block. + * \brief Factorize the reduction block by the given loop. The block will be + * split into two blocks: rfactor block and final write-back block. * @param rf_loop the reduce loop to do rfactor transformation. - * @param rf_axis the axis where the new generated loop is placed in the rfactor block. + * @param rf_axis the axis where the new generated loop is placed in the + * rfactor block. * @return The new created rfactor tensor. * * For example, input the block: @@ -332,15 +358,14 @@ class IRSchedule { * B[i] = B[i] + A[i, j, k] * \endcode * - * If the rfactor loop is k and rf_axis is 0, the rfactor transformation is divided into 2 steps: - * 1. get the rfactor block where the reduce loop k is transformed to the serial loop with no accumalation and a new - * rfactor tensor is created. The axis k will be placed in the rf_axis of the new rf_tensor. The rf_block is as - * follows: - * \code - * for (rf_k, 0, 30) // rfactor loop k is transformed to the serial loop. - * for (i, 0, 10) // serial loop for (j, 0, 20) // reduce loop - * rf_B_init[rf_k, i] = 0 - * for (j, 0, 20) // reduce loop + * If the rfactor loop is k and rf_axis is 0, the rfactor transformation is + * divided into 2 steps: + * 1. get the rfactor block where the reduce loop k is transformed to the + * serial loop with no accumalation and a new rfactor tensor is created. The + * axis k will be placed in the rf_axis of the new rf_tensor. The rf_block is + * as follows: \code for (rf_k, 0, 30) // rfactor loop k is transformed + * to the serial loop. for (i, 0, 10) // serial loop for (j, 0, 20) // + * reduce loop rf_B_init[rf_k, i] = 0 for (j, 0, 20) // reduce loop * rf_B[rf_k, i] = rf_B[rf_k, i] + A[i, j, rf_k] * \endcode * 2. do reduction of the rfactor loop k to get the final result block: @@ -375,18 +400,18 @@ class IRSchedule { */ // Temporary solution for simplify the elementwise/broadcast/injective index. // TODO(sunli): Solve Index Simplify. - void FlattenLoops(const std::vector& loops, const bool force_flat = false); + void FlattenLoops(const std::vector& loops, + const bool force_flat = false); /*! * \brief Sample the factors to tile a specific loop perfectly * \param loop the loop to be split * \param n the number of loop layers to split * \param max_innermost_factor the maximum factor of the innermost loop - * \param decision the decision data of the last sample, or the artificially given decision data - * \return the split factors of the loop (The larger the index, the inner the corresponding loop) - * For example, return {16,64} means the loop will be like this: - * for (i, 0, 16) { - * for (j, 0, 64) { + * \param decision the decision data of the last sample, or the artificially + * given decision data \return the split factors of the loop (The larger the + * index, the inner the corresponding loop) For example, return {16,64} means + * the loop will be like this: for (i, 0, 16) { for (j, 0, 64) { * ... * } * } @@ -397,8 +422,9 @@ class IRSchedule { const std::vector& decision = {}); /*! - * \brief Insert a tag in schedule_desc to mark the beginning of post processing, - * the schedue primitive itself does not make any changes to the IR. + * \brief Insert a tag in schedule_desc to mark the beginning of post + * processing, the schedue primitive itself does not make any changes to the + * IR. */ void TagPostSchedule(); @@ -406,7 +432,8 @@ class IRSchedule { * \brief Randomly sample an integer according to the given distribution. * @param candidates Candidate set of integers. * @param probs Probability distribution of candidate integer set. - * @param decision the decision data of the last sample, or the artificially given decision data. + * @param decision the decision data of the last sample, or the artificially + * given decision data. * @return Random variables sampled. */ Expr SampleCategorical(const std::vector& candidates, @@ -429,7 +456,8 @@ class IRSchedule { /*! * \brief The base class of the inliner, which handles: * 1) Remove the block to be lined - * 2) Maintain a list of index variables and their substition of the buffer being inlined + * 2) Maintain a list of index variables and their substition of the buffer + * being inlined */ class BaseInliner : public ir::IRMutator<> { protected: @@ -444,7 +472,8 @@ class BaseInliner : public ir::IRMutator<> { protected: //! Check if indices are validate. If so, set idx_vars_ properly. - bool UpdateAndCheckIndexVars(const std::vector& indices, int expected_ndim); + bool UpdateAndCheckIndexVars(const std::vector& indices, + int expected_ndim); void SetIndexSubstitution(const std::vector& indices); @@ -455,7 +484,8 @@ class BaseInliner : public ir::IRMutator<> { Expr inlined_store_{nullptr}; //! The indices used for indexing the buffer to be inlined std::vector idx_vars_; - //! Replacing vars(idx_sub_var_) in indices to corresponding expr(idx_sub_expr_) + //! Replacing vars(idx_sub_var_) in indices to corresponding + //! expr(idx_sub_expr_) std::vector idx_sub_var_; std::vector idx_sub_expr_; @@ -472,11 +502,13 @@ class BaseInliner : public ir::IRMutator<> { /*! * \brief Helper to inline the producer block into its consumer(s) * The derived class implements: - * Substitute `Load` on the tensor to be inlined to its value calculation in the producer block + * Substitute `Load` on the tensor to be inlined to its value calculation in the + * producer block */ class ComputeInliner : public BaseInliner { public: - explicit ComputeInliner(const Tensor& inlined_tensor, const Expr& inlined_store) + explicit ComputeInliner(const Tensor& inlined_tensor, + const Expr& inlined_store) : BaseInliner(inlined_tensor, inlined_store) {} bool BodyPatternAllowInline(); @@ -501,7 +533,9 @@ class ReverseComputeInliner : public BaseInliner { const Expr& inlined_store, const Expr& inlined_load, const Expr& target_store) - : BaseInliner(inlined_tensor, inlined_store), inlined_load_(inlined_load), target_store_(target_store) {} + : BaseInliner(inlined_tensor, inlined_store), + inlined_load_(inlined_load), + target_store_(target_store) {} bool BodyPatternAllowInline(); @@ -548,13 +582,13 @@ class LeafBlockRemovalPlan : public ir::IRMutator<> { int block_index = -1; for (int i = 0; i < expr->stmts.size(); ++i) { auto keep_flag = find_block; - find_block = false; - auto* node = op->As(); + find_block = false; + auto* node = op->As(); IRMutator::Visit(&node->stmts[i], &node->stmts[i]); if (find_block) { if (depth == 0) { *source_expr_ = *op; - block_index = i; + block_index = i; } depth++; } @@ -569,7 +603,7 @@ class LeafBlockRemovalPlan : public ir::IRMutator<> { new_stmts.push_back(expr->stmts[i]); } auto target_block = ir::Block::Make(new_stmts); - *target_expr_ = target_block; + *target_expr_ = target_block; } } else { IRMutator::Visit(expr, op); @@ -586,7 +620,8 @@ class LeafBlockRemovalPlan : public ir::IRMutator<> { class ComputeInlineChecker : public ir::IRMutator<> { public: - ComputeInlineChecker(IRSchedule& schedule, Expr& block) : ir_schedule_(schedule), block_(block) {} + ComputeInlineChecker(IRSchedule& schedule, Expr& block) + : ir_schedule_(schedule), block_(block) {} bool Check(); @@ -595,7 +630,8 @@ class ComputeInlineChecker : public ir::IRMutator<> { private: void Visit(const ir::Load* expr, Expr* op) { // Check there is Load Expr corresponds to Store Expr - if ((store_.As()->tensor).as_tensor_ref()->name == expr->tensor.as_tensor_ref()->name) { + if ((store_.As()->tensor).as_tensor_ref()->name == + expr->tensor.as_tensor_ref()->name) { should_skip_ = false; return; } diff --git a/paddle/cinn/ir/ir_schedule_util.cc b/paddle/cinn/ir/ir_schedule_util.cc index 26ea1a6736365..a518ad7e7860f 100644 --- a/paddle/cinn/ir/ir_schedule_util.cc +++ b/paddle/cinn/ir/ir_schedule_util.cc @@ -42,9 +42,11 @@ Tensor GetTensor(const Expr& block) { CHECK(block.As()); auto find_tensor = ir::CollectIRNodesWithoutTensor( block, [&](const Expr* x) { return x->As(); }, true); - CHECK_EQ(find_tensor.size(), 1U) << "One block should only have one Store node!(except for root block)"; + CHECK_EQ(find_tensor.size(), 1U) + << "One block should only have one Store node!(except for root block)"; CHECK((*find_tensor.begin()).As()->tensor.as_tensor()); - Tensor tensor = (*find_tensor.begin()).As()->tensor.as_tensor_ref(); + Tensor tensor = + (*find_tensor.begin()).As()->tensor.as_tensor_ref(); return tensor; } @@ -52,15 +54,19 @@ Tensor GetReadTensor(const Expr& block, int index) { CHECK(block.As()); auto find_tensor = ir::CollectIRNodesWithoutTensor( block, [&](const Expr* x) { return x->As(); }, true); - CHECK_EQ(find_tensor.size(), 1U) << "One block should only have one Store node!(except for root block)"; + CHECK_EQ(find_tensor.size(), 1U) + << "One block should only have one Store node!(except for root block)"; std::vector res; - auto find_read_tensor = ir::CollectIRNodesWithoutTensor(block, [&](const Expr* x) { - if (x->As()) res.push_back(x->As()->tensor.as_tensor_ref()); - return x->As(); - }); + auto find_read_tensor = + ir::CollectIRNodesWithoutTensor(block, [&](const Expr* x) { + if (x->As()) + res.push_back(x->As()->tensor.as_tensor_ref()); + return x->As(); + }); CHECK_EQ(find_read_tensor.size(), res.size()); CHECK(!find_read_tensor.empty()) << "Didn't find Load tensor in block!"; - CHECK_LT(index, (int)find_read_tensor.size()) << "Index is not < read tensor's size!"; + CHECK_LT(index, (int)find_read_tensor.size()) + << "Index is not < read tensor's size!"; return res[index]; } @@ -80,7 +86,7 @@ void SetCudaAxisInfo(Expr* lowered_func) { auto func_body = lowered_func->as_lowered_func_ref()->body; CudaAxisInfo info; - auto block_nodes = ir::CollectIRNodes(func_body, [&](const Expr* x) { + auto block_nodes = ir::CollectIRNodes(func_body, [&](const Expr* x) { if (x->As() && x->As()->bind_info().valid()) { auto bind_info = x->As()->bind_info(); info.set_valid(true); @@ -88,16 +94,22 @@ void SetCudaAxisInfo(Expr* lowered_func) { CHECK(common::is_zero(x->As()->min)); CHECK(x->As()->extent.is_constant()); int range = x->As()->extent.get_constant(); - range = range > info.block_dim(bind_info.offset) ? range : info.block_dim(bind_info.offset); - VLOG(3) << "Set block dim[" << bind_info.offset << "] with range " << range; + range = range > info.block_dim(bind_info.offset) + ? range + : info.block_dim(bind_info.offset); + VLOG(3) << "Set block dim[" << bind_info.offset << "] with range " + << range; info.set_block_dim(bind_info.offset, range); } else if (bind_info.for_type == ForType::GPUBlock) { CHECK(common::is_zero(x->As()->min)); CHECK(x->As()->extent.is_constant()); int range = x->As()->extent.get_constant(); - range = range > info.grid_dim(bind_info.offset) ? range : info.grid_dim(bind_info.offset); + range = range > info.grid_dim(bind_info.offset) + ? range + : info.grid_dim(bind_info.offset); info.set_grid_dim(bind_info.offset, range); - VLOG(3) << "Set grid dim[" << bind_info.offset << "] with range " << range; + VLOG(3) << "Set grid dim[" << bind_info.offset << "] with range " + << range; } else { LOG(FATAL) << "The for loop's bind info should be gpu block or thread!"; } @@ -109,14 +121,19 @@ void SetCudaAxisInfo(Expr* lowered_func) { bool Contains(const Expr& container, const Expr& expr) { auto find_expr = ir::CollectIRNodesWithoutTensor( - container, [&](const Expr* x) { return (x->node_type() == expr.node_type() && *x == expr); }, true); + container, + [&](const Expr* x) { + return (x->node_type() == expr.node_type() && *x == expr); + }, + true); return (!find_expr.empty()); } Expr GetNextForLoop(const Expr& for_loop) { Expr result; - CHECK(for_loop.As()) << "The input of GetNextForLoop should be ir::For!"; - Expr for_body = for_loop.As()->body; + CHECK(for_loop.As()) + << "The input of GetNextForLoop should be ir::For!"; + Expr for_body = for_loop.As()->body; ir::Block* for_body_block = for_body.As(); CHECK(for_body_block) << "The for_loop's body shoule be Block!"; @@ -145,7 +162,8 @@ Expr GetNextForLoop(const Expr& for_loop) { // we will check it later in the future. CHECK(block_body.As()->true_case.As()); Expr true_case = block_body.As()->true_case; - if (true_case.As()->stmts.size() != 1U || !true_case.As()->stmts[0].As()) + if (true_case.As()->stmts.size() != 1U || + !true_case.As()->stmts[0].As()) return result; result = true_case.As()->stmts[0]; return result; @@ -162,14 +180,16 @@ std::vector GetIfThenElseInRange(const Expr& top, const Expr& bottom) { CHECK(bottom.As()); for (auto loop_iter = top; loop_iter != bottom;) { CHECK(loop_iter.As()); - CHECK(loop_iter.As()->body.As()) << "For node's body should be Block!"; + CHECK(loop_iter.As()->body.As()) + << "For node's body should be Block!"; auto block = loop_iter.As()->body.As(); for (Expr tmp : block->stmts) { if (tmp.As()) { if_nodes.push_back(tmp); CHECK(tmp.As()->true_case.As()); Expr true_case = tmp.As()->true_case; - CHECK(true_case.As()->stmts.size() == 1U && true_case.As()->stmts[0].As()); + CHECK(true_case.As()->stmts.size() == 1U && + true_case.As()->stmts[0].As()); tmp = true_case.As()->stmts[0]; } if (tmp.As()) { @@ -180,14 +200,18 @@ std::vector GetIfThenElseInRange(const Expr& top, const Expr& bottom) { return if_nodes; } -void ReplaceExpr(Expr* source, const std::vector& replaced, const std::vector& candidates) { +void ReplaceExpr(Expr* source, + const std::vector& replaced, + const std::vector& candidates) { CHECK_EQ(replaced.size(), candidates.size()) - << "In ReplaceExpr, the size of Vars to be replaced must be equal to the size of cadidate Exprs! Please check."; + << "In ReplaceExpr, the size of Vars to be replaced must be equal to the " + "size of cadidate Exprs! Please check."; if (replaced.empty()) return; std::map replacing_map; for (int i = 0; i < replaced.size(); ++i) { // If the Var to be replaced is equal to the candidate, we skip it. - if (candidates[i].is_var() && candidates[i].as_var_ref() == replaced[i]) continue; + if (candidates[i].is_var() && candidates[i].as_var_ref() == replaced[i]) + continue; replacing_map[replaced[i]] = candidates[i]; } MappingVarToExprMutator mapper(replacing_map); @@ -195,15 +219,20 @@ void ReplaceExpr(Expr* source, const std::vector& replaced, const std::vect return; } -std::vector ValidateFactors(const std::vector& factors, int total_extent) { - CHECK(!factors.empty()) << "The factors param of Split should not be empty! Please check."; +std::vector ValidateFactors(const std::vector& factors, + int total_extent) { + CHECK(!factors.empty()) + << "The factors param of Split should not be empty! Please check."; bool has_minus_one = false; - int product = 1; + int product = 1; for (auto& i : factors) { - CHECK(i != 0) << "The params in factors of Split should not be 0! Please check."; - CHECK(i >= -1) << "The params in factors of Split should not be less than -1! Please check."; + CHECK(i != 0) + << "The params in factors of Split should not be 0! Please check."; + CHECK(i >= -1) << "The params in factors of Split should not be less than " + "-1! Please check."; if (i == -1) { - CHECK(!has_minus_one) << "The params in factors of Split should not have more than one -1! Please check."; + CHECK(!has_minus_one) << "The params in factors of Split should not have " + "more than one -1! Please check."; has_minus_one = true; } else { product *= i; @@ -212,11 +241,14 @@ std::vector ValidateFactors(const std::vector& factors, int total_exte std::vector validated_factors = factors; if (!has_minus_one) { CHECK_GE(product, total_extent) - << "In Split, the factors' product should be equal to original loop's extent! Please check."; + << "In Split, the factors' product should be equal to original loop's " + "extent! Please check."; return validated_factors; } else { - CHECK_LE(product, total_extent) << "In Split, when there is -1 in factors, the other factors' product should be <= " - "original loop's extent! Please check."; + CHECK_LE(product, total_extent) + << "In Split, when there is -1 in factors, the other factors' product " + "should be <= " + "original loop's extent! Please check."; int minus_one_candidate = (int)ceil((double)total_extent / (double)product); for (int i = 0; i < validated_factors.size(); ++i) { if (validated_factors[i] == -1) { @@ -232,42 +264,52 @@ void CHECKRfactorValidation(const Expr& rf_loop, int rf_axis) { CHECK(rf_for) << "Expr param of Rfactor must be For node! Please check."; // check the rf_loop only has one schedule block auto block_nodes = ir::CollectIRNodesWithoutTensor( - rf_loop, [&](const Expr* x) { return x->As(); }, true); - CHECK_EQ(block_nodes.size(), 1U) << "Rfactor Loop should only have one schedule block"; + rf_loop, + [&](const Expr* x) { return x->As(); }, + true); + CHECK_EQ(block_nodes.size(), 1U) + << "Rfactor Loop should only have one schedule block"; auto find_store = ir::CollectIRNodesWithoutTensor( rf_loop, [&](const Expr* x) { return x->As(); }, true); CHECK_EQ(find_store.size(), 1U); auto indice = find_store.begin()->As()->indices; // check rf_axis - CHECK_LE(rf_axis, indice.size()) << "rf_axis should not be greater than store's domain size"; + CHECK_LE(rf_axis, indice.size()) + << "rf_axis should not be greater than store's domain size"; // check rfactor loop is reduce auto* sch_block_realize = block_nodes.begin()->As(); - auto* sch_block = sch_block_realize->schedule_block.As(); + auto* sch_block = sch_block_realize->schedule_block.As(); CHECK(sch_block); auto& iter_values = sch_block_realize->iter_values; - auto& iter_vars = sch_block->iter_vars; + auto& iter_vars = sch_block->iter_vars; CHECK_EQ(iter_values.size(), iter_vars.size()); auto rf_loop_var = rf_for->loop_var; Var rf_block_var; for (int i = 0; i < iter_values.size(); ++i) { if (ContainVar({iter_values[i]}, rf_loop_var->name)) { - CHECK(!rf_block_var.defined()) << "rfactor loop var can only be binded to one block var"; + CHECK(!rf_block_var.defined()) + << "rfactor loop var can only be binded to one block var"; auto iter_value = iter_values[i].As<_Var_>(); CHECK(iter_value) << "not support complex reduce bindings"; rf_block_var = iter_vars[i]; - auto it = std::find_if(indice.begin(), indice.end(), [&](const Expr& x) { + auto it = std::find_if(indice.begin(), indice.end(), [&](const Expr& x) { return x.As<_Var_>() && x.As<_Var_>()->name == rf_block_var->name; }); - CHECK(it == indice.end()) << "rfactor loop var is not reduce, please check!"; + CHECK(it == indice.end()) + << "rfactor loop var is not reduce, please check!"; } } } std::vector GetLoopsOfExpr(const Expr& expr, const Expr& root) { - auto loop_nodes = - ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { return x->As() && Contains(*x, expr); }); + auto loop_nodes = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { + return x->As() && Contains(*x, expr); + }); std::vector result(loop_nodes.begin(), loop_nodes.end()); - if (result.empty()) LOG(FATAL) << "Didn't find expr's : \n" << expr << "\n loops in root : \n" << root; + if (result.empty()) + LOG(FATAL) << "Didn't find expr's : \n" + << expr << "\n loops in root : \n" + << root; std::sort(result.begin(), result.end(), [&](Expr i, Expr j) { return (utils::GetStreamCnt(i).size() > utils::GetStreamCnt(j).size()); }); @@ -295,7 +337,8 @@ IterRange GetAccessedRange(const Expr& index, Expr indice_extent; Expr mod_extent(0); - if (indice_min.As() && indice_min.As()->b().is_constant()) mod_extent = indice_min.As()->b(); + if (indice_min.As() && indice_min.As()->b().is_constant()) + mod_extent = indice_min.As()->b(); if (indice_min == indice_max) { if (common::is_zero(mod_extent)) { @@ -305,24 +348,29 @@ IterRange GetAccessedRange(const Expr& index, indice_extent = mod_extent; } } else { - indice_extent = common::AutoSimplify(common::AutoSimplify(indice_max) - common::AutoSimplify(indice_min) + 1); + indice_extent = common::AutoSimplify(common::AutoSimplify(indice_max) - + common::AutoSimplify(indice_min) + 1); } if (indice_extent.is_constant() && indice_extent.get_constant() < 0) { VLOG(3) << "deduced indices are not constant"; - indice_min = indice_max; + indice_min = indice_max; indice_extent = Expr(-indice_extent.get_constant()); } - VLOG(3) << "indice_min=" << indice_min << ", indice_max=" << indice_max << ", indice_extent=" << indice_extent; + VLOG(3) << "indice_min=" << indice_min << ", indice_max=" << indice_max + << ", indice_extent=" << indice_extent; return IterRange(indice_min, indice_extent); } -std::vector CalculateTensorRegions(const Expr& block, - const std::vector& tensor_indices, - const Tensor& tensor, - const Expr& root) { +std::vector CalculateTensorRegions( + const Expr& block, + const std::vector& tensor_indices, + const Tensor& tensor, + const Expr& root) { CHECK(block.As()); - auto iter_vars = block.As()->schedule_block.As()->iter_vars; + auto iter_vars = block.As() + ->schedule_block.As() + ->iter_vars; auto iter_values = block.As()->iter_values; std::vector loop_vars; @@ -332,7 +380,8 @@ std::vector CalculateTensorRegions(const Expr& block, for (auto& loop : outer_loops) { CHECK(loop.As()); loop_vars.emplace_back(loop.As()->loop_var); - loop_ranges.emplace_back(IterRange(loop.As()->min, loop.As()->extent)); + loop_ranges.emplace_back( + IterRange(loop.As()->min, loop.As()->extent)); } std::vector result; @@ -341,11 +390,13 @@ std::vector CalculateTensorRegions(const Expr& block, ReplaceExpr(&binded_index, iter_vars, iter_values); auto range = GetAccessedRange(binded_index, loop_vars, loop_ranges); - // in generally, the range should be constant, but in some cases our AutoSimplify - // (algebraic simplification function) can't simplify completely where we use the whole - // shape in this indice as the accessed range conservatively + // in generally, the range should be constant, but in some cases our + // AutoSimplify (algebraic simplification function) can't simplify + // completely where we use the whole shape in this indice as the accessed + // range conservatively if (!range.min.is_constant() || !range.extent.is_constant()) { - VLOG(3) << "deduced range is not constant, range.min=" << range.min << ", range.extent=" << range.extent; + VLOG(3) << "deduced range is not constant, range.min=" << range.min + << ", range.extent=" << range.extent; if (tensor->buffer.defined()) { CHECK_GT((int)tensor->buffer->shape.size(), i); result.emplace_back(IterRange(Expr(0), tensor->buffer->shape[i])); @@ -363,23 +414,27 @@ std::vector CalculateTensorRegions(const Expr& block, Expr GetNthAccessExpr(const Expr& block, int index, bool is_write) { CHECK(block.As()); - auto compute_body = block.As()->schedule_block.As()->body; + auto compute_body = block.As() + ->schedule_block.As() + ->body; if (is_write) { std::vector find_store_vec; - auto find_store = ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { - if (x->As()) find_store_vec.push_back(*x); - return x->As(); - }); + auto find_store = + ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { + if (x->As()) find_store_vec.push_back(*x); + return x->As(); + }); CHECK_EQ(find_store.size(), find_store_vec.size()); CHECK_LT(index, (int)find_store.size()); Expr store_index = find_store_vec[index]; return store_index; } else { std::vector find_load_vec; - auto find_load = ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { - if (x->As()) find_load_vec.push_back(*x); - return x->As(); - }); + auto find_load = + ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { + if (x->As()) find_load_vec.push_back(*x); + return x->As(); + }); CHECK_EQ(find_load.size(), find_load_vec.size()); CHECK_LT(index, (int)find_load.size()); Expr load_index = find_load_vec[index]; @@ -406,7 +461,8 @@ Expr MakeCacheBlock(const std::vector& buffer_ranges, std::vector iter_values; // Create loop vars and block vars' binding_value for (const auto& range : buffer_ranges) { - Var loop_var(common::UniqName("cache_ax" + std::to_string(loop_vars.size()))); + Var loop_var( + common::UniqName("cache_ax" + std::to_string(loop_vars.size()))); // Var loop_var("ax" + std::to_string(loop_vars.size())); loop_vars.push_back(loop_var); iter_values.push_back(common::AutoSimplify(range.min + loop_var)); @@ -420,14 +476,19 @@ Expr MakeCacheBlock(const std::vector& buffer_ranges, Var var(Expr(0), dim, "v" + std::to_string(block_vars.size()), false); block_vars.push_back(var); } - auto body = new_tensor->tensor_store_expanded_body(); - std::vector axis_vars = common::GenDefaultAxis(new_tensor->domain.size()); - axis_vars.insert(axis_vars.end(), new_tensor->reduce_axis.begin(), new_tensor->reduce_axis.end()); + auto body = new_tensor->tensor_store_expanded_body(); + std::vector axis_vars = + common::GenDefaultAxis(new_tensor->domain.size()); + axis_vars.insert(axis_vars.end(), + new_tensor->reduce_axis.begin(), + new_tensor->reduce_axis.end()); for (int i = 0; i < axis_vars.size(); ++i) { optim::ReplaceVarWithExpr(&body, axis_vars[i], block_vars[i]); } Expr block = ir::ScheduleBlockRealize::Make( - iter_values, ir::ScheduleBlock::Make(block_vars, {}, {}, new_tensor->name, Block::Make({body}))); + iter_values, + ir::ScheduleBlock::Make( + block_vars, {}, {}, new_tensor->name, Block::Make({body}))); Expr new_body = block; for (int i = (int)loop_vars.size() - 1; i >= 0; i--) { new_body = For::Make(loop_vars[i], @@ -442,15 +503,22 @@ Expr MakeCacheBlock(const std::vector& buffer_ranges, } void FindInsertionPoint(Expr& root, CacheBlockInfo* info, bool is_write) { - Expr find_tensor = is_write ? Expr(info->write_tensor) : Expr(info->read_tensor); - auto find_produce_read = ir::CollectIRNodesWithoutTensor( - root, [&](const Expr* x) { return x->As() && x->As()->tensor == find_tensor; }); + Expr find_tensor = + is_write ? Expr(info->write_tensor) : Expr(info->read_tensor); + auto find_produce_read = + ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { + return x->As() && x->As()->tensor == find_tensor; + }); if (find_produce_read.empty()) { CHECK(root.As()->schedule_block.As()); - CHECK(root.As()->schedule_block.As()->body.As()); - info->loc_block = root.As()->schedule_block.As()->body; - info->loc_pos = 0; + CHECK(root.As() + ->schedule_block.As() + ->body.As()); + info->loc_block = root.As() + ->schedule_block.As() + ->body; + info->loc_pos = 0; return; } @@ -458,8 +526,11 @@ void FindInsertionPoint(Expr& root, CacheBlockInfo* info, bool is_write) { Expr producer = *(find_produce_read.begin()); CHECK(root.As()->schedule_block.As()); - CHECK(root.As()->schedule_block.As()->body.As()); - info->loc_block = root.As()->schedule_block.As()->body; + CHECK(root.As() + ->schedule_block.As() + ->body.As()); + info->loc_block = + root.As()->schedule_block.As()->body; for (int i = 0; i < (int)info->loc_block.As()->stmts.size(); ++i) { if (Contains(info->loc_block.As()->stmts[i], producer)) { info->loc_pos = i + 1; @@ -468,13 +539,15 @@ void FindInsertionPoint(Expr& root, CacheBlockInfo* info, bool is_write) { } } -const std::set CollectLoopsToSet(const std::vector& loops) { +const std::set CollectLoopsToSet( + const std::vector& loops) { std::set for_loops; for (auto& i : loops) { CHECK(i.As()) << "loops should be For node! Please check."; auto inserted = for_loops.insert(i); if (!inserted.second) { - LOG(FATAL) << "There should be no duplicate elements in loops! Please check."; + LOG(FATAL) + << "There should be no duplicate elements in loops! Please check."; } } return for_loops; @@ -483,7 +556,8 @@ const std::set CollectLoopsToSet(const std::vector& loops) // This function is used in Reorder schedule primitive. Since input loop // Expr(s) of Reorder doesn't give original for loop order, we have to // find the top (most outter) loop and bottom (most inner) among loop Expr(s) -std::pair GetBoundaryOfReorderRange(const std::set& loop_set) { +std::pair GetBoundaryOfReorderRange( + const std::set& loop_set) { Expr top = *loop_set.begin(); Expr bottom; std::set visited; @@ -499,7 +573,8 @@ std::pair GetBoundaryOfReorderRange(const std::set& // Then loop_i should be the new top if (visited.count(v_for)) { if (v_for != top) { - LOG(FATAL) << "Loops in GetBoundaryOfReorderRange is not a chain! Please check."; + LOG(FATAL) << "Loops in GetBoundaryOfReorderRange is not a chain! " + "Please check."; } top = loop_i; break; @@ -527,7 +602,9 @@ std::vector GetLoopsInRange(const Expr& top, const Expr& bottom) { CHECK(bottom.As()); for (auto loop_iter = top; loop_iter != bottom;) { Expr tmp = GetNextForLoop(loop_iter); - if (!tmp.defined()) LOG(FATAL) << "Loops in GetLoopsInReorderRange is not a chain! Please check."; + if (!tmp.defined()) + LOG(FATAL) + << "Loops in GetLoopsInReorderRange is not a chain! Please check."; chain.push_back(loop_iter); loop_iter = tmp; } @@ -576,7 +653,9 @@ Expr ConstructNewLoopChain(const std::vector& chain, // In each IfThenElse node, find the vars its condition depends on. for (auto& if_expr : if_nodes) { CHECK(if_expr.As()); - auto var_set = ir::CollectIRNodes(if_expr.As()->condition, [&](const Expr* x) { return x->as_var(); }); + auto var_set = + ir::CollectIRNodes(if_expr.As()->condition, + [&](const Expr* x) { return x->as_var(); }); std::set var_name_set; for (auto& i : var_set) var_name_set.insert(i.as_var()->name); condition_vars.push_back(var_name_set); @@ -608,12 +687,15 @@ Expr ConstructNewLoopChain(const std::vector& chain, Expr original_temp = temp; // Here we handle the IfThenElse nodes. for (int i = 0; i < static_cast(if_nodes.size()); ++i) { - if (condition_vars[i].count(original_temp.As()->loop_var->name)) { + if (condition_vars[i].count( + original_temp.As()->loop_var->name)) { Expr temp_body = temp.As()->body; if (temp_body.As() && temp_body.As()->stmts.size() == 1U) temp_body = temp_body.As()->stmts[0]; - temp.As()->body = IfThenElse::Make( - if_nodes[i].As()->condition, temp_body, if_nodes[i].As()->false_case); + temp.As()->body = + IfThenElse::Make(if_nodes[i].As()->condition, + temp_body, + if_nodes[i].As()->false_case); temp.As()->body = Block::Make({temp.As()->body}); if_nodes.erase(if_nodes.begin() + i); condition_vars.erase(condition_vars.begin() + i); @@ -625,7 +707,8 @@ Expr ConstructNewLoopChain(const std::vector& chain, } CHECK(new_loop.defined()); - // new_loop_chain, which represents the main loop chain, now is from top to bottom. + // new_loop_chain, which represents the main loop chain, now is from top to + // bottom. std::reverse(reordered_loop_chain.begin(), reordered_loop_chain.end()); // In the main loop chain, each loop's body only contains sub_loop or bottom @@ -655,7 +738,8 @@ Expr ConstructNewLoopChain(const std::vector& chain, // Construct the complete loop chain from origin loop top to bottom. CHECK_EQ(chain.size(), reordered_loop_chain.size()) - << "origin loop chain size not equals reordered requirement when ConstructNewLoopChain in Reorder"; + << "origin loop chain size not equals reordered requirement when " + "ConstructNewLoopChain in Reorder"; std::unordered_set origin_loop_var_names; Expr ret = new_loop; @@ -674,10 +758,11 @@ Expr ConstructNewLoopChain(const std::vector& chain, // because bottom loop's body stmts have been all added const ir::For* loop_in_chain = chain[i].As(); - ir::For* reordered_in_chain = reordered_loop_chain[i].As(); + ir::For* reordered_in_chain = reordered_loop_chain[i].As(); origin_loop_var_names.insert(loop_in_chain->loop_var->name); - CHECK_EQ(origin_loop_var_names.size(), i + 1) << "Duplicate loop var name in origin Chain during Reorder"; + CHECK_EQ(origin_loop_var_names.size(), i + 1) + << "Duplicate loop var name in origin Chain during Reorder"; const ir::Block* body_block = loop_in_chain->body.As(); @@ -690,7 +775,8 @@ Expr ConstructNewLoopChain(const std::vector& chain, std::vector stmts_after_loop; for (int j = 0; j < body_block->stmts.size(); ++j) { if (body_block->stmts[j].As() && - body_block->stmts[j].As()->loop_var->name == chain[i + 1].As()->loop_var->name) { + body_block->stmts[j].As()->loop_var->name == + chain[i + 1].As()->loop_var->name) { other_stmt_body_before_loop = false; continue; } @@ -704,32 +790,36 @@ Expr ConstructNewLoopChain(const std::vector& chain, // Find the chain that other body stmts shares with main loop chain std::vector reordered_indices; for (int j = 0; j < reordered_loop_chain.size(); ++j) { - if (origin_loop_var_names.count(reordered_loop_chain[j].As()->loop_var->name)) { + if (origin_loop_var_names.count( + reordered_loop_chain[j].As()->loop_var->name)) { reordered_indices.push_back(j); } } CHECK_EQ(reordered_indices.size(), origin_loop_var_names.size()) - << "Reordered chain loop var names doesn't match other stmt chain loop var names"; + << "Reordered chain loop var names doesn't match other stmt chain " + "loop var names"; // Add other stmts chain to root Block if other stmts exist if (!stmts_before_loop.empty()) { - Expr before_chain = ConstructOtherStmtChain(stmts_before_loop, reordered_loop_chain, reordered_indices); + Expr before_chain = ConstructOtherStmtChain( + stmts_before_loop, reordered_loop_chain, reordered_indices); if (ret.As() == nullptr) { ret = ir::Block::Make({ret}); } std::vector& inplace_stmts = ret.As()->stmts; - auto pos = inplace_stmts.begin() + add_other_chain_index; + auto pos = inplace_stmts.begin() + add_other_chain_index; inplace_stmts.insert(pos, before_chain); ++add_other_chain_index; } if (!stmts_after_loop.empty()) { - Expr after_chain = ConstructOtherStmtChain(stmts_after_loop, reordered_loop_chain, reordered_indices); + Expr after_chain = ConstructOtherStmtChain( + stmts_after_loop, reordered_loop_chain, reordered_indices); if (ret.As() == nullptr) { ret = ir::Block::Make({ret}); } std::vector& inplace_stmts = ret.As()->stmts; - auto pos = inplace_stmts.begin() + add_other_chain_index + 1; + auto pos = inplace_stmts.begin() + add_other_chain_index + 1; inplace_stmts.insert(pos, after_chain); } } @@ -745,25 +835,35 @@ std::vector GetProducers(const Expr& block, const Expr& root) { // collect all producers' tensor names std::set producer_tensor_names; - auto compute_body = block.As()->schedule_block.As()->body; - ir::CollectIRNodesWithoutTensor(compute_body, [&producer_tensor_names](const Expr* x) { - auto* load = x->As(); - if (load) { - producer_tensor_names.insert(load->tensor.as_tensor()->name); - return true; - } - return false; - }); + auto compute_body = block.As() + ->schedule_block.As() + ->body; + ir::CollectIRNodesWithoutTensor( + compute_body, [&producer_tensor_names](const Expr* x) { + auto* load = x->As(); + if (load) { + producer_tensor_names.insert(load->tensor.as_tensor()->name); + return true; + } + return false; + }); - // traverse each of other blocks and filter those ones which contain at least one producer tensor; - auto find_blocks = ir::CollectIRNodesWithoutTensor( - root, [&block, &root](const Expr* x) { return x->As() && *x != block && *x != root; }); + // traverse each of other blocks and filter those ones which contain at least + // one producer tensor; + auto find_blocks = + ir::CollectIRNodesWithoutTensor(root, [&block, &root](const Expr* x) { + return x->As() && *x != block && *x != root; + }); for (auto&& cur : find_blocks) { - auto* cur_block = cur.As()->schedule_block.As(); + auto* cur_block = cur.As() + ->schedule_block.As(); CHECK(cur_block) << "block result should be a ScheduleBlockRealize"; - auto find_stores = ir::CollectIRNodesWithoutTensor(cur_block->body, [&producer_tensor_names](const Expr* x) { - return x->As() && producer_tensor_names.count(x->As()->tensor.as_tensor()->name) > 0; - }); + auto find_stores = ir::CollectIRNodesWithoutTensor( + cur_block->body, [&producer_tensor_names](const Expr* x) { + return x->As() && + producer_tensor_names.count( + x->As()->tensor.as_tensor()->name) > 0; + }); if (!find_stores.empty()) producers.emplace_back(cur); } return producers; @@ -774,30 +874,49 @@ std::vector GetConsumers(const Expr& block, const Expr& root) { CHECK(root.As()); std::vector consumers; std::string block_tensor = GetTensor(block)->name; - auto find_block = ir::CollectIRNodesWithoutTensor( - root, [&](const Expr* x) { return x->As() && *x != block && *x != root; }); + auto find_block = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { + return x->As() && *x != block && *x != root; + }); for (auto& i : find_block) { - CHECK(i.As()->schedule_block.As()); - auto block_body = i.As()->schedule_block.As()->body; - auto find_load = ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) { - return x->As() && x->As()->tensor.as_tensor_ref()->name == block_tensor; - }); + CHECK(i.As() + ->schedule_block.As()); + auto block_body = i.As() + ->schedule_block.As() + ->body; + auto find_load = + ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) { + return x->As() && + x->As()->tensor.as_tensor_ref()->name == + block_tensor; + }); if (!find_load.empty()) consumers.emplace_back(i); } return consumers; } -void CheckComputeAtValidation(const Expr& block, const Expr& loop, const Expr& root) { +void CheckComputeAtValidation(const Expr& block, + const Expr& loop, + const Expr& root) { auto find_block = ir::CollectIRNodesWithoutTensor( - root, [&](const Expr* x) { return x->As() && *x == block; }, true); + root, + [&](const Expr* x) { + return x->As() && *x == block; + }, + true); CHECK(!find_block.empty()) << "Didn't find block in root!"; auto find_loop = ir::CollectIRNodesWithoutTensor( - root, [&](const Expr* x) { return x->As() && *x == loop; }, true); + root, + [&](const Expr* x) { return x->As() && *x == loop; }, + true); CHECK(!find_loop.empty()) << "Didn't find loop in root!"; auto find_block_in_loop = ir::CollectIRNodesWithoutTensor( - loop, [&](const Expr* x) { return x->As() && *x == block; }, true); + loop, + [&](const Expr* x) { + return x->As() && *x == block; + }, + true); CHECK(find_block_in_loop.empty()) << "loop should not be block's ancestor!"; } @@ -806,7 +925,8 @@ void InsertBlock(Expr& for_loop, const Expr& insertion, int index) { CHECK(for_loop.As()->body.As()); ir::Block* dst_block = for_loop.As()->body.As(); CHECK(index == -1 || index >= 0 && index < dst_block->stmts.size()) - << "index = " << index << ", it should be -1 or between [0, block stmts size)"; + << "index = " << index + << ", it should be -1 or between [0, block stmts size)"; if (index == -1) { dst_block->stmts.emplace_back(insertion); @@ -814,7 +934,8 @@ void InsertBlock(Expr& for_loop, const Expr& insertion, int index) { auto dst_it = dst_block->stmts.begin() + index; if (dst_it->As()) { auto* inserted_block = dst_it->As()->true_case.As(); - CHECK(inserted_block) << "the IfThenElse node to be inserted shuold contain a true_case block"; + CHECK(inserted_block) << "the IfThenElse node to be inserted shuold " + "contain a true_case block"; inserted_block->stmts.insert(inserted_block->stmts.begin(), insertion); } else { dst_block->stmts.insert(dst_it, insertion); @@ -823,74 +944,99 @@ void InsertBlock(Expr& for_loop, const Expr& insertion, int index) { } IterRange RangeUnion(const IterRange& range1, const IterRange& range2) { - Expr new_min = common::AutoSimplify(Min::Make(range1.min, range2.min)); + Expr new_min = common::AutoSimplify(Min::Make(range1.min, range2.min)); Expr new_extent = common::AutoSimplify( - common::AutoSimplify(Max::Make(range1.min + range1.extent, range2.min + range2.extent)) - new_min); + common::AutoSimplify( + Max::Make(range1.min + range1.extent, range2.min + range2.extent)) - + new_min); return IterRange(new_min, new_extent); } -std::vector CalculateRequiredRegions(const Expr& block, - const Expr& loop, - const Expr& root, - const std::vector& required_blocks, - bool is_store_provided) { - CHECK(block.As()) << "Param block should be a ir::ScheduleBlockRealize node"; +std::vector CalculateRequiredRegions( + const Expr& block, + const Expr& loop, + const Expr& root, + const std::vector& required_blocks, + bool is_store_provided) { + CHECK(block.As()) + << "Param block should be a ir::ScheduleBlockRealize node"; CHECK(loop.As()) << "Param loop should be a ir::For node"; std::set provided_nodes; if (is_store_provided) { - provided_nodes = ir::CollectIRNodesWithoutTensor(block, [&](const Expr* x) { return x->As(); }); + provided_nodes = ir::CollectIRNodesWithoutTensor( + block, [&](const Expr* x) { return x->As(); }); } else { - provided_nodes = ir::CollectIRNodesWithoutTensor(block, [&](const Expr* x) { return x->As(); }); + provided_nodes = ir::CollectIRNodesWithoutTensor( + block, [&](const Expr* x) { return x->As(); }); } std::vector required_buffer_range; - // deduce accessed regions of the provided tensor in block by itering each required block + // deduce accessed regions of the provided tensor in block by itering each + // required block for (const Expr& pro_node : provided_nodes) { - const std::string& provided_tensor_name = is_store_provided ? pro_node.As()->tensor.as_tensor()->name - : pro_node.As()->tensor.as_tensor()->name; + const std::string& provided_tensor_name = + is_store_provided ? pro_node.As()->tensor.as_tensor()->name + : pro_node.As()->tensor.as_tensor()->name; for (const Expr& req_block : required_blocks) { CHECK(req_block.As()); Expr block_body = - optim::IRCopy(req_block.As()->schedule_block.As()->body); - auto iter_vars = req_block.As()->schedule_block.As()->iter_vars; + optim::IRCopy(req_block.As() + ->schedule_block.As() + ->body); + auto iter_vars = req_block.As() + ->schedule_block.As() + ->iter_vars; auto iter_values = req_block.As()->iter_values; ReplaceExpr(&block_body, iter_vars, iter_values); - // Notice that we look for For nodes in loop's body instead of loop itself. + // Notice that we look for For nodes in loop's body instead of loop + // itself. auto find_loops = ir::CollectIRNodesWithoutTensor( - loop.As()->body, [&](const Expr* x) { return x->As() && Contains(*x, req_block); }); + loop.As()->body, [&](const Expr* x) { + return x->As() && Contains(*x, req_block); + }); // collect vars and their ranges of each loop under the input loop std::vector loop_vars; std::vector loop_ranges; for (const auto& for_loop : find_loops) { loop_vars.emplace_back(for_loop.As()->loop_var); - loop_ranges.emplace_back(for_loop.As()->min, for_loop.As()->extent); + loop_ranges.emplace_back(for_loop.As()->min, + for_loop.As()->extent); } std::set required_nodes; if (is_store_provided) { - required_nodes = ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) { - return x->As() && x->As()->tensor.as_tensor_ref()->name == provided_tensor_name; - }); + required_nodes = + ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) { + return x->As() && + x->As()->tensor.as_tensor_ref()->name == + provided_tensor_name; + }); } else { - required_nodes = ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) { - return x->As() && x->As()->tensor.as_tensor_ref()->name == provided_tensor_name; - }); + required_nodes = + ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) { + return x->As() && + x->As()->tensor.as_tensor_ref()->name == + provided_tensor_name; + }); } // deducing range by indices of each required node for (const Expr& req_node : required_nodes) { - const auto& indices = is_store_provided ? req_node.As()->indices : req_node.As()->indices; + const auto& indices = is_store_provided + ? req_node.As()->indices + : req_node.As()->indices; if (find_loops.empty()) { for (int i = 0; i < indices.size(); ++i) { if (i >= required_buffer_range.size()) required_buffer_range.emplace_back(indices[i], Expr(1)); else - required_buffer_range[i] = RangeUnion(required_buffer_range[i], IterRange(indices[i], Expr(1))); + required_buffer_range[i] = RangeUnion( + required_buffer_range[i], IterRange(indices[i], Expr(1))); } } else { for (int i = 0; i < indices.size(); ++i) { @@ -898,7 +1044,8 @@ std::vector CalculateRequiredRegions(const Expr& block, if (i >= required_buffer_range.size()) { required_buffer_range.emplace_back(std::move(range)); } else { - required_buffer_range[i] = RangeUnion(required_buffer_range[i], range); + required_buffer_range[i] = + RangeUnion(required_buffer_range[i], range); } } } @@ -907,21 +1054,30 @@ std::vector CalculateRequiredRegions(const Expr& block, } int iter_size = block.As()->iter_values.size(); - // maybe some dimensions are not accessed by consumers so we should append them + // maybe some dimensions are not accessed by consumers so we should append + // them if (iter_size > required_buffer_range.size()) { for (int i = required_buffer_range.size(); i < iter_size; ++i) { CHECK(block.As()->iter_values[i].as_var() || block.As()->iter_values[i].is_constant()); if (block.As()->iter_values[i].as_var()) { - auto find_for_loops = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { - return x->As() && x->As()->loop_var->name == - block.As()->iter_values[i].as_var_ref()->name; - }); + auto find_for_loops = + ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { + return x->As() && + x->As()->loop_var->name == + block.As() + ->iter_values[i] + .as_var_ref() + ->name; + }); CHECK_EQ(find_for_loops.size(), 1U); - required_buffer_range.emplace_back((*find_for_loops.begin()).As()->min, - (*find_for_loops.begin()).As()->extent); + required_buffer_range.emplace_back( + (*find_for_loops.begin()).As()->min, + (*find_for_loops.begin()).As()->extent); } else { - int cons = (int)block.As()->iter_values[i].is_constant(); + int cons = (int)block.As() + ->iter_values[i] + .is_constant(); required_buffer_range.emplace_back(Expr(cons), Expr(1)); } } @@ -929,9 +1085,12 @@ std::vector CalculateRequiredRegions(const Expr& block, return required_buffer_range; } -Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block, const Expr& root) { +Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block, + const Expr& root) { CHECK(schedule_block.As()); - auto compute_body = schedule_block.As()->schedule_block.As()->body; + auto compute_body = schedule_block.As() + ->schedule_block.As() + ->body; // 1. Check the schedule block to be inlined is not a reduce tensor. auto find_store = ir::CollectIRNodesWithoutTensor( compute_body, [&](const Expr* x) { return x->As(); }, true); @@ -942,21 +1101,28 @@ Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block, const E find_store = ir::CollectIRNodesWithoutTensor( root, [&](const Expr* x) { - return x->As() && (x->As()->tensor).as_tensor_ref()->name == tensor.as_tensor_ref()->name; + return x->As() && + (x->As()->tensor).as_tensor_ref()->name == + tensor.as_tensor_ref()->name; }, true); CHECK_EQ(find_store.size(), 1U); - // 3. Check there is no overlap between the buffers the schedule block reads and writes. - auto find_load = ir::CollectIRNodesWithoutTensor( - compute_body, [&](const Expr* x) { return x->As() && x->As()->tensor == tensor; }); + // 3. Check there is no overlap between the buffers the schedule block reads + // and writes. + auto find_load = + ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { + return x->As() && x->As()->tensor == tensor; + }); CHECK(find_load.empty()); return (*find_store.begin()); } -std::tuple CheckReverseComputeInlineValidationAndGetExprs(const Expr& schedule_block, - const Expr& root) { +std::tuple CheckReverseComputeInlineValidationAndGetExprs( + const Expr& schedule_block, const Expr& root) { CHECK(schedule_block.As()); - auto compute_body = schedule_block.As()->schedule_block.As()->body; + auto compute_body = schedule_block.As() + ->schedule_block.As() + ->body; // 1. Check the schedule block to be reverse inlined is not a reduce tensor. auto find_inlined_load = ir::CollectIRNodesWithoutTensor( compute_body, [&](const Expr* x) { return x->As(); }, true); @@ -968,17 +1134,24 @@ std::tuple CheckReverseComputeInlineValidationAndGetExprs(cons auto find_load = ir::CollectIRNodesWithoutTensor( root, [&](const Expr* x) { - return x->As() && (x->As()->tensor).as_tensor_ref()->name == tensor.as_tensor_ref()->name; + return x->As() && + (x->As()->tensor).as_tensor_ref()->name == + tensor.as_tensor_ref()->name; }, true); CHECK_EQ(find_load.size(), 1U); - // 3. Check there is no overlap between the buffers the schedule block reads and writes. - auto find_store = ir::CollectIRNodesWithoutTensor( - compute_body, [&](const Expr* x) { return x->As() && x->As()->tensor == tensor; }); + // 3. Check there is no overlap between the buffers the schedule block reads + // and writes. + auto find_store = + ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { + return x->As() && x->As()->tensor == tensor; + }); CHECK(find_store.empty()); // 4. Get store that will be inlined. - auto find_inlined_store = ir::CollectIRNodesWithoutTensor( - root, [&](const Expr* x) { return x->As() && x->As()->tensor == tensor; }); + auto find_inlined_store = + ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { + return x->As() && x->As()->tensor == tensor; + }); CHECK_EQ(find_inlined_store.size(), 1U); auto inlined_store = *find_inlined_store.begin(); // 5. Get target store. @@ -992,7 +1165,11 @@ std::tuple CheckReverseComputeInlineValidationAndGetExprs(cons bool ContainVar(const std::vector& exprs, const std::string& var_name) { for (auto& expr : exprs) { auto find_expr = ir::CollectIRNodesWithoutTensor( - expr, [&](const Expr* x) { return x->As<_Var_>() && x->As<_Var_>()->name == var_name; }, true); + expr, + [&](const Expr* x) { + return x->As<_Var_>() && x->As<_Var_>()->name == var_name; + }, + true); if (!find_expr.empty()) return true; } return false; @@ -1016,11 +1193,13 @@ std::unordered_map PrimeFactorize(int n) { return factors; } -std::vector SampleTile(utils::LinearRandomEngine::StateType* rand_seed, int n, int extent) { +std::vector SampleTile(utils::LinearRandomEngine::StateType* rand_seed, + int n, + int extent) { std::vector tile; while (n > 1) { std::unordered_map factors = PrimeFactorize(extent); - int product = 1; + int product = 1; for (auto& factor : factors) { if (factor.second >= 1) { int num = utils::SampleUniformInt(1, factor.second + 1, rand_seed); diff --git a/paddle/cinn/ir/ir_schedule_util.h b/paddle/cinn/ir/ir_schedule_util.h index c3bc3cc4ae95b..802a134d23bc4 100644 --- a/paddle/cinn/ir/ir_schedule_util.h +++ b/paddle/cinn/ir/ir_schedule_util.h @@ -39,11 +39,14 @@ struct CompExpr { // Self-defined operator to support std::set struct CompVar { - bool operator()(const Var& left, const Var& right) const { return left->name < right->name; } + bool operator()(const Var& left, const Var& right) const { + return left->name < right->name; + } }; struct MappingVarToExprMutator : public ir::IRMutator<> { - MappingVarToExprMutator(const std::map& replacing_map) : replacing_map_(replacing_map) {} + MappingVarToExprMutator(const std::map& replacing_map) + : replacing_map_(replacing_map) {} void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } @@ -76,8 +79,9 @@ struct FindLoopsVisitor { Visit(&(expr->As()->body)); father_loops.pop_back(); } else if (expr->As()) { - if (!expr->As()->iter_values.empty() && (*expr == block_)) { - result = father_loops; + if (!expr->As()->iter_values.empty() && + (*expr == block_)) { + result = father_loops; visit_end = true; return; } else { @@ -100,14 +104,16 @@ struct FindLoopsVisitor { }; /** - * \brief Given a ScheduleBlockRealize node, return the Store tensor in its body. + * \brief Given a ScheduleBlockRealize node, return the Store tensor in its + * body. * @param block The given ScheduleBlockRealize node * @return The Store tensor in block */ Tensor GetTensor(const Expr& block); struct FindBlocksVisitor { - FindBlocksVisitor(const std::string& block_name = "") : block_name_(block_name) {} + FindBlocksVisitor(const std::string& block_name = "") + : block_name_(block_name) {} std::vector operator()(const Expr* expr) { Visit(expr); @@ -122,7 +128,8 @@ struct FindBlocksVisitor { Visit(&(expr->As()->body)); } else if (expr->As()) { if (!expr->As()->iter_values.empty()) { - auto* schedule_block = expr->As()->schedule_block.As(); + auto* schedule_block = expr->As() + ->schedule_block.As(); if (block_name_.empty() || schedule_block->name == block_name_) { result.emplace_back(*expr); } @@ -149,7 +156,8 @@ struct CacheBlockInfo { Tensor write_tensor; /*! \brief The tensor allocation to be inserted into the block signature. */ Tensor alloc; - /*! \brief The AST node whose body is where the cache stage should be inserted. */ + /*! \brief The AST node whose body is where the cache stage should be + * inserted. */ Expr loc_block; /*! \brief The index to insert the cache_read/cache_write stage. */ int loc_pos; @@ -167,7 +175,8 @@ struct IterRange { }; /** - * \brief Given a ScheduleBlockRealize node, return the index-th Load tensor in its body. + * \brief Given a ScheduleBlockRealize node, return the index-th Load tensor in + * its body. * @param block The given ScheduleBlockRealize node * @param index The index of Load tensor * @return The index-th Load tensor in block @@ -182,7 +191,8 @@ Tensor GetReadTensor(const Expr& block, int index); int GetLoopExtent(const Expr& loop); /** - * \brief Given a vector of Exors, return whether they contain a var with specific name. + * \brief Given a vector of Exors, return whether they contain a var with + * specific name. * @param exprs The given vector of Exprs * @param var_name The name of specific var * @return Whether there is a Var with the same name as var_name @@ -219,20 +229,25 @@ Expr GetNextForLoop(const Expr& for_loop); std::vector GetIfThenElseInRange(const Expr& top, const Expr& bottom); /** - * Replace Vars in replaced to Exprs in candidates in source. Vars -> Exprs is one-to-one correspondence. + * Replace Vars in replaced to Exprs in candidates in source. Vars -> Exprs is + * one-to-one correspondence. * @param source The Expr we will implement the change. * @param replaced The Vars to be replaced. * @param candidates The Exprs to replace Vars in replaced. */ -void ReplaceExpr(Expr* source, const std::vector& replaced, const std::vector& candidates); +void ReplaceExpr(Expr* source, + const std::vector& replaced, + const std::vector& candidates); /** - * Validate the factors param of Split. We will check if factors are validate and change -1 to positive integer. + * Validate the factors param of Split. We will check if factors are validate + * and change -1 to positive integer. * @param factors The original factors. * @param total_extent The extent of the loop to be splitted. * @return return The valiated factors. */ -std::vector ValidateFactors(const std::vector& factors, int total_extent); +std::vector ValidateFactors(const std::vector& factors, + int total_extent); void CHECKRfactorValidation(const Expr& rf_loop, int rf_axis); @@ -245,20 +260,21 @@ void CHECKRfactorValidation(const Expr& rf_loop, int rf_axis); std::vector GetLoopsOfExpr(const Expr& expr, const Expr& root); /** - * Given an index Expr and all vars' range, return the accessed range in this indice. + * Given an index Expr and all vars' range, return the accessed range in this + * indice. * @param index The Expr of a specified indice. * @param iter_vars The vars in expr. * @param iter_range Each var's range. - * @return return an IterRange represents the accessed range of this indice, If it is not constant, return corresponding - * tensor's shape. + * @return return an IterRange represents the accessed range of this indice, If + * it is not constant, return corresponding tensor's shape. */ IterRange GetAccessedRange(const Expr& index, const std::vector& iter_vars, const std::vector& iter_ranges); /** - * Given a ScheduleBlockRealize, an AST root, a tensor and its tensor_indices, return the accessed buffer region of the - * tensor in block. + * Given a ScheduleBlockRealize, an AST root, a tensor and its tensor_indices, + * return the accessed buffer region of the tensor in block. * @param block The ScheduleBlockRealize. * @param tensor_indices The tensor's indices. * @param tensor The tensor. @@ -266,17 +282,19 @@ IterRange GetAccessedRange(const Expr& index, * @return return The accessed buffer region of the tensor in block. */ -std::vector CalculateTensorRegions(const Expr& block, - const std::vector& tensor_indices, - const Tensor& tensor, - const Expr& root); +std::vector CalculateTensorRegions( + const Expr& block, + const std::vector& tensor_indices, + const Tensor& tensor, + const Expr& root); /** * Return n-th access tensor in block * @param block The ScheduleBlockRealize. * @param index The index indicating which tensor we want to get. * @param is_write We want to get write tensor or read tensor. - * @return return The n-th access tensor in block. Should be ir::Store(is_write) or ir::Load(!is_write). + * @return return The n-th access tensor in block. Should be ir::Store(is_write) + * or ir::Load(!is_write). */ Expr GetNthAccessExpr(const Expr& block, int index, bool is_write); @@ -314,20 +332,23 @@ void FindInsertionPoint(Expr& root, CacheBlockInfo* info, bool is_write); * @param loops The given vector of For loops. * @return A set containing all the For loops in loops. */ -const std::set CollectLoopsToSet(const std::vector& loops); +const std::set CollectLoopsToSet( + const std::vector& loops); /** * \brief Given a set of For loops, return the boundary among them. * @param loop_set The given set of For loops. * @return A pair of the boundary among For loops.(The top For and bottom For) */ -std::pair GetBoundaryOfReorderRange(const std::set& loop_set); +std::pair GetBoundaryOfReorderRange( + const std::set& loop_set); /** * \brief Given two For loops, return all loops between them. * @param top The top For loop. * @param bottom The bottom For loop. - * @return A vector containing all For loops between the boundary, stored in ascending order. + * @return A vector containing all For loops between the boundary, stored in + * ascending order. */ std::vector GetLoopsInRange(const Expr& top, const Expr& bottom); @@ -361,13 +382,15 @@ std::vector GetConsumers(const Expr& block, const Expr& root); * \param loop The for node we want to put the block under in ComputeAt. * \param root The root ScheduleBlockRealize node of block and loop. */ -void CheckComputeAtValidation(const Expr& block, const Expr& loop, const Expr& root); +void CheckComputeAtValidation(const Expr& block, + const Expr& loop, + const Expr& root); /*! - * \brief Insert a new ScheduleBlockRealize in a loop's body(under its IfThenElse Node, if any) - * \param for_loop The for loop whose body we want to modify - * \param insertion The ScheduleBlockRealize we want to insert - * \param index The position index of the for_loop body `stmts` to be inserted: + * \brief Insert a new ScheduleBlockRealize in a loop's body(under its + * IfThenElse Node, if any) \param for_loop The for loop whose body we want to + * modify \param insertion The ScheduleBlockRealize we want to insert \param + * index The position index of the for_loop body `stmts` to be inserted: * - `index = -1` means inserted into the tail * - otherwise, it should be a index between [0, stmts size) */ @@ -376,22 +399,17 @@ void InsertBlock(Expr& for_loop, const Expr& insertion, int index = 0); /*! * \brief Make a union of two range. The detailed function is : * new_range.min = min(range1.min, range2.min) - * new_range.extent = max(range1.min + range1.extent, range2.min + range2.extent) - new_range.min - * Notice that the pair indicates a range's min and extent. - * \param range1 The first range - * \param range2 The second range - * \return The union of these two ranges + * new_range.extent = max(range1.min + range1.extent, range2.min + + * range2.extent) - new_range.min Notice that the pair indicates a + * range's min and extent. \param range1 The first range \param range2 The + * second range \return The union of these two ranges */ IterRange RangeUnion(const IterRange& range1, const IterRange& range2); /*! - * \brief Calculate the required buffer region given a block and its required blocks. - * For example, if block is : - * B[i0, j0] = A[i0, j0] - * loop is : - * for (i, 0, 64) { - * for (j, 0, 64) { - * C[i, j] = B[i, j] + * \brief Calculate the required buffer region given a block and its required + * blocks. For example, if block is : B[i0, j0] = A[i0, j0] loop is : for (i, 0, + * 64) { for (j, 0, 64) { C[i, j] = B[i, j] * } * } * And required_blocks is : @@ -403,29 +421,33 @@ IterRange RangeUnion(const IterRange& range1, const IterRange& range2); * \param block The ScheduleBlockRealize node begin required * \param loop The loop where we will insert the block under it * @param root The root of the whole AST. - * \param required_blocks vector of ScheduleBlockRealize nodes that require the block - * \param is_store_provided Whether Store nodes of the block provide the tensor, - * true means it is in compute_at case, otherwise false means in reverse_compuate_at case - * \return Each index's range of block's tensor. Indicating the buffer region being required. + * \param required_blocks vector of ScheduleBlockRealize nodes that require the + * block \param is_store_provided Whether Store nodes of the block provide the + * tensor, true means it is in compute_at case, otherwise false means in + * reverse_compuate_at case \return Each index's range of block's tensor. + * Indicating the buffer region being required. */ -std::vector CalculateRequiredRegions(const Expr& block, - const Expr& loop, - const Expr& root, - const std::vector& required_blocks, - bool is_store_provided = true); +std::vector CalculateRequiredRegions( + const Expr& block, + const Expr& loop, + const Expr& root, + const std::vector& required_blocks, + bool is_store_provided = true); -Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block, const Expr& root); +Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block, + const Expr& root); /*! - * \brief Check if the reverse compute inline validation passes for a given schedule block and root expression, - * and retrieve the store expression if so. - * Reverse compute inline validation ensures that the outputs of a loop nest are properly computed in reverse order. - * \param schedule_block The schedule block to check. - * \param root The root expression of the loop nest. - * \return A tuple containing the load that will be inlined, the store that will be inlined and the target store. + * \brief Check if the reverse compute inline validation passes for a given + * schedule block and root expression, and retrieve the store expression if so. + * Reverse compute inline validation ensures that the outputs of a loop nest are + * properly computed in reverse order. \param schedule_block The schedule block + * to check. \param root The root expression of the loop nest. \return A tuple + * containing the load that will be inlined, the store that will be inlined and + * the target store. */ -std::tuple CheckReverseComputeInlineValidationAndGetExprs(const Expr& schedule_block, - const Expr& root); +std::tuple CheckReverseComputeInlineValidationAndGetExprs( + const Expr& schedule_block, const Expr& root); /*! * \brief Get the prime factors of a number. @@ -443,6 +465,8 @@ std::unordered_map PrimeFactorize(int n); * \param n The number to be factorized. * \param dividend The dividend of the number. */ -std::vector SampleTile(utils::LinearRandomEngine::StateType* rand_seed, int n, int dividend); +std::vector SampleTile(utils::LinearRandomEngine::StateType* rand_seed, + int n, + int dividend); } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/ir_visitor.h b/paddle/cinn/ir/ir_visitor.h index 21ccbc23335c6..e5f8c21399603 100644 --- a/paddle/cinn/ir/ir_visitor.h +++ b/paddle/cinn/ir/ir_visitor.h @@ -73,7 +73,8 @@ struct IRVisitor : public IRVisitorBase { #undef __m }; -// std::set CollectIRNodes(Expr expr, std::function teller); +// std::set CollectIRNodes(Expr expr, std::function +// teller); bool operator==(Expr a, Expr b); bool operator!=(Expr a, Expr b); diff --git a/paddle/cinn/ir/layout.cc b/paddle/cinn/ir/layout.cc index 5113c065370be..f4e4585aa2145 100644 --- a/paddle/cinn/ir/layout.cc +++ b/paddle/cinn/ir/layout.cc @@ -25,8 +25,10 @@ void Layout::Verify() { for (auto& axis : axes_) { CHECK_EQ(axis->name.size(), 1U); auto axis_name = axis->name[0]; - CHECK((axis_name >= 'A' && axis_name <= 'Z') || (axis_name >= 'a' && axis_name <= 'z')); - CHECK(axis_names_.find(axis_name) == axis_names_.npos) << axis_name << " has already exsit."; + CHECK((axis_name >= 'A' && axis_name <= 'Z') || + (axis_name >= 'a' && axis_name <= 'z')); + CHECK(axis_names_.find(axis_name) == axis_names_.npos) + << axis_name << " has already exsit."; axis_names_ += axis_name; } int offset = 'A' - 'a'; @@ -46,12 +48,14 @@ Layout::Layout(const std::string& name) { std::vector axes; for (char c : name) { if (c >= 'A' && c <= 'Z') { - CHECK_EQ(factor, 0) << "Invalid factor " << factor << " before primal axis " << c; + CHECK_EQ(factor, 0) << "Invalid factor " << factor + << " before primal axis " << c; axes.push_back(ir::Var(std::string(1, c))); } else if (c >= '0' && c <= '9') { factor = 10 * factor + c - '0'; } else if (c >= 'a' && c <= 'z') { - CHECK_GT(factor, 0) << "Invalid factor " << factor << " for sub-axis " << c; + CHECK_GT(factor, 0) << "Invalid factor " << factor << " for sub-axis " + << c; axes.push_back(ir::Var(factor, std::string(1, c))); factor = 0; } else { diff --git a/paddle/cinn/ir/layout.h b/paddle/cinn/ir/layout.h index f71c6823fd20f..18ccb682ff57b 100644 --- a/paddle/cinn/ir/layout.h +++ b/paddle/cinn/ir/layout.h @@ -28,7 +28,10 @@ class Layout { std::string axis_names_; std::vector axes_; - Layout(const std::string& name, const std::vector& axes) : name_(name), axes_(axes) { Verify(); } + Layout(const std::string& name, const std::vector& axes) + : name_(name), axes_(axes) { + Verify(); + } explicit Layout(const std::string& name); diff --git a/paddle/cinn/ir/lowered_func.cc b/paddle/cinn/ir/lowered_func.cc index d31a959ffeecd..7505a5647f5fe 100644 --- a/paddle/cinn/ir/lowered_func.cc +++ b/paddle/cinn/ir/lowered_func.cc @@ -38,17 +38,19 @@ namespace ir { using common::bfloat16; using common::float16; -const _LoweredFunc_* LoweredFunc::operator->() const { return As<_LoweredFunc_>(); } +const _LoweredFunc_* LoweredFunc::operator->() const { + return As<_LoweredFunc_>(); +} _LoweredFunc_* LoweredFunc::operator->() { return As<_LoweredFunc_>(); } LoweredFunc _LoweredFunc_::Make(const std::string& name, const std::vector& args, const Expr& body, const std::vector& temp_bufs) { - auto* n = make_shared<_LoweredFunc_>(); - n->name = name; - n->args = args; - n->body = body; + auto* n = make_shared<_LoweredFunc_>(); + n->name = name; + n->args = args; + n->body = body; n->temp_bufs = temp_bufs; n->CheckValid(); @@ -68,22 +70,25 @@ LoweredFunc _LoweredFunc_::Make(const std::string& name, void _LoweredFunc_::CheckValid() const { // check there is at least one output int out_count = 0; - int in_count = 0; + int in_count = 0; for (auto& arg : args) { in_count += arg.is_input(); out_count += arg.is_output(); } - CHECK_GT(out_count, 0) << "At least one output argument is needed for a function\n" << body; + CHECK_GT(out_count, 0) + << "At least one output argument is needed for a function\n" + << body; } std::vector _LoweredFunc_::expr_fields() { return {&body}; } std::vector _LoweredFunc_::expr_fields() const { return {&body}; } void _LoweredFunc_::PrepareCudaAxisInfoFromBody() { - std::set bound_for_exprs = ir::CollectIRNodes(body, [](const Expr* expr) { - const ir::For* for_expr = expr->As(); - return for_expr != nullptr && for_expr->is_binded(); - }); + std::set bound_for_exprs = + ir::CollectIRNodes(body, [](const Expr* expr) { + const ir::For* for_expr = expr->As(); + return for_expr != nullptr && for_expr->is_binded(); + }); if (bound_for_exprs.empty()) { device_api = ir::DeviceAPI::GPU; @@ -97,9 +102,11 @@ void _LoweredFunc_::PrepareCudaAxisInfoFromBody() { for (const Expr& expr : bound_for_exprs) { const ir::For* for_expr = expr.As(); if (for_expr->for_type() == ir::ForType::GPUBlock) { - cuda_axis_info.set_grid_dim(for_expr->bind_info().offset, for_expr->extent.as_int32()); + cuda_axis_info.set_grid_dim(for_expr->bind_info().offset, + for_expr->extent.as_int32()); } else if (for_expr->for_type() == ir::ForType::GPUThread) { - cuda_axis_info.set_block_dim(for_expr->bind_info().offset, for_expr->extent.as_int32()); + cuda_axis_info.set_block_dim(for_expr->bind_info().offset, + for_expr->extent.as_int32()); } } device_api = ir::DeviceAPI::GPU; @@ -107,16 +114,23 @@ void _LoweredFunc_::PrepareCudaAxisInfoFromBody() { } void _LoweredFunc_::PrepareAllocOutputBufferExprs() { - CHECK(alloc_output_buffer_exprs.empty()) << "duplicate prepare the allocate buffer for outputs"; + CHECK(alloc_output_buffer_exprs.empty()) + << "duplicate prepare the allocate buffer for outputs"; std::set buffer_names; for (auto& arg : args) { if (arg.is_output()) { - CHECK(arg.type().valid()) << "argument [" << arg.name() << "]'s type should be set"; - if (arg.is_buffer() && !buffer_names.count(arg.name())) { // only buffer need allocation. - buffer_names.insert(arg.name()); // Avoid duplicate + CHECK(arg.type().valid()) + << "argument [" << arg.name() << "]'s type should be set"; + if (arg.is_buffer() && + !buffer_names.count(arg.name())) { // only buffer need allocation. + buffer_names.insert(arg.name()); // Avoid duplicate alloc_output_buffer_exprs.push_back( - Alloc::Make(arg.buffer_arg(), arg.buffer_arg()->type(), arg.buffer_arg()->shape, Expr(), Expr())); + Alloc::Make(arg.buffer_arg(), + arg.buffer_arg()->type(), + arg.buffer_arg()->shape, + Expr(), + Expr())); } } } @@ -126,7 +140,8 @@ std::vector _LoweredFunc_::PrepareAllocTempBufferExprs() const { std::vector alloc_temp_buffer_exprs; for (auto& temp_buf : temp_bufs) { if (!temp_buf->shape.empty() && temp_buf->type() != Void()) { - alloc_temp_buffer_exprs.push_back(Alloc::Make(temp_buf, temp_buf->type(), temp_buf->shape, Expr(), Expr())); + alloc_temp_buffer_exprs.push_back(Alloc::Make( + temp_buf, temp_buf->type(), temp_buf->shape, Expr(), Expr())); } } return alloc_temp_buffer_exprs; @@ -146,10 +161,13 @@ std::vector _LoweredFunc_::PrepareCreateTempBufferExprs() const { std::vector create_temp_buffer_exprs; for (auto& temp_buf : temp_bufs) { if (!temp_buf->shape.empty() && temp_buf->type() != Void()) { - auto expr = ir::intrinsics::BufferCreate::Make(temp_buf); - auto buffer_ptr_type = Type().set_customized_type(common::customized_type::kbuffer_t).set_cpp_handle(); - Var variable = ir::_Var_::Make(temp_buf->name, buffer_ptr_type); - expr = ir::Let::Make(variable, expr); + auto expr = ir::intrinsics::BufferCreate::Make(temp_buf); + auto buffer_ptr_type = + Type() + .set_customized_type(common::customized_type::kbuffer_t) + .set_cpp_handle(); + Var variable = ir::_Var_::Make(temp_buf->name, buffer_ptr_type); + expr = ir::Let::Make(variable, expr); create_temp_buffer_exprs.push_back(expr); } } @@ -163,21 +181,25 @@ std::vector _LoweredFunc_::CudaPrepareAllocTempBufferExprs() const { temp_buf->name = temp_buf->name.substr(1); } if (!temp_buf->shape.empty() && temp_buf->type() != Void()) { - alloc_output_buffer_exprs.push_back(Alloc::Make(temp_buf, temp_buf->type(), temp_buf->shape, Expr(), Expr())); + alloc_output_buffer_exprs.push_back(Alloc::Make( + temp_buf, temp_buf->type(), temp_buf->shape, Expr(), Expr())); } } return alloc_output_buffer_exprs; } void _LoweredFunc_::PrepareDeallocOutputBufferExprs() { - CHECK(dealloc_output_buffer_exprs.empty()) << "duplicate prepare the allocate buffer for outputs"; + CHECK(dealloc_output_buffer_exprs.empty()) + << "duplicate prepare the allocate buffer for outputs"; std::set buffer_names; for (auto& arg : args) { if (arg.is_output()) { - CHECK(arg.type().valid()) << "argument [" << arg.name() << "]'s type should be set"; - if (arg.is_buffer() && !buffer_names.count(arg.name())) { // only buffer need allocation. - buffer_names.insert(arg.name()); // Avoid duplicate + CHECK(arg.type().valid()) + << "argument [" << arg.name() << "]'s type should be set"; + if (arg.is_buffer() && + !buffer_names.count(arg.name())) { // only buffer need allocation. + buffer_names.insert(arg.name()); // Avoid duplicate dealloc_output_buffer_exprs.push_back(Free::Make(arg.buffer_arg())); } } @@ -193,7 +215,9 @@ void _LoweredFunc_::PrepareBufferCastExprs(bool with_expr_gen_tensor) { write_teller.Collect(&body); auto tensors = CollectAllTensorReference(with_expr_gen_tensor); - std::sort(tensors.begin(), tensors.end(), [](const Tensor& a, const Tensor& b) { return a->name < b->name; }); + std::sort(tensors.begin(), + tensors.end(), + [](const Tensor& a, const Tensor& b) { return a->name < b->name; }); VLOG(3) << "Function used " << tensors.size() << " buffers"; for (auto& tensor : tensors) { @@ -202,17 +226,20 @@ void _LoweredFunc_::PrepareBufferCastExprs(bool with_expr_gen_tensor) { if (!tensor->buffer.defined()) continue; Type value_type = tensor->type().ElementOf(); - bool is_const = !write_teller.IsWrite(tensor->name); + bool is_const = !write_teller.IsWrite(tensor->name); value_type.set_cpp_handle(); value_type.set_cpp_const(is_const); Var variable = _Var_::Make(tensor->name, value_type); - Expr body = is_const ? ir::intrinsics::BufferGetDataConstHandle::Make(tensor->buffer) - : ir::intrinsics::BufferGetDataHandle::Make(tensor->buffer); + Expr body = + is_const + ? ir::intrinsics::BufferGetDataConstHandle::Make(tensor->buffer) + : ir::intrinsics::BufferGetDataHandle::Make(tensor->buffer); - Type target_type = is_const ? tensor->buffer->dtype.PointerOf().ConstOf() : tensor->buffer->dtype.PointerOf(); - body = ir::Cast::Make(target_type, body); - auto let = Let::Make(variable, body); + Type target_type = is_const ? tensor->buffer->dtype.PointerOf().ConstOf() + : tensor->buffer->dtype.PointerOf(); + body = ir::Cast::Make(target_type, body); + auto let = Let::Make(variable, body); buffer_data_cast_exprs.push_back(let); } @@ -229,7 +256,9 @@ std::vector _LoweredFunc_::CudaAliasVarExprs() const { write_teller.Collect(&body); auto tensors = CollectAllTensorReference(); - std::sort(tensors.begin(), tensors.end(), [](const Tensor& a, const Tensor& b) { return a->name < b->name; }); + std::sort(tensors.begin(), + tensors.end(), + [](const Tensor& a, const Tensor& b) { return a->name < b->name; }); for (auto& tensor : tensors) { auto* node = tensor.As(); @@ -237,15 +266,16 @@ std::vector _LoweredFunc_::CudaAliasVarExprs() const { if (!tensor->buffer.defined()) { continue; } - if (tensor->name == tensor->buffer->name.substr(1) || args_buffer.count(tensor->buffer->name) == 0) { + if (tensor->name == tensor->buffer->name.substr(1) || + args_buffer.count(tensor->buffer->name) == 0) { continue; } Type value_type = tensor->type().ElementOf(); - bool is_const = !write_teller.IsWrite(tensor->name); + bool is_const = !write_teller.IsWrite(tensor->name); value_type.set_cpp_handle(); value_type.set_cpp_const(is_const); Var variable = _Var_::Make(tensor->name, value_type); - Var body = Var(tensor->buffer->name.substr(1), value_type); + Var body = Var(tensor->buffer->name.substr(1), value_type); auto let = Let::Make(variable, body); @@ -256,22 +286,31 @@ std::vector _LoweredFunc_::CudaAliasVarExprs() const { void _LoweredFunc_::PrepareArgumentExprs() { // Seems a CINN func. - if (args.front().is_var() && args.front().var_arg()->type() == type_of()) return; + if (args.front().is_var() && + args.front().var_arg()->type() == type_of()) + return; // type of `void*` - auto void_ptr_array_type = Type().with_type(Type::type_t::Void).set_cpp_handle(); + auto void_ptr_array_type = + Type().with_type(Type::type_t::Void).set_cpp_handle(); // type of `cinn_buffer_t*` - auto buffer_ptr_type = Type().set_customized_type(common::customized_type::kbuffer_t).set_cpp_handle(); + auto buffer_ptr_type = + Type() + .set_customized_type(common::customized_type::kbuffer_t) + .set_cpp_handle(); // type of `const cinn_buffer_t*` auto const_buffer_ptr_type = buffer_ptr_type.with_cpp_const(); CHECK(!buffer_ptr_type.is_cpp_const()); Var args_passed_in("_args", type_of()); - auto pod_value_ptr = common::CastIfNeeded(args_passed_in, type_of()); + auto pod_value_ptr = + common::CastIfNeeded(args_passed_in, type_of()); if (FLAGS_cinn_runtime_display_debug_info) { argument_prepare_exprs.push_back(runtime::IntrinsicCall( - Void(), runtime::intrinsic::print_debug_args_repr, {pod_value_ptr, common::make_const(Int(32), args.size())})); + Void(), + runtime::intrinsic::print_debug_args_repr, + {pod_value_ptr, common::make_const(Int(32), args.size())})); } /* @@ -282,8 +321,8 @@ void _LoweredFunc_::PrepareArgumentExprs() { * int M = (int)arg[2]; */ - // We just has two kinds of argument types, first is `cinn_buffer_t*`, second is `const cinn_buffer_t*`, do not need a - // `any` type support currently. + // We just has two kinds of argument types, first is `cinn_buffer_t*`, second + // is `const cinn_buffer_t*`, do not need a `any` type support currently. for (int i = 0; i < args.size(); i++) { auto& arg = args[i]; // cast arg to cinn_pod_value_t* @@ -298,7 +337,7 @@ void _LoweredFunc_::PrepareArgumentExprs() { if (arg.is_buffer()) { auto buffer_type = is_const ? const_buffer_ptr_type : buffer_ptr_type; - _arg = Var(arg.name(), buffer_type); + _arg = Var(arg.name(), buffer_type); } else if (arg.is_var()) { _arg = Var(arg.name(), arg.var_arg()->type()); } else { @@ -310,35 +349,50 @@ void _LoweredFunc_::PrepareArgumentExprs() { Expr pod_cast_expr; if (arg.is_buffer()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = ir::intrinsics::PodValueToX::Make( + load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else if (arg.type() == type_of()) { - pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else { LOG(ERROR) << "Not supported type [" << arg.type() << "]"; CINN_NOT_IMPLEMENTED @@ -350,11 +404,15 @@ void _LoweredFunc_::PrepareArgumentExprs() { } } -std::vector _LoweredFunc_::CollectAllTensorReference(bool with_expr_gen_tensor) const { +std::vector _LoweredFunc_::CollectAllTensorReference( + bool with_expr_gen_tensor) const { std::set tensor_exprs = with_expr_gen_tensor - ? ir::CollectIRNodes(body, [](const Expr* expr) { return expr->As(); }) - : ir::CollectIRNodesWithoutTensor(body, [](const Expr* expr) { return expr->As(); }); + ? ir::CollectIRNodes( + body, [](const Expr* expr) { return expr->As(); }) + : ir::CollectIRNodesWithoutTensor(body, [](const Expr* expr) { + return expr->As(); + }); std::vector tensors; // remove the duplicate tensor by their name. @@ -429,8 +487,10 @@ std::string Argument::human_readable() const { } std::ostream& operator<<(std::ostream& os, const CudaAxisInfo& x) { - os << ""; - os << ""; + os << ""; + os << ""; return os; } @@ -457,7 +517,7 @@ int CudaAxisInfo::block_dim(int offset) const { void CudaAxisInfo::ExtendWith(const CudaAxisInfo& other) { set_valid(true); for (int i = 0; i < 3; i++) { - grid_dims_[i] = std::max(grid_dims_[i], other.grid_dims_[i]); + grid_dims_[i] = std::max(grid_dims_[i], other.grid_dims_[i]); block_dims_[i] = std::max(block_dims_[i], other.block_dims_[i]); } } diff --git a/paddle/cinn/ir/lowered_func.h b/paddle/cinn/ir/lowered_func.h index 3efc7cfc41254..03ffacad817bd 100755 --- a/paddle/cinn/ir/lowered_func.h +++ b/paddle/cinn/ir/lowered_func.h @@ -26,8 +26,8 @@ namespace ir { class _LoweredFunc_; /** - * A struct representing an argument to a lowered function. Used for specifying the function signature of generated - * code. + * A struct representing an argument to a lowered function. Used for specifying + * the function signature of generated code. */ struct Argument { //! Input or output. @@ -39,7 +39,8 @@ struct Argument { explicit Argument(const ir::Buffer& buffer, IO io = IO::kInput); explicit Argument(const ir::Var& var, IO io = IO::kInput); - //! Set the buffer argument, all the buffer information are stored in ir::Buffer. + //! Set the buffer argument, all the buffer information are stored in + //! ir::Buffer. void set_buffer(const ir::Buffer& x); //! Set the var argument. @@ -128,8 +129,8 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> { //! The Arguments used in the body of the function. std::vector args; - //! Temporary buffers(as output), these buffers will not appear in the function's argument list, but will be used in - //! the body. + //! Temporary buffers(as output), these buffers will not appear in the + //! function's argument list, but will be used in the body. std::vector temp_bufs; //! Body of this function. @@ -140,11 +141,13 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> { CudaAxisInfo cuda_axis_info; /** - * The output buffer will be resized to the size required, we leave all the expression here. - * The allocation and deallocation expressions will insert into the head and tail of the function's body. It supports - * lazy allocation/deallocation if the corresponding intristic methods support. + * The output buffer will be resized to the size required, we leave all the + * expression here. The allocation and deallocation expressions will insert + * into the head and tail of the function's body. It supports lazy + * allocation/deallocation if the corresponding intristic methods support. * - * Currently, we assume that all the input and output buffers should locate in heap, no other memory type is allowed. + * Currently, we assume that all the input and output buffers should locate in + * heap, no other memory type is allowed. */ // @{ std::vector alloc_output_buffer_exprs; @@ -190,8 +193,10 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> { void PrepareArgumentExprs(); //! Get all the Buffers the function body references. - //! NOTE it will return the buffers with duplicates removed(by comparing their name). - std::vector CollectAllTensorReference(bool with_expr_gen_tensor = true) const; + //! NOTE it will return the buffers with duplicates removed(by comparing their + //! name). + std::vector CollectAllTensorReference( + bool with_expr_gen_tensor = true) const; }; } // namespace ir diff --git a/paddle/cinn/ir/module.cc b/paddle/cinn/ir/module.cc index 85d846e2c0fbe..acc04bfad8e19 100644 --- a/paddle/cinn/ir/module.cc +++ b/paddle/cinn/ir/module.cc @@ -31,10 +31,12 @@ void Module::Builder::AddFunction(ir::LoweredFunc func) { } void Module::Builder::AddBuffer(ir::Buffer buffer) { - CHECK(buffer->target.defined()) << "buffer [" << buffer->name << "]'s target is undefined"; - if (std::find_if(module_->buffers.begin(), module_->buffers.end(), [&](const Expr &x) { - return x.as_buffer()->name == buffer->name; - }) == std::end(module_->buffers)) { + CHECK(buffer->target.defined()) + << "buffer [" << buffer->name << "]'s target is undefined"; + if (std::find_if( + module_->buffers.begin(), module_->buffers.end(), [&](const Expr &x) { + return x.as_buffer()->name == buffer->name; + }) == std::end(module_->buffers)) { module_->buffers.push_back(buffer); if (module_->target.arch == Target::Arch::X86) { module_->buffers.back().as_buffer()->data_alignment = 32; diff --git a/paddle/cinn/ir/module.h b/paddle/cinn/ir/module.h index df47afab3f3fb..291e1eada98f4 100644 --- a/paddle/cinn/ir/module.h +++ b/paddle/cinn/ir/module.h @@ -35,8 +35,9 @@ namespace ir { class Module : public ir::IrNodeRef { public: struct Builder { - Builder(const std::string& name, const Target& target) : module_(common::make_shared()) { - module_->name = name; + Builder(const std::string& name, const Target& target) + : module_(common::make_shared()) { + module_->name = name; module_->target = target; } diff --git a/paddle/cinn/ir/operation.cc b/paddle/cinn/ir/operation.cc index 967adf8a8d42f..44b1af64fe6b0 100644 --- a/paddle/cinn/ir/operation.cc +++ b/paddle/cinn/ir/operation.cc @@ -21,9 +21,11 @@ namespace cinn { namespace ir { -Operation PlaceholderOp::Make(const std::string &name, const std::vector &shape, Type dtype) { - auto n = make_shared(); - n->name = name; +Operation PlaceholderOp::Make(const std::string &name, + const std::vector &shape, + Type dtype) { + auto n = make_shared(); + n->name = name; n->shape = shape; n->set_type(dtype); return Operation(n); @@ -40,36 +42,40 @@ Operation ComputeOp::Make(const std::string &name, const std::vector &reduce_axis, const std::map &attrs, const std::string &tag) { - auto n = make_shared(); - n->name = name; + auto n = make_shared(); + n->name = name; n->producer_fn = handle; - n->shape = domain; + n->shape = domain; n->reduce_axis = reduce_axis; - n->tag = tag; - n->attrs = attrs; - auto axis = common::GenDefaultAxis(domain.size()); + n->tag = tag; + n->attrs = attrs; + auto axis = common::GenDefaultAxis(domain.size()); std::vector _axis; for (auto &x : axis) _axis.push_back(x); - n->body = {handle(_axis)}; + n->body = {handle(_axis)}; n->reduce_axis = reduce_axis; return Operation(n); } Operation CallOp::Make(const std::string &call_target, Expr call_op) { - auto n = make_shared(); + auto n = make_shared(); n->call_expr = call_op; return Operation(n); } -Operation PrecedingViewOp::Make(const Tensor &tensor, int preceding_axis) { return Operation(); } +Operation PrecedingViewOp::Make(const Tensor &tensor, int preceding_axis) { + return Operation(); +} -const char *PrecedingViewOp::func_type() const { return PrecedingViewOp::__func_type__; } +const char *PrecedingViewOp::func_type() const { + return PrecedingViewOp::__func_type__; +} const char *CallOp::func_type() const { return __func_type__; } -const char *ComputeOp::__func_type__ = "compute_op"; +const char *ComputeOp::__func_type__ = "compute_op"; const char *PlaceholderOp::__func_type__ = "placeholder_op"; -const char *CallOp::__func_type__ = "call_op"; +const char *CallOp::__func_type__ = "call_op"; const std::string &CallOp::target() const { auto *call = call_expr.As(); diff --git a/paddle/cinn/ir/operation.h b/paddle/cinn/ir/operation.h index c1aad25295e54..651c2a9a9dc5c 100644 --- a/paddle/cinn/ir/operation.h +++ b/paddle/cinn/ir/operation.h @@ -35,7 +35,9 @@ struct PlaceholderOp : public _Operation_ { //! The data type of the input. Type dtype; - static Operation Make(const std::string &name, const std::vector &shape, Type dtype); + static Operation Make(const std::string &name, + const std::vector &shape, + Type dtype); const char *func_type() const override; @@ -53,7 +55,8 @@ struct CallOp : public _Operation_ { const std::vector &write_args() const; std::vector args() const; - //! A reference to the target LoweredFunc if this CallOp calls an generated LoweredFunc. + //! A reference to the target LoweredFunc if this CallOp calls an generated + //! LoweredFunc. Expr func; // the offset int the tuple of return values. @@ -117,9 +120,9 @@ struct ComputeOp : public _Operation_ { ComputeOp::handle_t handle, const std::vector &shape, const std::vector &domain, - const std::vector &reduce_axis = {}, + const std::vector &reduce_axis = {}, const std::map &attrs = {}, - const std::string &tag = ""); + const std::string &tag = ""); const char *func_type() const override; diff --git a/paddle/cinn/ir/registry.cc b/paddle/cinn/ir/registry.cc index 03f1c50ed752f..2467329f79ae8 100644 --- a/paddle/cinn/ir/registry.cc +++ b/paddle/cinn/ir/registry.cc @@ -28,7 +28,7 @@ struct Registry::Manager { std::map functions; private: - Manager() = default; + Manager() = default; Manager(const Manager &) = delete; void operator=(Manager &) = delete; }; @@ -45,14 +45,16 @@ Registry &Registry::SetBody(lang::PackedFunc::body_t f) { Registry::Registry(const std::string &name) : name_(name) {} -/*static*/ Registry &Registry::Register(const std::string &name, bool can_override) { +/*static*/ Registry &Registry::Register(const std::string &name, + bool can_override) { auto *manager = Registry::Manager::Global(); std::lock_guard lock(manager->mu); if (manager->functions.count(name)) { - CHECK(can_override) << "Global PackedFunc[" << name << "] is already exists"; + CHECK(can_override) << "Global PackedFunc[" << name + << "] is already exists"; } - auto *r = new Registry(name); + auto *r = new Registry(name); manager->functions[name] = r; return *r; } diff --git a/paddle/cinn/ir/schedule_desc.cc b/paddle/cinn/ir/schedule_desc.cc index 643423b741e05..cd8dea2fa8280 100644 --- a/paddle/cinn/ir/schedule_desc.cc +++ b/paddle/cinn/ir/schedule_desc.cc @@ -27,7 +27,8 @@ namespace cinn { namespace ir { -// ------ Following codes are about `Apply` functions registry of variaous types of ScheduleDesc::Step +// ------ Following codes are about `Apply` functions registry of variaous types +// of ScheduleDesc::Step class PackedStepContext; // uniformed function prototype of a scheduling operation in IRSchedule using StepApplyFunc = std::vector (*)(PackedStepContext*); @@ -55,7 +56,9 @@ class StepKindInfo { } // execute the Apply function of this type - std::vector Apply(PackedStepContext* context) const { return apply_func_(context); } + std::vector Apply(PackedStepContext* context) const { + return apply_func_(context); + } private: friend class PackedStepContext; @@ -74,11 +77,14 @@ class StepKindRegistry : public Registry { CINN_DISALLOW_COPY_AND_ASSIGN(StepKindRegistry); }; -// PackedStepContext is the param of a uniformed `Apply` function, which is used to be an -// auxiliary structure to interact with in/out arguments of the original scheduling function in IRSchedule +// PackedStepContext is the param of a uniformed `Apply` function, which is used +// to be an auxiliary structure to interact with in/out arguments of the +// original scheduling function in IRSchedule class PackedStepContext { public: - explicit PackedStepContext(const ScheduleDesc::Step& desc, const StepKindInfo* step_kind, IRSchedule* schedule) + explicit PackedStepContext(const ScheduleDesc::Step& desc, + const StepKindInfo* step_kind, + IRSchedule* schedule) : ir_schedule_(schedule) { Build(desc, step_kind); } @@ -111,7 +117,8 @@ class PackedStepContext { try { return absl::get(attrs_.at(idx)); } catch (absl::bad_variant_access& ex) { - LOG(FATAL) << "Attribute cast error, idx:" << idx << ", get tpye:" << typeid(AttrType).name() + LOG(FATAL) << "Attribute cast error, idx:" << idx + << ", get tpye:" << typeid(AttrType).name() << ", real index:" << attrs_.at(idx).index(); throw ex; } @@ -125,7 +132,9 @@ class PackedStepContext { auto arg_it = desc.inputs.find(param_name); CHECK(arg_it != desc.inputs.end()) << "Can't find param:" << param_name; auto&& args = arg_it->second; - inputs_.insert(inputs_.end(), std::make_move_iterator(args.begin()), std::make_move_iterator(args.end())); + inputs_.insert(inputs_.end(), + std::make_move_iterator(args.begin()), + std::make_move_iterator(args.end())); input_range_.emplace_back(input_idx, input_idx + args.size()); input_idx += args.size(); } @@ -134,7 +143,8 @@ class PackedStepContext { size_t attr_idx = 0; for (auto&& attr_name : step_kind->attrs_) { auto attr_it = desc.attrs.find(attr_name); - CHECK(attr_it != desc.attrs.end()) << "Can't find attribute:" << attr_name; + CHECK(attr_it != desc.attrs.end()) + << "Can't find attribute:" << attr_name; attrs_.emplace_back(attr_it->second); ++attr_idx; } @@ -146,17 +156,19 @@ class PackedStepContext { std::vector attrs_; }; -#define CINN_SPECIALIZE_ApplyCallHelper(attr_type) \ - template \ - struct ApplyCallHelper { \ - template \ - static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { \ - using rf_attr_type = std::remove_reference::type; \ - using rc_attr_type = std::remove_const::type; \ - const auto& arg = ctx->AttrAt(attr_idx); \ - return ApplyCallHelper::template Apply( \ - ctx, std::forward(pargs)..., arg); \ - } \ +#define CINN_SPECIALIZE_ApplyCallHelper(attr_type) \ + template \ + struct ApplyCallHelper { \ + template \ + static std::vector Apply(PackedStepContext* ctx, \ + PreviousArgs... pargs) { \ + using rf_attr_type = std::remove_reference::type; \ + using rc_attr_type = std::remove_const::type; \ + const auto& arg = ctx->AttrAt(attr_idx); \ + return ApplyCallHelper:: \ + template Apply( \ + ctx, std::forward(pargs)..., arg); \ + } \ } template @@ -167,17 +179,26 @@ struct TypeTag {}; template struct FreeFuncConverter; -template +template struct FreeFuncConverter { - static Return Apply(IRSchedule* sch, Args... args) { return (sch->*impl_fn)(std::forward(args)...); } + static Return Apply(IRSchedule* sch, Args... args) { + return (sch->*impl_fn)(std::forward(args)...); + } }; -template +template struct FreeFuncConverter { - static Return Apply(IRSchedule* sch, Args... args) { return (sch->*impl_fn)(std::forward(args)...); } + static Return Apply(IRSchedule* sch, Args... args) { + return (sch->*impl_fn)(std::forward(args)...); + } }; -// used for formatting scheduling functions with variaous function signatures to be uniformed form +// used for formatting scheduling functions with variaous function signatures to +// be uniformed form template struct ApplyFuncImpl; @@ -199,37 +220,45 @@ struct ApplyFuncImpl { static std::vector Apply(PackedStepContext* ctx) { static_assert(in_idx == 0, "IRSchedule* must be the first argument"); IRSchedule* ir_schedule = ctx->ScheduleHandler(); - return ApplyCallHelper::template Apply(ctx, ir_schedule); + return ApplyCallHelper< + Tail...>::template Apply(ctx, + ir_schedule); } }; template struct ApplyCallHelper { template - static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { + static std::vector Apply(PackedStepContext* ctx, + PreviousArgs... pargs) { auto arg = ctx->InputAt(in_idx - 1); - return ApplyCallHelper::template Apply( - ctx, std::forward(pargs)..., arg); + return ApplyCallHelper:: + template Apply( + ctx, std::forward(pargs)..., arg); } }; template struct ApplyCallHelper { template - static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { + static std::vector Apply(PackedStepContext* ctx, + PreviousArgs... pargs) { auto arg = ctx->InputAt(in_idx - 1); - return ApplyCallHelper::template Apply( - ctx, std::forward(pargs)..., arg); + return ApplyCallHelper:: + template Apply( + ctx, std::forward(pargs)..., arg); } }; template struct ApplyCallHelper&, Tail...> { template - static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { + static std::vector Apply(PackedStepContext* ctx, + PreviousArgs... pargs) { auto arg = ctx->InputsAt(in_idx - 1); - return ApplyCallHelper::template Apply( - ctx, std::forward(pargs)..., arg); + return ApplyCallHelper:: + template Apply( + ctx, std::forward(pargs)..., arg); } }; @@ -267,22 +296,28 @@ struct ApplyFuncImpl { template struct ApplyReturnHelper> { - static std::vector Apply(Args... args) { return impl_fn(std::forward(args)...); } + static std::vector Apply(Args... args) { + return impl_fn(std::forward(args)...); + } }; // end: base template template struct ApplyCallHelper> { template - static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { + static std::vector Apply(PackedStepContext* ctx, + PreviousArgs... pargs) { static_assert(out_idx == 0, "Output is exported from return value"); - return ApplyReturnHelper::Apply(std::forward(pargs)...); + return ApplyReturnHelper::Apply( + std::forward(pargs)...); } }; }; -#define APPLY_FUNC_UNIFORM(...) ::cinn::ir::ApplyFuncImpl::Apply -#define FREE_FUNCTION_CONVERTER(...) ::cinn::ir::FreeFuncConverter::Apply +#define APPLY_FUNC_UNIFORM(...) \ + ::cinn::ir::ApplyFuncImpl::Apply +#define FREE_FUNCTION_CONVERTER(...) \ + ::cinn::ir::FreeFuncConverter::Apply #define CINN_BUILD_STEP_KIND(TypeName) \ static ::cinn::ir::StepKindInfo& __step_kind_registrar_##TypeName = \ @@ -480,8 +515,10 @@ CINN_BUILD_STEP_KIND(SampleCategorical) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SampleCategorical))); // clang-format on -// ------ Following codes are about member function implement of the ScheduleDesc class -void AttrVariantToProto(const utils::Attribute& attr, proto::ScheduleDesc_Attr* attr_proto) { +// ------ Following codes are about member function implement of the +// ScheduleDesc class +void AttrVariantToProto(const utils::Attribute& attr, + proto::ScheduleDesc_Attr* attr_proto) { #define SET_DESC_SINGLE_ITEM(index, built_type, proto_type, proto_field) \ case index: \ attr_proto->set_dtype(proto::ScheduleDesc_Attr_DataType_##proto_type); \ @@ -525,9 +562,10 @@ utils::Attribute AttrProtoToVariant(const proto::ScheduleDesc_Attr& attr) { value = built_type(attr.proto_field()); \ break; -#define PARSE_DESC_REPEATED_ITEM(proto_type, proto_field, built_type) \ - case proto::ScheduleDesc_Attr_DataType_##proto_type: \ - value = built_type({attr.proto_field().begin(), attr.proto_field().end()}); \ +#define PARSE_DESC_REPEATED_ITEM(proto_type, proto_field, built_type) \ + case proto::ScheduleDesc_Attr_DataType_##proto_type: \ + value = \ + built_type({attr.proto_field().begin(), attr.proto_field().end()}); \ break; switch (attr.dtype()) { @@ -554,11 +592,15 @@ utils::Attribute AttrProtoToVariant(const proto::ScheduleDesc_Attr& attr) { // Expr hash functor, presents how to hash an Expr struct ExprHash { - size_t operator()(const Expr& e) const { return std::hash()(e.ptr()); } + size_t operator()(const Expr& e) const { + return std::hash()(e.ptr()); + } }; // Expr equal functor, presents whether a Expr pair is equal struct ExprEqual { - bool operator()(const Expr& lhs, const Expr& rhs) const { return lhs.get() == rhs.get(); } + bool operator()(const Expr& lhs, const Expr& rhs) const { + return lhs.get() == rhs.get(); + } }; void ScheduleDesc::Append(Step&& step) { steps_.emplace_back(std::move(step)); } @@ -569,7 +611,8 @@ void ScheduleDesc::Pop() { } } -void ScheduleDesc::Replay(IRSchedule* schedule, bool without_post_schedule) const { +void ScheduleDesc::Replay(IRSchedule* schedule, + bool without_post_schedule) const { ReplayWithProto(this->ToProto(), schedule, without_post_schedule); } @@ -584,16 +627,18 @@ proto::ScheduleDesc ScheduleDesc::ToProto() const { // inputs of a step must refer to Exprs resulted by preceding steps for (auto&& param2exprs : step.inputs) { const std::string& param_name = param2exprs.first; - auto* expr_desc = step_proto->add_inputs(); + auto* expr_desc = step_proto->add_inputs(); expr_desc->set_parameter(param_name); for (auto&& expr : param2exprs.second) { auto expr_it = expr2name.find(expr); - CHECK(expr_it != expr2name.end()) << "Can't find expr of param_name: " << param_name; + CHECK(expr_it != expr2name.end()) + << "Can't find expr of param_name: " << param_name; expr_desc->add_arguments(expr_it->second); } } - // each output Expr is represented by a formatted name, to be refered by suceeding steps + // each output Expr is represented by a formatted name, to be refered by + // suceeding steps for (auto&& expr : step.outputs) { std::string local_name = "e" + std::to_string(expr2name.size()); expr2name.emplace(expr, local_name); @@ -601,7 +646,7 @@ proto::ScheduleDesc ScheduleDesc::ToProto() const { } for (auto&& attr2value : step.attrs) { - auto* attr_proto = step_proto->add_attrs(); + auto* attr_proto = step_proto->add_attrs(); const auto& attr_value = attr2value.second; VLOG(5) << "Attr.index:" << attr_value.index(); attr_proto->set_name(attr2value.first); @@ -611,9 +656,10 @@ proto::ScheduleDesc ScheduleDesc::ToProto() const { return desc_proto; } -std::vector ScheduleDesc::ReplayWithProto(const proto::ScheduleDesc& desc_proto, - IRSchedule* sch, - bool without_post_schedule) { +std::vector ScheduleDesc::ReplayWithProto( + const proto::ScheduleDesc& desc_proto, + IRSchedule* sch, + bool without_post_schedule) { VLOG(4) << "proto::ScheduleDesc:\n" << desc_proto.DebugString(); if (desc_proto.steps().empty()) { LOG(WARNING) << "Input proto::ScheduleDesc is empty"; @@ -649,7 +695,8 @@ std::vector ScheduleDesc::ReplayWithProto(const proto::ScheduleDesc& desc_ PackedStepContext context(step, step_kind, sch); step.outputs = step_kind->Apply(&context); - CHECK_EQ(step_proto.outputs().size(), step.outputs.size()) << "Output size not matched"; + CHECK_EQ(step_proto.outputs().size(), step.outputs.size()) + << "Output size not matched"; for (size_t i = 0; i < step.outputs.size(); ++i) { name2expr[step_proto.outputs(i)] = step.outputs.at(i); } @@ -658,7 +705,9 @@ std::vector ScheduleDesc::ReplayWithProto(const proto::ScheduleDesc& desc_ return last_outputs; } -ScheduleDesc ScheduleDesc::ForkAndUpdate(int step_idx, utils::Attribute decision, bool without_post_schedule) const { +ScheduleDesc ScheduleDesc::ForkAndUpdate(int step_idx, + utils::Attribute decision, + bool without_post_schedule) const { int n_valid_step = 0; if (!without_post_schedule) { n_valid_step = steps_.size(); @@ -671,7 +720,8 @@ ScheduleDesc ScheduleDesc::ForkAndUpdate(int step_idx, utils::Attribute decision } } } - std::vector new_steps(steps_.begin(), steps_.begin() + n_valid_step); + std::vector new_steps(steps_.begin(), + steps_.begin() + n_valid_step); new_steps[step_idx].attrs["decision"] = decision; return ScheduleDesc(std::move(new_steps)); } diff --git a/paddle/cinn/ir/schedule_desc.h b/paddle/cinn/ir/schedule_desc.h index 57c85b5391bb2..9cac7ac87816d 100644 --- a/paddle/cinn/ir/schedule_desc.h +++ b/paddle/cinn/ir/schedule_desc.h @@ -27,11 +27,12 @@ namespace cinn { namespace ir { -// A ScheduleDesc describe the scheduling process of an ir::ModuleExpr, it records -// all transform/getting operations executed by a corresponding ir::IRSchedule. -// A ScheduleDesc can be serialized to JSON format and saved to file. For deserializing, -// it can be re-applied to a new IRSchedule that is initialzied by a semantics-euqal -// original ir::ModuleExpr, and then achieves the same result. +// A ScheduleDesc describe the scheduling process of an ir::ModuleExpr, it +// records all transform/getting operations executed by a corresponding +// ir::IRSchedule. A ScheduleDesc can be serialized to JSON format and saved to +// file. For deserializing, it can be re-applied to a new IRSchedule that is +// initialzied by a semantics-euqal original ir::ModuleExpr, and then achieves +// the same result. class IRSchedule; // forward declartion to avoid cross-reference class ScheduleDesc { @@ -51,14 +52,17 @@ class ScheduleDesc { }; /** - * \brief Re-applied a scheduling process represented as a proto::ScheduleDesc to a new IRSchedule object. + * \brief Re-applied a scheduling process represented as a proto::ScheduleDesc + * to a new IRSchedule object. * @param desc_proto The proto of the ScheduleDesc to be re-applied. * @param sch The original IRSchedule to be replayed the description on. - * @param without_post_schedule Determine whether to delete the post schedules. + * @param without_post_schedule Determine whether to delete the post + * schedules. */ - static std::vector ReplayWithProto(const proto::ScheduleDesc& desc_proto, - IRSchedule* sch, - bool without_post_schedule = false); + static std::vector ReplayWithProto( + const proto::ScheduleDesc& desc_proto, + IRSchedule* sch, + bool without_post_schedule = false); ScheduleDesc() = default; @@ -73,9 +77,11 @@ class ScheduleDesc { void Pop(); /** - * \brief Replay this description to a new IRSchedule that is initialzied by a semantics-euqal original ModuleExpr. + * \brief Replay this description to a new IRSchedule that is initialzied by a + * semantics-euqal original ModuleExpr. * @param schedule The original IRSchedule to be replayed the description on. - * @param without_post_schedule Determine whether to delete the post schedules. + * @param without_post_schedule Determine whether to delete the post + * schedules. */ void Replay(IRSchedule* schedule, bool without_post_schedule = false) const; @@ -90,13 +96,17 @@ class ScheduleDesc { bool Empty() const { return steps_.empty(); } /** - * \brief Fork this ScheduleDesc and update a step of the new ScheduleDesc with a new decision. + * \brief Fork this ScheduleDesc and update a step of the new ScheduleDesc + * with a new decision. * @param step_idx The index of the step to be update. * @param decision The new decision. - * @param without_post_schedule Determine whether to delete the post schedules. + * @param without_post_schedule Determine whether to delete the post + * schedules. * @return The new ScheduleDesc. */ - ScheduleDesc ForkAndUpdate(int step_idx, utils::Attribute decision, bool without_post_schedule) const; + ScheduleDesc ForkAndUpdate(int step_idx, + utils::Attribute decision, + bool without_post_schedule) const; private: std::vector steps_; // all operations are recorded in order. diff --git a/paddle/cinn/ir/schedule_desc_test.cc b/paddle/cinn/ir/schedule_desc_test.cc index af10af53406c6..805d79230c684 100644 --- a/paddle/cinn/ir/schedule_desc_test.cc +++ b/paddle/cinn/ir/schedule_desc_test.cc @@ -30,10 +30,11 @@ namespace cinn { namespace ir { // Return lowerd ir AST for example functions used in this test -std::vector LowerCompute(const std::vector& shape, - const Target& target, - bool need_c = false, - const std::string& operation = "elementwise-copy") { +std::vector LowerCompute( + const std::vector& shape, + const Target& target, + bool need_c = false, + const std::string& operation = "elementwise-copy") { CHECK(shape.size() == 2 || shape.size() == 3) << "shape should be 2 or 3"; std::vector domain; for (auto i = 0; i < shape.size(); ++i) { @@ -65,17 +66,29 @@ std::vector LowerCompute(const std::vector& shape, domain, [&B](Var i, Var j) { return B(i, j) + Expr(1.f); }, "C"); } else { B = Compute( - domain, [&A](Var i, Var j, Var k) { return A(i, j, k) * Expr(2.f); }, "B"); + domain, + [&A](Var i, Var j, Var k) { return A(i, j, k) * Expr(2.f); }, + "B"); C = Compute( - domain, [&B](Var i, Var j, Var k) { return B(i, j, k) + Expr(1.f); }, "C"); + domain, + [&B](Var i, Var j, Var k) { return B(i, j, k) + Expr(1.f); }, + "C"); } } if (need_c) { - return cinn::lang::LowerVec("test_func", CreateStages({A, B, C}), {A, C}, {}, {}, nullptr, target, true); + return cinn::lang::LowerVec("test_func", + CreateStages({A, B, C}), + {A, C}, + {}, + {}, + nullptr, + target, + true); } - return cinn::lang::LowerVec("test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); + return cinn::lang::LowerVec( + "test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); } // Create a new IRSchedule with copied ir::LoweredFunc AST @@ -112,17 +125,21 @@ class TestScheduleDesc : public ::testing::Test { ScheduleDesc trace; void SetUp() override { Context::Global().ResetNameId(); } - void CheckTracingOutputs(const std::vector& base, const ScheduleDesc& trace_desc) { + void CheckTracingOutputs(const std::vector& base, + const ScheduleDesc& trace_desc) { Context::Global().ResetNameId(); ir::IRSchedule replay_sch = MakeIRSchedule(lowered_funcs); - auto traced_outputs = ScheduleDesc::ReplayWithProto(trace_desc.ToProto(), &replay_sch); + auto traced_outputs = + ScheduleDesc::ReplayWithProto(trace_desc.ToProto(), &replay_sch); ASSERT_EQ(base.size(), traced_outputs.size()); for (auto i = 0; i < base.size(); ++i) { - ASSERT_EQ(utils::GetStreamCnt(base.at(i)), utils::GetStreamCnt(traced_outputs.at(i))); + ASSERT_EQ(utils::GetStreamCnt(base.at(i)), + utils::GetStreamCnt(traced_outputs.at(i))); } } - void CheckReplayResult(const ir::IRSchedule& ir_sch, const ScheduleDesc& trace_desc) { + void CheckReplayResult(const ir::IRSchedule& ir_sch, + const ScheduleDesc& trace_desc) { Context::Global().ResetNameId(); ir::IRSchedule replay_sch = MakeIRSchedule(lowered_funcs); trace_desc.Replay(&replay_sch); @@ -133,41 +150,60 @@ class TestScheduleDesc : public ::testing::Test { auto rhs_exprs = replay_sch.GetModule().GetExprs(); ASSERT_EQ(lhs_exprs.size(), rhs_exprs.size()); for (auto i = 0; i < lhs_exprs.size(); ++i) { - ASSERT_EQ(utils::GetStreamCnt(lhs_exprs.at(i)), utils::GetStreamCnt(rhs_exprs.at(i))); + ASSERT_EQ(utils::GetStreamCnt(lhs_exprs.at(i)), + utils::GetStreamCnt(rhs_exprs.at(i))); } // check the equality of source code between them - ASSERT_EQ(utils::Trim(SourceCodeGen(ir_sch.GetModule(), lowered_funcs, target)), - utils::Trim(SourceCodeGen(replay_sch.GetModule(), lowered_funcs, target))); + ASSERT_EQ( + utils::Trim(SourceCodeGen(ir_sch.GetModule(), lowered_funcs, target)), + utils::Trim( + SourceCodeGen(replay_sch.GetModule(), lowered_funcs, target))); } }; TEST_F(TestScheduleDesc, Append_Replay) { - lowered_funcs = LowerCompute({32, 32}, target); + lowered_funcs = LowerCompute({32, 32}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto fused = ir_sch.Fuse("B", {0, 1}); - trace.Append(ScheduleDesc::Step( - "FuseWithName", {}, {{"block_name", std::string("B")}, {"loops_index", std::vector({0, 1})}}, {fused})); + trace.Append(ScheduleDesc::Step("FuseWithName", + {}, + {{"block_name", std::string("B")}, + {"loops_index", std::vector({0, 1})}}, + {fused})); auto sample = ir_sch.SamplePerfectTile(fused, 2, 1, {4, -1}); trace.Append(ScheduleDesc::Step("SamplePerfectTile", {{"loop", std::vector({fused})}}, - {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{4, -1}}}, + {{"n", 2}, + {"max_innermost_factor", 1}, + {"decision", std::vector{4, -1}}}, sample)); auto splited = ir_sch.Split(fused, sample); - trace.Append(ScheduleDesc::Step("Split", {{"loop", std::vector({fused})}, {"factors", sample}}, {}, splited)); + trace.Append(ScheduleDesc::Step( + "Split", + {{"loop", std::vector({fused})}, {"factors", sample}}, + {}, + splited)); auto loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); fused = ir_sch.Fuse(loops); trace.Append(ScheduleDesc::Step("Fuse", {{"loops", loops}}, {}, {fused})); sample = ir_sch.SamplePerfectTile(fused, 2, 1, {256, -1}); trace.Append(ScheduleDesc::Step("SamplePerfectTile", {{"loop", std::vector({fused})}}, - {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{256, -1}}}, + {{"n", 2}, + {"max_innermost_factor", 1}, + {"decision", std::vector{256, -1}}}, sample)); splited = ir_sch.Split(fused, sample); - trace.Append(ScheduleDesc::Step("Split", {{"loop", std::vector({fused})}, {"factors", sample}}, {}, splited)); + trace.Append(ScheduleDesc::Step( + "Split", + {{"loop", std::vector({fused})}, {"factors", sample}}, + {}, + splited)); // check the equality of results between the ir_sch and replaying of trace CheckTracingOutputs(splited, trace); @@ -177,9 +213,10 @@ TEST_F(TestScheduleDesc, Append_Replay) { CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } -// Test cases with `StepKind` prefix are to check the correctness of their StepKindInfo register +// Test cases with `StepKind` prefix are to check the correctness of their +// StepKindInfo register TEST_F(TestScheduleDesc, StepKind_GetAllBlocks) { - lowered_funcs = LowerCompute({32, 32}, target); + lowered_funcs = LowerCompute({32, 32}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto all_blocks = ir_sch.GetAllBlocks(); @@ -189,56 +226,70 @@ TEST_F(TestScheduleDesc, StepKind_GetAllBlocks) { } TEST_F(TestScheduleDesc, StepKind_GetChildBlocks) { - lowered_funcs = LowerCompute({32, 32, 64}, target, true); + lowered_funcs = LowerCompute({32, 32, 64}, target, true); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); auto loops = ir_sch.GetLoops("C"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops)); ir_sch.ComputeAt(block_b, loops[1]); trace.Append(ScheduleDesc::Step("ComputeAt", - {{"block", std::vector({block_b})}, {"loop", std::vector({loops[1]})}}, + {{"block", std::vector({block_b})}, + {"loop", std::vector({loops[1]})}}, {{"keep_unit_loops", false}}, {})); loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); auto root_block = ir_sch.GetRootBlock(loops[1]); - trace.Append(ScheduleDesc::Step("GetRootBlock", {{"expr", std::vector({loops[1]})}}, {}, {root_block})); + trace.Append(ScheduleDesc::Step("GetRootBlock", + {{"expr", std::vector({loops[1]})}}, + {}, + {root_block})); auto childblocks = ir_sch.GetChildBlocks(root_block); - trace.Append(ScheduleDesc::Step("GetChildBlocks", {{"expr", std::vector({root_block})}}, {}, childblocks)); + trace.Append(ScheduleDesc::Step("GetChildBlocks", + {{"expr", std::vector({root_block})}}, + {}, + childblocks)); CheckTracingOutputs(childblocks, trace); CheckTracingOutputs(childblocks, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_GetLoops) { - lowered_funcs = LowerCompute({32, 32}, target); + lowered_funcs = LowerCompute({32, 32}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); auto loops = ir_sch.GetLoops(block_b); - trace.Append(ScheduleDesc::Step("GetLoops", {{"block", std::vector({block_b})}}, {}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoops", {{"block", std::vector({block_b})}}, {}, loops)); CheckTracingOutputs(loops, trace); CheckTracingOutputs(loops, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_GetLoopsWithName) { - lowered_funcs = LowerCompute({32, 32}, target); + lowered_funcs = LowerCompute({32, 32}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); CheckTracingOutputs(loops, trace); CheckTracingOutputs(loops, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_GetBlock) { - lowered_funcs = LowerCompute({32, 32, 32}, target); + lowered_funcs = LowerCompute({32, 32, 32}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); CheckTracingOutputs({block_b}, trace); CheckTracingOutputs({block_b}, ir_sch.GetTraceDesc()); } @@ -252,16 +303,17 @@ TEST_F(TestScheduleDesc, StepKind_Split) { // test split with inputs of Expr auto loops = ir_sch_split_base.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); - auto sample = ir_sch_split_base.SamplePerfectTile(loops.front(), 2, 1, {4, -1}); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", +std::string("B")}}, loops)); auto sample = +ir_sch_split_base.SamplePerfectTile(loops.front(), 2, 1, {4, -1}); trace.Append(ScheduleDesc::Step("SamplePerfectTile", - {{"loop", std::vector({loops.front()})}}, - {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{4, -1}}}, - sample)); - auto splited = ir_sch_split_base.Split(loops.front(), sample); - trace.Append( - ScheduleDesc::Step("Split", {{"loop", std::vector({loops.front()})}, {"factors", sample}}, {}, splited)); - CheckTracingOutputs(splited, trace); + {{"loop", +std::vector({loops.front()})}}, + {{"n", 2}, {"max_innermost_factor", 1}, +{"decision", std::vector{4, -1}}}, sample)); auto splited = +ir_sch_split_base.Split(loops.front(), sample); trace.Append( + ScheduleDesc::Step("Split", {{"loop", std::vector({loops.front()})}, +{"factors", sample}}, {}, splited)); CheckTracingOutputs(splited, trace); CheckTracingOutputs(splited, ir_sch_split_base.GetTraceDesc()); // test split with inputs of int @@ -277,11 +329,12 @@ TEST_F(TestScheduleDesc, StepKind_Split) { } */ TEST_F(TestScheduleDesc, StepKind_Fuse) { - lowered_funcs = LowerCompute({32, 32, 64}, target); + lowered_funcs = LowerCompute({32, 32, 64}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); auto fused = ir_sch.Fuse(loops); trace.Append(ScheduleDesc::Step("Fuse", {{"loops", loops}}, {}, {fused})); CheckTracingOutputs({fused}, trace); @@ -289,42 +342,51 @@ TEST_F(TestScheduleDesc, StepKind_Fuse) { } TEST_F(TestScheduleDesc, StepKind_FuseWithName) { - lowered_funcs = LowerCompute({32, 32, 64}, target); + lowered_funcs = LowerCompute({32, 32, 64}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto fused = ir_sch.Fuse("B", {0, 1, 2}); - trace.Append(ScheduleDesc::Step( - "FuseWithName", {}, {{"block_name", std::string("B")}, {"loops_index", std::vector({0, 1, 2})}}, {fused})); + trace.Append( + ScheduleDesc::Step("FuseWithName", + {}, + {{"block_name", std::string("B")}, + {"loops_index", std::vector({0, 1, 2})}}, + {fused})); CheckTracingOutputs({fused}, trace); CheckTracingOutputs({fused}, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_FuseWithBlock) { - lowered_funcs = LowerCompute({32, 32, 64}, target); + lowered_funcs = LowerCompute({32, 32, 64}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); auto fused = ir_sch.Fuse(block_b, {0, 1, 2}); - trace.Append(ScheduleDesc::Step("FuseWithBlock", - {{"block", std::vector({block_b})}}, - {{"loops_index", std::vector({0, 1, 2})}}, - {fused})); + trace.Append( + ScheduleDesc::Step("FuseWithBlock", + {{"block", std::vector({block_b})}}, + {{"loops_index", std::vector({0, 1, 2})}}, + {fused})); CheckTracingOutputs({fused}, trace); CheckTracingOutputs({fused}, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_ComputeAt) { - lowered_funcs = LowerCompute({32, 32, 64}, target, true); + lowered_funcs = LowerCompute({32, 32, 64}, target, true); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); auto loops = ir_sch.GetLoops("C"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops)); ir_sch.ComputeAt(block_b, loops[1]); trace.Append(ScheduleDesc::Step("ComputeAt", - {{"block", std::vector({block_b})}, {"loop", std::vector({loops[1]})}}, + {{"block", std::vector({block_b})}, + {"loop", std::vector({loops[1]})}}, {{"keep_unit_loops", false}}, {})); CheckReplayResult(ir_sch, trace); @@ -332,16 +394,19 @@ TEST_F(TestScheduleDesc, StepKind_ComputeAt) { } TEST_F(TestScheduleDesc, StepKind_SimpleComputeAt) { - lowered_funcs = LowerCompute({32, 32, 64}, target, true); + lowered_funcs = LowerCompute({32, 32, 64}, target, true); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); auto loops = ir_sch.GetLoops("C"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops)); ir_sch.SimpleComputeAt(block_b, loops[2]); trace.Append(ScheduleDesc::Step("SimpleComputeAt", - {{"block", std::vector({block_b})}, {"loop", std::vector({loops[2]})}}, + {{"block", std::vector({block_b})}, + {"loop", std::vector({loops[2]})}}, {{"keep_unit_loops", false}}, {})); CheckReplayResult(ir_sch, trace); @@ -349,16 +414,19 @@ TEST_F(TestScheduleDesc, StepKind_SimpleComputeAt) { } TEST_F(TestScheduleDesc, StepKind_ReverseComputeAt) { - lowered_funcs = LowerCompute({32, 32, 64}, target, true); + lowered_funcs = LowerCompute({32, 32, 64}, target, true); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_c = ir_sch.GetBlock("C"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("C")}}, {block_c})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("C")}}, {block_c})); auto loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); ir_sch.ReverseComputeAt(block_c, loops[1]); trace.Append(ScheduleDesc::Step("ReverseComputeAt", - {{"block", std::vector({block_c})}, {"loop", std::vector({loops[1]})}}, + {{"block", std::vector({block_c})}, + {"loop", std::vector({loops[1]})}}, {{"keep_unit_loops", false}}, {})); CheckReplayResult(ir_sch, trace); @@ -366,28 +434,33 @@ TEST_F(TestScheduleDesc, StepKind_ReverseComputeAt) { } TEST_F(TestScheduleDesc, StepKind_GetRootBlock) { - lowered_funcs = LowerCompute({32, 64}, target); + lowered_funcs = LowerCompute({32, 64}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); auto root_b = ir_sch.GetRootBlock(loops[1]); - trace.Append(ScheduleDesc::Step("GetRootBlock", {{"expr", std::vector({loops[1]})}}, {}, {root_b})); + trace.Append(ScheduleDesc::Step( + "GetRootBlock", {{"expr", std::vector({loops[1]})}}, {}, {root_b})); CheckTracingOutputs({root_b}, trace); CheckTracingOutputs({root_b}, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_CacheRead) { - lowered_funcs = LowerCompute({32, 64}, target, false, "elementwise-add_const"); + lowered_funcs = + LowerCompute({32, 64}, target, false, "elementwise-add_const"); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); auto a_cache = ir_sch.CacheRead(block_b, 0, "local"); - trace.Append(ScheduleDesc::Step("CacheRead", - {{"block", std::vector({block_b})}}, - {{"read_buffer_index", 0}, {"memory_type", std::string("local")}}, - {a_cache})); + trace.Append(ScheduleDesc::Step( + "CacheRead", + {{"block", std::vector({block_b})}}, + {{"read_buffer_index", 0}, {"memory_type", std::string("local")}}, + {a_cache})); CheckTracingOutputs({a_cache}, trace); CheckTracingOutputs({a_cache}, ir_sch.GetTraceDesc()); CheckReplayResult(ir_sch, trace); @@ -395,16 +468,19 @@ TEST_F(TestScheduleDesc, StepKind_CacheRead) { } TEST_F(TestScheduleDesc, StepKind_CacheWrite) { - lowered_funcs = LowerCompute({32, 64}, target, false, "elementwise-add_const"); + lowered_funcs = + LowerCompute({32, 64}, target, false, "elementwise-add_const"); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); - trace.Append(ScheduleDesc::Step("CacheWrite", - {{"block", std::vector({block_b})}}, - {{"write_buffer_index", 0}, {"memory_type", std::string("local")}}, - {b_cache})); + trace.Append(ScheduleDesc::Step( + "CacheWrite", + {{"block", std::vector({block_b})}}, + {{"write_buffer_index", 0}, {"memory_type", std::string("local")}}, + {b_cache})); CheckTracingOutputs({b_cache}, trace); CheckTracingOutputs({b_cache}, ir_sch.GetTraceDesc()); CheckReplayResult(ir_sch, trace); @@ -412,227 +488,307 @@ TEST_F(TestScheduleDesc, StepKind_CacheWrite) { } TEST_F(TestScheduleDesc, StepKind_SyncThreads) { - lowered_funcs = LowerCompute({64, 32}, target, true, "elementwise-add_const"); + lowered_funcs = LowerCompute({64, 32}, target, true, "elementwise-add_const"); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); - trace.Append(ScheduleDesc::Step("CacheWrite", - {{"block", std::vector({block_b})}}, - {{"write_buffer_index", 0}, {"memory_type", std::string("local")}}, - {b_cache})); + trace.Append(ScheduleDesc::Step( + "CacheWrite", + {{"block", std::vector({block_b})}}, + {{"write_buffer_index", 0}, {"memory_type", std::string("local")}}, + {b_cache})); auto block_c = ir_sch.GetBlock("C"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("C")}}, {block_c})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("C")}}, {block_c})); auto c_cache = ir_sch.CacheWrite(block_c, 0, "local"); - trace.Append(ScheduleDesc::Step("CacheWrite", - {{"block", std::vector({block_c})}}, - {{"write_buffer_index", 0}, {"memory_type", std::string("local")}}, - {c_cache})); + trace.Append(ScheduleDesc::Step( + "CacheWrite", + {{"block", std::vector({block_c})}}, + {{"write_buffer_index", 0}, {"memory_type", std::string("local")}}, + {c_cache})); block_c = ir_sch.GetBlock("C"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("C")}}, {block_c})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("C")}}, {block_c})); ir_sch.SyncThreads(block_c, false); - trace.Append( - ScheduleDesc::Step("SyncThreads", {{"ir_node", std::vector({block_c})}}, {{"after_node", false}}, {})); + trace.Append(ScheduleDesc::Step("SyncThreads", + {{"ir_node", std::vector({block_c})}}, + {{"after_node", false}}, + {})); block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); ir_sch.SyncThreads(block_b); - trace.Append( - ScheduleDesc::Step("SyncThreads", {{"ir_node", std::vector({block_b})}}, {{"after_node", true}}, {})); + trace.Append(ScheduleDesc::Step("SyncThreads", + {{"ir_node", std::vector({block_b})}}, + {{"after_node", true}}, + {})); CheckReplayResult(ir_sch, trace); CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_SetBuffer) { - lowered_funcs = LowerCompute({32, 64}, target, false, "elementwise-add_const"); + lowered_funcs = + LowerCompute({32, 64}, target, false, "elementwise-add_const"); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); ir_sch.SetBuffer(block_b, "shared", true); - trace.Append(ScheduleDesc::Step("SetBuffer", - {{"block", std::vector({block_b})}}, - {{"memory_type", std::string("shared")}, {"fixed", true}}, - {})); + trace.Append(ScheduleDesc::Step( + "SetBuffer", + {{"block", std::vector({block_b})}}, + {{"memory_type", std::string("shared")}, {"fixed", true}}, + {})); CheckReplayResult(ir_sch, trace); CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_Reorder) { - lowered_funcs = LowerCompute({32, 64, 12}, target); + lowered_funcs = LowerCompute({32, 64, 12}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); auto sample = ir_sch.SamplePerfectTile(loops[0], 2, 1, {-1, 4}); trace.Append(ScheduleDesc::Step("SamplePerfectTile", {{"loop", std::vector({loops[0]})}}, - {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{-1, 4}}}, + {{"n", 2}, + {"max_innermost_factor", 1}, + {"decision", std::vector{-1, 4}}}, sample)); auto splited = ir_sch.Split(loops[0], sample); - trace.Append( - ScheduleDesc::Step("Split", {{"loop", std::vector({loops[0]})}, {"factors", sample}}, {}, splited)); + trace.Append(ScheduleDesc::Step( + "Split", + {{"loop", std::vector({loops[0]})}, {"factors", sample}}, + {}, + splited)); loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); sample = ir_sch.SamplePerfectTile(loops[2], 2, 1, {-1, 2}); trace.Append(ScheduleDesc::Step("SamplePerfectTile", {{"loop", std::vector({loops[2]})}}, - {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{-1, 2}}}, + {{"n", 2}, + {"max_innermost_factor", 1}, + {"decision", std::vector{-1, 2}}}, sample)); splited = ir_sch.Split(loops[2], sample); - trace.Append( - ScheduleDesc::Step("Split", {{"loop", std::vector({loops[2]})}, {"factors", sample}}, {}, splited)); + trace.Append(ScheduleDesc::Step( + "Split", + {{"loop", std::vector({loops[2]})}, {"factors", sample}}, + {}, + splited)); loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); Expr ret = ir_sch.Reorder({loops[4], loops[0]}); - trace.Append(ScheduleDesc::Step("Reorder", {{"loops", std::vector({loops[4], loops[0]})}}, {}, {ret})); + trace.Append( + ScheduleDesc::Step("Reorder", + {{"loops", std::vector({loops[4], loops[0]})}}, + {}, + {ret})); CheckReplayResult(ir_sch, trace); CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_ReorderWithBlock) { - lowered_funcs = LowerCompute({32, 32, 64}, target); + lowered_funcs = LowerCompute({32, 32, 64}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); - auto loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); auto sample = ir_sch.SamplePerfectTile(loops[0], 2, 1, {-1, 4}); trace.Append(ScheduleDesc::Step("SamplePerfectTile", {{"loop", std::vector({loops[0]})}}, - {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{-1, 4}}}, + {{"n", 2}, + {"max_innermost_factor", 1}, + {"decision", std::vector{-1, 4}}}, sample)); auto splited = ir_sch.Split(loops[0], sample); - trace.Append( - ScheduleDesc::Step("Split", {{"loop", std::vector({loops[0]})}, {"factors", sample}}, {}, splited)); + trace.Append(ScheduleDesc::Step( + "Split", + {{"loop", std::vector({loops[0]})}, {"factors", sample}}, + {}, + splited)); loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); sample = ir_sch.SamplePerfectTile(loops[2], 2, 1, {-1, 2}); trace.Append(ScheduleDesc::Step("SamplePerfectTile", {{"loop", std::vector({loops[2]})}}, - {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{-1, 2}}}, + {{"n", 2}, + {"max_innermost_factor", 1}, + {"decision", std::vector{-1, 2}}}, sample)); splited = ir_sch.Split(loops[2], sample); - trace.Append( - ScheduleDesc::Step("Split", {{"loop", std::vector({loops[2]})}, {"factors", sample}}, {}, splited)); + trace.Append(ScheduleDesc::Step( + "Split", + {{"loop", std::vector({loops[2]})}, {"factors", sample}}, + {}, + splited)); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); Expr ret = ir_sch.Reorder("B", {2, 3, 1, 4, 0}); - trace.Append(ScheduleDesc::Step("ReorderWithBlock", - {{"block", std::vector({block_b})}}, - {{"loops_index", std::vector({2, 3, 1, 4, 0})}}, - {ret})); + trace.Append( + ScheduleDesc::Step("ReorderWithBlock", + {{"block", std::vector({block_b})}}, + {{"loops_index", std::vector({2, 3, 1, 4, 0})}}, + {ret})); CheckReplayResult(ir_sch, trace); CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_ReorderWithName) { - lowered_funcs = LowerCompute({32, 32, 64}, target); + lowered_funcs = LowerCompute({32, 32, 64}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); auto sample = ir_sch.SamplePerfectTile(loops[0], 2, 1, {-1, 4}); trace.Append(ScheduleDesc::Step("SamplePerfectTile", {{"loop", std::vector({loops[0]})}}, - {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{-1, 4}}}, + {{"n", 2}, + {"max_innermost_factor", 1}, + {"decision", std::vector{-1, 4}}}, sample)); auto splited = ir_sch.Split(loops[0], sample); - trace.Append( - ScheduleDesc::Step("Split", {{"loop", std::vector({loops[0]})}, {"factors", sample}}, {}, splited)); + trace.Append(ScheduleDesc::Step( + "Split", + {{"loop", std::vector({loops[0]})}, {"factors", sample}}, + {}, + splited)); loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); sample = ir_sch.SamplePerfectTile(loops[2], 2, 1, {-1, 2}); trace.Append(ScheduleDesc::Step("SamplePerfectTile", {{"loop", std::vector({loops[2]})}}, - {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{-1, 2}}}, + {{"n", 2}, + {"max_innermost_factor", 1}, + {"decision", std::vector{-1, 2}}}, sample)); splited = ir_sch.Split(loops[2], sample); - trace.Append( - ScheduleDesc::Step("Split", {{"loop", std::vector({loops[2]})}, {"factors", sample}}, {}, splited)); + trace.Append(ScheduleDesc::Step( + "Split", + {{"loop", std::vector({loops[2]})}, {"factors", sample}}, + {}, + splited)); Expr ret = ir_sch.Reorder("B", {4, 2, 3, 1, 0}); trace.Append( ScheduleDesc::Step("ReorderWithName", {}, - {{"block_name", std::string("B")}, {"loops_index", std::vector({4, 2, 3, 1, 0})}}, + {{"block_name", std::string("B")}, + {"loops_index", std::vector({4, 2, 3, 1, 0})}}, {ret})); CheckReplayResult(ir_sch, trace); CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_Parallel) { - lowered_funcs = LowerCompute({32, 64}, target); + lowered_funcs = LowerCompute({32, 64}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); ir_sch.Parallel(loops[0]); - trace.Append(ScheduleDesc::Step("Parallel", {{"loop", std::vector({loops[0]})}}, {}, {})); + trace.Append(ScheduleDesc::Step( + "Parallel", {{"loop", std::vector({loops[0]})}}, {}, {})); CheckReplayResult(ir_sch, trace); CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_Vectorize) { - lowered_funcs = LowerCompute({32, 64}, target); + lowered_funcs = LowerCompute({32, 64}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); ir_sch.Vectorize(loops[1], 16); - trace.Append(ScheduleDesc::Step("Vectorize", {{"loop", std::vector({loops[1]})}}, {{"factor", 16}}, {})); + trace.Append(ScheduleDesc::Step("Vectorize", + {{"loop", std::vector({loops[1]})}}, + {{"factor", 16}}, + {})); CheckReplayResult(ir_sch, trace); CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_Unroll) { - lowered_funcs = LowerCompute({32, 2}, target); + lowered_funcs = LowerCompute({32, 2}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); ir_sch.Unroll(loops[1]); - trace.Append(ScheduleDesc::Step("Unroll", {{"loop", std::vector({loops[1]})}}, {}, {})); + trace.Append(ScheduleDesc::Step( + "Unroll", {{"loop", std::vector({loops[1]})}}, {}, {})); CheckReplayResult(ir_sch, trace); CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_ComputeInline) { - lowered_funcs = LowerCompute({32, 32, 32}, target, true, "elementwise-add_const"); + lowered_funcs = + LowerCompute({32, 32, 32}, target, true, "elementwise-add_const"); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); ir_sch.ComputeInline(block_b); - trace.Append(ScheduleDesc::Step("ComputeInline", {{"schedule_block", std::vector({block_b})}}, {}, {})); + trace.Append( + ScheduleDesc::Step("ComputeInline", + {{"schedule_block", std::vector({block_b})}}, + {}, + {})); CheckReplayResult(ir_sch, trace); CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_ReverseComputeInline) { - lowered_funcs = LowerCompute({32, 32, 32}, target, true, "elementwise-add_const"); + lowered_funcs = + LowerCompute({32, 32, 32}, target, true, "elementwise-add_const"); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); - auto block_c = ir_sch.GetBlock("C"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("C")}}, {block_c})); + auto block_c = ir_sch.GetBlock("C"); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("C")}}, {block_c})); ir_sch.ReverseComputeInline(block_c); - trace.Append(ScheduleDesc::Step("ReverseComputeInline", {{"schedule_block", std::vector({block_c})}}, {}, {})); + trace.Append( + ScheduleDesc::Step("ReverseComputeInline", + {{"schedule_block", std::vector({block_c})}}, + {}, + {})); CheckReplayResult(ir_sch, trace); CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_Bind) { - lowered_funcs = LowerCompute({32, 128}, target); + lowered_funcs = LowerCompute({32, 128}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); - ir_sch.Bind(loops[0], "blockIdx.x"); trace.Append(ScheduleDesc::Step( - "Bind", {{"loop", std::vector({loops[0]})}}, {{"thread_axis", std::string("blockIdx.x")}}, {})); + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + ir_sch.Bind(loops[0], "blockIdx.x"); + trace.Append(ScheduleDesc::Step("Bind", + {{"loop", std::vector({loops[0]})}}, + {{"thread_axis", std::string("blockIdx.x")}}, + {})); CheckReplayResult(ir_sch, trace); CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } @@ -646,20 +802,31 @@ TEST_F(TestScheduleDesc, StepKind_Rfactor) { Placeholder B("B", {K, N}); Var k(16, "k0"); auto C = Compute( - {M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); - - lowered_funcs = - cinn::lang::LowerVec("test_rfactor", CreateStages({A, B, C}), {A, B, C}, {}, {}, nullptr, target, true); + {M, N}, + [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, + "C"); + + lowered_funcs = cinn::lang::LowerVec("test_rfactor", + CreateStages({A, B, C}), + {A, B, C}, + {}, + {}, + nullptr, + target, + true); cinn::common::Context::Global().ResetNameId(); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); cinn::common::Context::Global().ResetNameId(); auto loops = ir_sch.GetLoops("C"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops)); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops)); auto new_rf_tensor = ir_sch.Rfactor(loops[2], 0); - trace.Append( - ScheduleDesc::Step("Rfactor", {{"rf_loop", std::vector({loops[2]})}}, {{"rf_axis", 0}}, {new_rf_tensor})); + trace.Append(ScheduleDesc::Step("Rfactor", + {{"rf_loop", std::vector({loops[2]})}}, + {{"rf_axis", 0}}, + {new_rf_tensor})); CheckTracingOutputs({new_rf_tensor}, trace); CheckTracingOutputs({new_rf_tensor}, ir_sch.GetTraceDesc()); CheckReplayResult(ir_sch, trace); @@ -668,95 +835,115 @@ TEST_F(TestScheduleDesc, StepKind_Rfactor) { TEST_F(TestScheduleDesc, StepKind_MergeExprs) { auto funcs_0 = LowerCompute({32, 128}, target); - auto funcs_1 = LowerCompute({32, 32, 32}, target, true, "elementwise-add_const"); + auto funcs_1 = + LowerCompute({32, 32, 32}, target, true, "elementwise-add_const"); - ir::IRSchedule ir_sch = - ir::IRSchedule(ir::ModuleExpr({optim::IRCopy(funcs_0[0]->body), optim::IRCopy(funcs_0[0]->body)})); + ir::IRSchedule ir_sch = ir::IRSchedule(ir::ModuleExpr( + {optim::IRCopy(funcs_0[0]->body), optim::IRCopy(funcs_0[0]->body)})); ir_sch.MergeExprs(); trace.Append(ScheduleDesc::Step("MergeExprs", {}, {}, {})); - ir::IRSchedule replay_sch = - ir::IRSchedule(ir::ModuleExpr({optim::IRCopy(funcs_0[0]->body), optim::IRCopy(funcs_0[0]->body)})); + ir::IRSchedule replay_sch = ir::IRSchedule(ir::ModuleExpr( + {optim::IRCopy(funcs_0[0]->body), optim::IRCopy(funcs_0[0]->body)})); trace.Replay(&replay_sch); auto lhs_exprs = ir_sch.GetModule().GetExprs(); auto rhs_exprs = replay_sch.GetModule().GetExprs(); ASSERT_EQ(lhs_exprs.size(), rhs_exprs.size()); for (auto i = 0; i < lhs_exprs.size(); ++i) { - ASSERT_EQ(utils::GetStreamCnt(lhs_exprs.at(i)), utils::GetStreamCnt(rhs_exprs.at(i))); + ASSERT_EQ(utils::GetStreamCnt(lhs_exprs.at(i)), + utils::GetStreamCnt(rhs_exprs.at(i))); } } TEST_F(TestScheduleDesc, StepKind_Annotate) { - lowered_funcs = LowerCompute({32, 128}, target); + lowered_funcs = LowerCompute({32, 128}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); ir_sch.Annotate(block_b, "k1", int(64)); - trace.Append(ScheduleDesc::Step("AnnotateIntAttr", - {{"block", std::vector({block_b})}}, - {{"key", std::string("k1")}, {"value", int(64)}}, - {})); + trace.Append( + ScheduleDesc::Step("AnnotateIntAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k1")}, {"value", int(64)}}, + {})); block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); ir_sch.Annotate(block_b, "k2", bool(true)); - trace.Append(ScheduleDesc::Step("AnnotateBoolAttr", - {{"block", std::vector({block_b})}}, - {{"key", std::string("k2")}, {"value", bool(true)}}, - {})); + trace.Append( + ScheduleDesc::Step("AnnotateBoolAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k2")}, {"value", bool(true)}}, + {})); block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); ir_sch.Annotate(block_b, "k3", float(2.0)); - trace.Append(ScheduleDesc::Step("AnnotateFloatAttr", - {{"block", std::vector({block_b})}}, - {{"key", std::string("k3")}, {"value", float(2.0)}}, - {})); + trace.Append( + ScheduleDesc::Step("AnnotateFloatAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k3")}, {"value", float(2.0)}}, + {})); block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); ir_sch.Annotate(block_b, "k4", std::string("v4")); - trace.Append(ScheduleDesc::Step("AnnotateStringAttr", - {{"block", std::vector({block_b})}}, - {{"key", std::string("k4")}, {"value", std::string("v4")}}, - {})); + trace.Append(ScheduleDesc::Step( + "AnnotateStringAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k4")}, {"value", std::string("v4")}}, + {})); CheckReplayResult(ir_sch, trace); CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } TEST_F(TestScheduleDesc, StepKind_Unannotate) { - lowered_funcs = LowerCompute({32, 128}, target); + lowered_funcs = LowerCompute({32, 128}, target); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); auto block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); ir_sch.Annotate(block_b, "k1", int(64)); - trace.Append(ScheduleDesc::Step("AnnotateIntAttr", - {{"block", std::vector({block_b})}}, - {{"key", std::string("k1")}, {"value", int(64)}}, - {})); + trace.Append( + ScheduleDesc::Step("AnnotateIntAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k1")}, {"value", int(64)}}, + {})); block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); ir_sch.Annotate(block_b, "k2", bool(true)); - trace.Append(ScheduleDesc::Step("AnnotateBoolAttr", - {{"block", std::vector({block_b})}}, - {{"key", std::string("k2")}, {"value", bool(true)}}, - {})); + trace.Append( + ScheduleDesc::Step("AnnotateBoolAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k2")}, {"value", bool(true)}}, + {})); block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); ir_sch.Unannotate(block_b, "k1"); - trace.Append( - ScheduleDesc::Step("Unannotate", {{"block", std::vector({block_b})}}, {{"key", std::string("k1")}}, {})); + trace.Append(ScheduleDesc::Step("Unannotate", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k1")}}, + {})); block_b = ir_sch.GetBlock("B"); - trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + trace.Append(ScheduleDesc::Step( + "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); ir_sch.Unannotate(block_b, "k2"); - trace.Append( - ScheduleDesc::Step("Unannotate", {{"block", std::vector({block_b})}}, {{"key", std::string("k2")}}, {})); + trace.Append(ScheduleDesc::Step("Unannotate", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k2")}}, + {})); CheckReplayResult(ir_sch, trace); CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); @@ -769,19 +956,30 @@ TEST_F(TestScheduleDesc, StepKind_SamplePerfectTile) { Placeholder A("A", {M}); auto B = Compute( {M}, [&](Expr i) { return A(i) + n; }, "B"); - lowered_funcs = - cinn::lang::LowerVec("test_sample_perfect_tile", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); + lowered_funcs = cinn::lang::LowerVec("test_sample_perfect_tile", + CreateStages({A, B}), + {A, B}, + {}, + {}, + nullptr, + target, + true); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); - auto loops = ir_sch.GetLoops("B"); - trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step( + "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); auto result = ir_sch.SamplePerfectTile(loops[0], 2, 64); std::vector decision; - std::transform(result.begin(), result.end(), std::back_inserter(decision), [](Expr x) { return x.as_int32(); }); - trace.Append(ScheduleDesc::Step("SamplePerfectTile", - {{"loop", std::vector({loops[0]})}}, - {{"n", 2}, {"max_innermost_factor", 64}, {"decision", decision}}, - result)); + std::transform( + result.begin(), result.end(), std::back_inserter(decision), [](Expr x) { + return x.as_int32(); + }); + trace.Append(ScheduleDesc::Step( + "SamplePerfectTile", + {{"loop", std::vector({loops[0]})}}, + {{"n", 2}, {"max_innermost_factor", 64}, {"decision", decision}}, + result)); CheckTracingOutputs(result, trace); CheckTracingOutputs(result, ir_sch.GetTraceDesc()); CheckReplayResult(ir_sch, trace); @@ -789,16 +987,17 @@ TEST_F(TestScheduleDesc, StepKind_SamplePerfectTile) { } TEST_F(TestScheduleDesc, StepKind_SampleCategorical) { - lowered_funcs = LowerCompute({32, 32, 64}, target, true); - ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); - Expr ret = ir_sch.SampleCategorical({1, 2, 3}, {1.0, 2.0, 3.0}); + lowered_funcs = LowerCompute({32, 32, 64}, target, true); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + Expr ret = ir_sch.SampleCategorical({1, 2, 3}, {1.0, 2.0, 3.0}); std::vector decision = {ret.as_int32()}; - trace.Append(ScheduleDesc::Step("SampleCategorical", - {}, - {{"candidates", std::vector({1, 2, 3})}, - {"probs", std::vector({1.0, 2.0, 3.0})}, - {"decision", decision}}, - {ret})); + trace.Append( + ScheduleDesc::Step("SampleCategorical", + {}, + {{"candidates", std::vector({1, 2, 3})}, + {"probs", std::vector({1.0, 2.0, 3.0})}, + {"decision", decision}}, + {ret})); CheckTracingOutputs({ret}, trace); CheckTracingOutputs({ret}, ir_sch.GetTraceDesc()); CheckReplayResult(ir_sch, trace); diff --git a/paddle/cinn/ir/tensor.cc b/paddle/cinn/ir/tensor.cc index 87a14399e49b9..c5f5232e9a179 100755 --- a/paddle/cinn/ir/tensor.cc +++ b/paddle/cinn/ir/tensor.cc @@ -41,10 +41,10 @@ Tensor _Tensor_::Make(const std::string &name, FunctionRef fn, const std::vector &reduce_axis) { CHECK(!name.empty()) << "Tensor name is set empty"; - auto n = make_shared<_Tensor_>(); - n->name = name; - n->shape = shape; - n->domain = domain; + auto n = make_shared<_Tensor_>(); + n->name = name; + n->shape = shape; + n->domain = domain; n->reduce_axis = reduce_axis; n->set_type(dtype); n->operation = fn; @@ -59,8 +59,9 @@ std::set _Tensor_::GetDependTensorNames() const { std::set names; auto add_depend_tensors_from_expr = [&](Expr expr) { - auto tensors = - CollectIRNodes(expr, [&](const Expr *x) { return x->as_tensor() && x->as_tensor()->name != this->name; }); + auto tensors = CollectIRNodes(expr, [&](const Expr *x) { + return x->as_tensor() && x->as_tensor()->name != this->name; + }); for (auto &e : tensors) { names.insert(e.as_tensor()->name); } @@ -82,10 +83,12 @@ std::set _Tensor_::GetDependTensorNames() const { } Expr Tensor::operator()(const std::vector &indices) const { - CHECK(!self()->is_tuple()) << "should extract a specific value from the tuple and operate on that instead"; + CHECK(!self()->is_tuple()) << "should extract a specific value from the " + "tuple and operate on that instead"; auto *node = operator->(); - CHECK_EQ(indices.size(), ndims()) << "number of indices not match the dimension"; + CHECK_EQ(indices.size(), ndims()) + << "number of indices not match the dimension"; return Load::Make(*this, indices); } @@ -100,14 +103,18 @@ const char *_Tensor_::operation_type() const { return operation->as()->func_type(); } -bool _Tensor_::is_compute_node() const { return std::strcmp(operation_type(), ir::ComputeOp::__func_type__) == 0; } +bool _Tensor_::is_compute_node() const { + return std::strcmp(operation_type(), ir::ComputeOp::__func_type__) == 0; +} bool _Tensor_::is_placeholder_node() const { return std::strcmp(operation_type(), ir::PlaceholderOp::__func_type__) == 0; } -bool _Tensor_::is_call_node() const { return std::strcmp(operation_type(), ir::CallOp::__func_type__) == 0; } +bool _Tensor_::is_call_node() const { + return std::strcmp(operation_type(), ir::CallOp::__func_type__) == 0; +} bool _Tensor_::is_extern_call_node() const { if (std::strcmp(operation_type(), ir::CallOp::__func_type__) == 0) { - auto *op = operation->as(); + auto *op = operation->as(); auto *call = op->call_expr.As(); if (call) { return call->is_extern_call(); @@ -139,7 +146,8 @@ void _Tensor_::InitAxis() const { } bool _Tensor_::has_expression() const { - return (!is_placeholder_node()) && (!is_tuple_get()) && (!is_buffer_shared_node()); + return (!is_placeholder_node()) && (!is_tuple_get()) && + (!is_buffer_shared_node()); } isl::set _Tensor_::GenerateIslDomain() const { @@ -156,7 +164,9 @@ isl::set _Tensor_::GenerateIslDomain() const { if (dim.is_constant()) { dims.emplace_back(_axis_with_reduce[i]->name, 0, dim.as_int32() - 1); } else { - dims.emplace_back(_axis_with_reduce[i]->name, Expr(0), Sub::Make(dim, common::make_const(1))); + dims.emplace_back(_axis_with_reduce[i]->name, + Expr(0), + Sub::Make(dim, common::make_const(1))); } } } @@ -240,26 +250,35 @@ Expr *_Tensor_::mutable_body() { CINN_NOT_IMPLEMENTED } -ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target) const { - CHECK(contains_reduce_axis()) << "InitReduction only works on a reduce tensor"; +ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, + const Target &target) const { + CHECK(contains_reduce_axis()) + << "InitReduction only works on a reduce tensor"; // return if already rexists. std::string init_reduce_tensor_name = GenReduceInitTensorNameOf(name); - if (stages->Lookup(init_reduce_tensor_name)) return stages[this]->LookupCtrlDepend(init_reduce_tensor_name); + if (stages->Lookup(init_reduce_tensor_name)) + return stages[this]->LookupCtrlDepend(init_reduce_tensor_name); // create a new init tensor. auto init_tensor = lang::Compute( - domain, [=](const std::vector &axis) { return GetReduceInitVal(); }, init_reduce_tensor_name); + domain, + [=](const std::vector &axis) { return GetReduceInitVal(); }, + init_reduce_tensor_name); stages->InsertLazily(init_tensor); std::string this_transform = isl_map_to_str(stages[this]->transform().get()); - isl::ctx this_ctx = stages[this]->transform().ctx(); + isl::ctx this_ctx = stages[this]->transform().ctx(); isl::map temp_transform(this_ctx, this_transform); - int reduce_axis_num = this->reduce_axis.size(); - auto dim_out_names = poly::isl_get_dim_names(stages[this]->transform(), isl_dim_out); - auto dim_in_size = isl_map_dim(stages[this]->transform().get(), isl_dim_in); - auto dim_in_names = poly::isl_get_dim_names(stages[this]->transform(), isl_dim_in); - std::vector reduce_axis_input = stages[this]->origin_reduce_axis_names(); - auto origin_domain = stages[this]->domain(); - auto reduce_axis_output = poly::GetRelatedOutputAxies(temp_transform, origin_domain, reduce_axis_input); + int reduce_axis_num = this->reduce_axis.size(); + auto dim_out_names = + poly::isl_get_dim_names(stages[this]->transform(), isl_dim_out); + auto dim_in_size = isl_map_dim(stages[this]->transform().get(), isl_dim_in); + auto dim_in_names = + poly::isl_get_dim_names(stages[this]->transform(), isl_dim_in); + std::vector reduce_axis_input = + stages[this]->origin_reduce_axis_names(); + auto origin_domain = stages[this]->domain(); + auto reduce_axis_output = poly::GetRelatedOutputAxies( + temp_transform, origin_domain, reduce_axis_input); std::set reduce_axis_output_set; for (auto &i : reduce_axis_output) { reduce_axis_output_set.insert(i); @@ -273,7 +292,8 @@ ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target) } } - temp_transform = poly::RemoveAxiesByOutputNames(temp_transform, origin_domain, reduce_axis_output); + temp_transform = poly::RemoveAxiesByOutputNames( + temp_transform, origin_domain, reduce_axis_output); //! When the first axis is not reduce axis, do ComputeAt. if (compute_at_axis >= 0) { @@ -286,16 +306,21 @@ ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target) } //! When reduce axies are reordered to front, ComputeAt is illegal. //! So we just copy transform and forloopInfo. - isl_map_set_tuple_name(temp_transform.get(), isl_dim_in, init_reduce_tensor_name.c_str()); - isl_map_set_tuple_name(temp_transform.get(), isl_dim_out, init_reduce_tensor_name.c_str()); + isl_map_set_tuple_name( + temp_transform.get(), isl_dim_in, init_reduce_tensor_name.c_str()); + isl_map_set_tuple_name( + temp_transform.get(), isl_dim_out, init_reduce_tensor_name.c_str()); stages[init_tensor]->SetTransform(temp_transform); - auto init_dim_out_names = poly::isl_get_dim_names(temp_transform, isl_dim_out); - std::map temp_forloop_info = stages[this]->forloop_infos(); + auto init_dim_out_names = + poly::isl_get_dim_names(temp_transform, isl_dim_out); + std::map temp_forloop_info = + stages[this]->forloop_infos(); std::map init_forloop_info; for (auto &i : temp_forloop_info) { for (int j = 0; j < init_dim_out_names.size(); j++) { if (i.first < 0) continue; - int new_i = poly::isl_get_original_axes_from_optimized_level(stages[this]->transformed_domain().get(), i.first); + int new_i = poly::isl_get_original_axes_from_optimized_level( + stages[this]->transformed_domain().get(), i.first); if (dim_out_names[new_i] == init_dim_out_names[j]) { stages[init_tensor]->AddForloopInfo(j, i.second); } @@ -308,7 +333,8 @@ ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target) return init_tensor; } -ir::Tensor _Tensor_::GetInitTensor(poly::StageMap stages, const Target &target) const { +ir::Tensor _Tensor_::GetInitTensor(poly::StageMap stages, + const Target &target) const { return InitReduction(stages, target); } @@ -384,11 +410,13 @@ void _Tensor_::WithBuffer(const Type &type) { Bind(buf); } -void _Tensor_::WithBuffer(const std::string &memory_type, const std::string &buffer_name, const Type &type) { +void _Tensor_::WithBuffer(const std::string &memory_type, + const std::string &buffer_name, + const Type &type) { Type buf_type = type.is_void() ? type_ : type; if (this->buffer.defined()) { this->buffer->dtype = buf_type; - this->buffer->name = buffer_name; + this->buffer->name = buffer_name; if (memory_type == "shared") { this->buffer->memory_type = MemoryType::GPUShared; } else if (memory_type == "local") { @@ -461,11 +489,13 @@ Tensor::Tensor(const std::string &name, const std::vector &domain, FunctionRef fn, const std::vector &reduce_axis) - : IrNodeRef(_Tensor_::Make(name, dtype, shape, domain, fn, reduce_axis).self()) {} + : IrNodeRef( + _Tensor_::Make(name, dtype, shape, domain, fn, reduce_axis).self()) {} bool _Tensor_::is_tuple_get() const { return is_call_node() && operation.defined() && - operation->as()->func_type() == ir::CallOp::__func_type__ && + operation->as()->func_type() == + ir::CallOp::__func_type__ && operation->as()->is_tuple_get; } @@ -484,7 +514,8 @@ bool _Tensor_::IsDependOnStatement(absl::string_view statement) { std::set _Tensor_::DependingTensorNames() { std::set res; if (body().defined()) { - auto depend_tensors = ir::CollectIRNodes(body(), [](const Expr *x) -> bool { return x->as_tensor(); }); + auto depend_tensors = ir::CollectIRNodes( + body(), [](const Expr *x) -> bool { return x->as_tensor(); }); for (const auto &x : depend_tensors) { if (x.get() != this) { res.insert(x.as_tensor()->name); @@ -514,10 +545,11 @@ bool _Tensor_::Uses(const Tensor &other) const { return !loads.empty(); } -ir::Tensor _Tensor_::Reshape(const std::vector &shape, poly::StageMap stages) const { +ir::Tensor _Tensor_::Reshape(const std::vector &shape, + poly::StageMap stages) const { CHECK(!stages[this]->inlined()); - auto op = BufferShareOp::Make(); - auto n = make_shared<_Tensor_>(); + auto op = BufferShareOp::Make(); + auto n = make_shared<_Tensor_>(); auto selft = Tensor(const_cast(this)); { @@ -527,11 +559,12 @@ ir::Tensor _Tensor_::Reshape(const std::vector &shape, poly::StageMap stag Expr num_elements = Expr(1); for (auto &e : shape) num_elements = num_elements * e; - CHECK(MathIsZero(this_num_elements - num_elements)) << "number of elements mismatch"; + CHECK(MathIsZero(this_num_elements - num_elements)) + << "number of elements mismatch"; } - n->name = Context::Global().NewName(name + "_reshape"); - n->shape = shape; + n->name = Context::Global().NewName(name + "_reshape"); + n->shape = shape; n->domain = shape; n->set_type(type()); n->operation = op; @@ -545,8 +578,9 @@ ir::Tensor _Tensor_::Reshape(const std::vector &shape, poly::StageMap stag return t; } -ir::Tensor _Tensor_::ReshapeCopied(const std::vector &shape, poly::StageMap stages) const { - auto t = ir::Tensor(const_cast(this)); +ir::Tensor _Tensor_::ReshapeCopied(const std::vector &shape, + poly::StageMap stages) const { + auto t = ir::Tensor(const_cast(this)); auto copied = Compute( domain, [=](const std::vector &axis) { return t(axis); }, @@ -562,15 +596,19 @@ Shared CreateStage(Tensor tensor) { return poly::Stage::New(isl_domain, tensor->body(), tensor.self()); } -std::string GenReduceInitTensorNameOf(const std::string &tensor_name) { return tensor_name + "__reduce_init"; } +std::string GenReduceInitTensorNameOf(const std::string &tensor_name) { + return tensor_name + "__reduce_init"; +} bool _Tensor_::is_reduce_sum() const { if (!contains_reduce_axis()) return false; - return body().As() && body().As()->reduce_type == ir::Reduce::ReduceType::kSum; + return body().As() && + body().As()->reduce_type == ir::Reduce::ReduceType::kSum; } bool _Tensor_::is_reduce_mul() const { if (!contains_reduce_axis()) return false; - return body().As() && body().As()->reduce_type == ir::Reduce::ReduceType::kMul; + return body().As() && + body().As()->reduce_type == ir::Reduce::ReduceType::kMul; } Expr _Tensor_::GetReduceInitVal() const { @@ -578,7 +616,9 @@ Expr _Tensor_::GetReduceInitVal() const { return body().As()->init; } -bool _Tensor_::IsReduceInited(poly::StageMap stages) const { return stages->Lookup(GenReduceInitTensorNameOf(name)); } +bool _Tensor_::IsReduceInited(poly::StageMap stages) const { + return stages->Lookup(GenReduceInitTensorNameOf(name)); +} void _Tensor_::Verify() const { CHECK(!shape.empty()); diff --git a/paddle/cinn/ir/tensor.h b/paddle/cinn/ir/tensor.h index f97409b25b04b..437e0f2c5e605 100644 --- a/paddle/cinn/ir/tensor.h +++ b/paddle/cinn/ir/tensor.h @@ -80,9 +80,12 @@ class Tensor : public ir::IrNodeRef { * A(i,j) get the [i][j] element. */ // @{ - Expr operator()(const Expr& a) const { return operator()(std::vector({a})); } + Expr operator()(const Expr& a) const { + return operator()(std::vector({a})); + } template - inline typename std::enable_if::type operator()(Args&&... args) const { + inline typename std::enable_if::type + operator()(Args&&... args) const { return operator()({std::forward(args)...}); } // @} @@ -108,7 +111,8 @@ class Tensor : public ir::IrNodeRef { /** * \brief Generate the name of the reduce init tensor of \p tensor. - * This is used for retrieving the corresponding reduction-init tensor from a stage map by name. + * This is used for retrieving the corresponding reduction-init tensor from a + * stage map by name. */ std::string GenReduceInitTensorNameOf(const std::string& tensor_name); @@ -159,7 +163,8 @@ class _Tensor_ : public ExprNode<_Tensor_> { bool IsReduceInited(poly::StageMap stages) const; - //! Tell whether this tensor represents a tuple (consists of one or multiple tensors as output of a extern Call). + //! Tell whether this tensor represents a tuple (consists of one or multiple + //! tensors as output of a extern Call). bool is_tuple() const; bool is_tuple_get() const; @@ -172,7 +177,8 @@ class _Tensor_ : public ExprNode<_Tensor_> { std::set GetDependTensorNames() const; /** - * \brief Tell whether this tensor's computation relays on a specific statement. + * \brief Tell whether this tensor's computation relays on a specific + * statement. * @param statement The name of a statement(equivalent to the id of tensor). * @return A boolean. */ @@ -187,13 +193,15 @@ class _Tensor_ : public ExprNode<_Tensor_> { * Get a new tensor with the \p shape, but the underlying buffer shared. * NOTE the tensor to Reshape should not be an inlined computation. */ - ir::Tensor Reshape(const std::vector& shape, poly::StageMap stages) const; + ir::Tensor Reshape(const std::vector& shape, + poly::StageMap stages) const; /** * Get a new tensor with the \p shape with a newly allocated buffer. * NOTE the tensor to Reshape should not be an inlined computation. */ - ir::Tensor ReshapeCopied(const std::vector& shape, poly::StageMap stages) const; + ir::Tensor ReshapeCopied(const std::vector& shape, + poly::StageMap stages) const; /** * Tell whether this tensor has same shape with \p other. @@ -245,7 +253,9 @@ class _Tensor_ : public ExprNode<_Tensor_> { /** * Get the tensors thouse depend on the same buffer belong to this tensor. */ - const std::set& buffer_depended_tensor_names() const { return buffer_depended_tensor_names_; } + const std::set& buffer_depended_tensor_names() const { + return buffer_depended_tensor_names_; + } static const IrNodeTy _node_type_ = IrNodeTy::_Tensor_; @@ -267,8 +277,12 @@ class _Tensor_ : public ExprNode<_Tensor_> { //! Create a buffer belong to this tensor. void WithBuffer(const Type& type = Void()); - void WithBuffer(const std::string& memory_type, const std::string& buffer_name = "", const Type& type = Void()); - Tensor GetInitTensor(poly::StageMap stages, const Target& target = common::DefaultHostTarget()) const; + void WithBuffer(const std::string& memory_type, + const std::string& buffer_name = "", + const Type& type = Void()); + Tensor GetInitTensor( + poly::StageMap stages, + const Target& target = common::DefaultHostTarget()) const; private: //! Initialize the axis field after the shape field is assigned. @@ -282,14 +296,19 @@ class _Tensor_ : public ExprNode<_Tensor_> { * @param init_val The initial value. * @return The initializing tensor. */ - ir::Tensor InitReduction(poly::StageMap stages, const Target& target = common::DefaultHostTarget()) const; + ir::Tensor InitReduction( + poly::StageMap stages, + const Target& target = common::DefaultHostTarget()) const; - //! The names of the tensors depend the same buffer and should schedule before this. + //! The names of the tensors depend the same buffer and should schedule before + //! this. std::set buffer_depended_tensor_names_; friend Shared CreateStage(Tensor tensor); - friend void lang::InitReduceTensor(poly::StageMap stages, const ir::Tensor& tensor, const Target& target); + friend void lang::InitReduceTensor(poly::StageMap stages, + const ir::Tensor& tensor, + const Target& target); }; Shared CreateStage(Tensor tensor); @@ -300,8 +319,12 @@ class Operation : public FunctionRef { Operation() = default; explicit Operation(IrNode* n) : FunctionRef(n) {} - inline const _Operation_* operator->() const { return reinterpret_cast<_Operation_*>(get()); } - inline _Operation_* operator->() { return reinterpret_cast<_Operation_*>(get()); } + inline const _Operation_* operator->() const { + return reinterpret_cast<_Operation_*>(get()); + } + inline _Operation_* operator->() { + return reinterpret_cast<_Operation_*>(get()); + } //! Get the i-th output of the operation. // Tensor output(size_t i) const; diff --git a/paddle/cinn/ir/tensor_test.cc b/paddle/cinn/ir/tensor_test.cc index 049b3c75ae1a0..19b92ce51703a 100755 --- a/paddle/cinn/ir/tensor_test.cc +++ b/paddle/cinn/ir/tensor_test.cc @@ -86,8 +86,10 @@ TEST(Tensor, Reshape) { auto stages = CreateStages({A}); auto A1 = A->Reshape({Expr(10), Expr(10), Expr(100)}, stages); - auto B = Compute( - A1->shape, [=](Expr i, Expr j, Expr k) { return A1(i, j, k) * 2.f; }, "B"); + auto B = Compute( + A1->shape, + [=](Expr i, Expr j, Expr k) { return A1(i, j, k) * 2.f; }, + "B"); stages->InsertLazily(B); @@ -135,8 +137,10 @@ TEST(Tensor, ReshapeCopied) { auto stages = CreateStages({A}); auto A1 = A->ReshapeCopied({Expr(10), Expr(10), Expr(100)}, stages); - auto B = Compute( - A1->shape, [=](Expr i, Expr j, Expr k) { return A1(i, j, k) * 2.f; }, "B"); + auto B = Compute( + A1->shape, + [=](Expr i, Expr j, Expr k) { return A1(i, j, k) * 2.f; }, + "B"); stages->InsertLazily(B); @@ -189,7 +193,9 @@ TEST(Tensor, reduce) { { auto C = Compute( A->shape, - [=](const std::vector& axis) { return lang::ReduceSum(A(reduce_axis) + 1.f, {reduce_axis}); }, + [=](const std::vector& axis) { + return lang::ReduceSum(A(reduce_axis) + 1.f, {reduce_axis}); + }, "C"); ASSERT_TRUE(C->has_expression()); ASSERT_TRUE(C->is_reduce_sum()); @@ -199,7 +205,9 @@ TEST(Tensor, reduce) { { auto C = Compute( A->shape, - [=](const std::vector& axis) { return lang::ReduceMul(A(reduce_axis) + 1.f, {reduce_axis}); }, + [=](const std::vector& axis) { + return lang::ReduceMul(A(reduce_axis) + 1.f, {reduce_axis}); + }, "C"); ASSERT_TRUE(C->has_expression()); ASSERT_TRUE(C->is_reduce_mul()); diff --git a/paddle/cinn/lang/buffer.cc b/paddle/cinn/lang/buffer.cc index 185633f66b0dc..864adfb165cad 100644 --- a/paddle/cinn/lang/buffer.cc +++ b/paddle/cinn/lang/buffer.cc @@ -22,7 +22,7 @@ namespace lang { using ir::_Buffer_; Buffer::Buffer(Type type, const std::string& name) { - buffer_ = _Buffer_::Make(); + buffer_ = _Buffer_::Make(); buffer_->dtype = type; buffer_->set_type(type_of()); buffer_->elem_offset = Expr(0); diff --git a/paddle/cinn/lang/builtin.cc b/paddle/cinn/lang/builtin.cc index 34e150ba35472..3e7ef7390cf7e 100644 --- a/paddle/cinn/lang/builtin.cc +++ b/paddle/cinn/lang/builtin.cc @@ -48,13 +48,22 @@ Expr logic_or(const std::vector& conds) { } //! extern call op -#define EXTERN_CALL_IMP(name__, target__) \ - Expr name__(Expr e) { return ir::Call::Make(e->type(), #target__, {e}, {}, ir::CallType::Extern); } +#define EXTERN_CALL_IMP(name__, target__) \ + Expr name__(Expr e) { \ + return ir::Call::Make( \ + e->type(), #target__, {e}, {}, ir::CallType::Extern); \ + } -#define EXTERN_CALL_IMP_NO_VEC(name__, target__) \ - Expr name__(Expr e) { \ - return ir::Call::Make( \ - e->type(), #target__, {e}, {}, ir::CallType::Extern, ir::FunctionRef(), 0, {{"vectorizable", false}}); \ +#define EXTERN_CALL_IMP_NO_VEC(name__, target__) \ + Expr name__(Expr e) { \ + return ir::Call::Make(e->type(), \ + #target__, \ + {e}, \ + {}, \ + ir::CallType::Extern, \ + ir::FunctionRef(), \ + 0, \ + {{"vectorizable", false}}); \ } EXTERN_CALL_IMP(Exp, exp); @@ -87,11 +96,13 @@ EXTERN_CALL_IMP(Popc, popc); #undef EXTERN_CALL_IMP #undef EXTERN_CALL_IMP_NO_VEC -#define EXTERN_BINARY_CALL_IMP(name__, target__) \ - Expr name__(Expr a, Expr b) { \ - CHECK_EQ(a.type(), b.type()) << #name__ << "'s inputs type not equal, where a:" << a.type() \ - << " but b:" << b.type(); \ - return ir::Call::Make(a->type(), #target__, {a, b}, {}, ir::CallType::Extern); \ +#define EXTERN_BINARY_CALL_IMP(name__, target__) \ + Expr name__(Expr a, Expr b) { \ + CHECK_EQ(a.type(), b.type()) \ + << #name__ << "'s inputs type not equal, where a:" << a.type() \ + << " but b:" << b.type(); \ + return ir::Call::Make( \ + a->type(), #target__, {a, b}, {}, ir::CallType::Extern); \ } EXTERN_BINARY_CALL_IMP(Remainder, mod) @@ -106,7 +117,9 @@ Expr Zero(const Type& type) { return ir::Zero(type); } Expr One(const Type& type) { return ir::One(type); } Expr FloorDivide(Expr a, Expr b) { - CHECK_EQ(a.type(), b.type()) << "FloorDivide's inputs type not equal, where a:" << a.type() << " but b:" << b.type(); + CHECK_EQ(a.type(), b.type()) + << "FloorDivide's inputs type not equal, where a:" << a.type() + << " but b:" << b.type(); if (a.type().is_float()) { return Floor(a / b); } else if (a.type().is_uint()) { @@ -114,8 +127,10 @@ Expr FloorDivide(Expr a, Expr b) { } else { auto div = a / b; auto mod = a % b; - auto ret = ir::Select::Make( - ir::EQ::Make(mod, common::make_const(a.type(), 0)), div, div - common::make_const(a.type(), 1)); + auto ret = + ir::Select::Make(ir::EQ::Make(mod, common::make_const(a.type(), 0)), + div, + div - common::make_const(a.type(), 1)); return ir::Select::Make((a > 0 && b > 0) || (a < 0 && b < 0), div, ret); } } @@ -193,7 +208,7 @@ Expr Epsilon(const Type& type) { } Expr Abs(Expr e) { - Type type = e->type(); + Type type = e->type(); Type bool_type = Bool(type.lanes()); if (type.is_uint()) { return e; diff --git a/paddle/cinn/lang/builtin.h b/paddle/cinn/lang/builtin.h index 4ee302ee6eae3..b18c3ad1308a2 100644 --- a/paddle/cinn/lang/builtin.h +++ b/paddle/cinn/lang/builtin.h @@ -82,12 +82,12 @@ inline Expr Sigmoid(Expr e) { } inline Expr Sign(Expr e) { - auto zero = Zero(e->type()); - auto one = One(e->type()); + auto zero = Zero(e->type()); + auto one = One(e->type()); auto neg_one = ir::Cast::Make(e->type(), Expr(-1)); - auto ret0 = ir::Select::Make(ir::EQ::Make(e, zero), zero, e); - auto ret1 = ir::Select::Make(e > zero, one, ret0); - auto ret2 = ir::Select::Make(e < zero, neg_one, ret1); + auto ret0 = ir::Select::Make(ir::EQ::Make(e, zero), zero, e); + auto ret1 = ir::Select::Make(e > zero, one, ret0); + auto ret2 = ir::Select::Make(e < zero, neg_one, ret1); return ret2; } @@ -108,13 +108,15 @@ inline Expr Relu(Expr e, double threshold = 0.0) { } inline Expr Relu6(Expr e, double threshold = 0.0) { - return ir::Min::Make(ir::Max::Make(e, ir::Cast::Make(e->type(), Expr(threshold))), - ir::Cast::Make(e->type(), Expr(6.0))); + return ir::Min::Make( + ir::Max::Make(e, ir::Cast::Make(e->type(), Expr(threshold))), + ir::Cast::Make(e->type(), Expr(6.0))); } inline Expr LeakyRelu(Expr e, double alpha) { auto zero = Zero(e->type()); - return ir::Select::Make(e > zero, e, e * ir::Cast::Make(e->type(), Expr(alpha))); + return ir::Select::Make( + e > zero, e, e * ir::Cast::Make(e->type(), Expr(alpha))); } inline Expr LeakyRelu(Expr e, Expr alpha) { @@ -122,39 +124,51 @@ inline Expr LeakyRelu(Expr e, Expr alpha) { return ir::Select::Make(e > zero, e, e * alpha); } -inline Expr ReduceSum(Expr e, const std::vector& reduce_axis, Expr initial = Expr()) { +inline Expr ReduceSum(Expr e, + const std::vector& reduce_axis, + Expr initial = Expr()) { if (!initial.defined()) { initial = Zero(e->type()); } return ir::Reduce::Make(ir::Reduce::kSum, initial, e, reduce_axis); } -inline Expr ReduceMul(Expr e, const std::vector& reduce_axis, Expr initial = Expr()) { +inline Expr ReduceMul(Expr e, + const std::vector& reduce_axis, + Expr initial = Expr()) { if (!initial.defined()) { initial = One(e->type()); } return ir::Reduce::Make(ir::Reduce::kMul, initial, e, reduce_axis); } -inline Expr ReduceMax(Expr e, const std::vector& reduce_axis, Expr initial = Expr()) { +inline Expr ReduceMax(Expr e, + const std::vector& reduce_axis, + Expr initial = Expr()) { if (!initial.defined()) { initial = min_value(e.type()); } return ir::Reduce::Make(ir::Reduce::kMax, initial, e, reduce_axis); } -inline Expr ReduceMin(Expr e, const std::vector& reduce_axis, Expr initial = Expr()) { +inline Expr ReduceMin(Expr e, + const std::vector& reduce_axis, + Expr initial = Expr()) { if (!initial.defined()) { initial = max_value(e.type()); } return ir::Reduce::Make(ir::Reduce::kMin, initial, e, reduce_axis); } -inline Expr ReduceAll(Expr e, const std::vector& reduce_axis, Expr initial = Expr()) { +inline Expr ReduceAll(Expr e, + const std::vector& reduce_axis, + Expr initial = Expr()) { if (!initial.defined()) { initial = Expr(true); } return ir::Reduce::Make(ir::Reduce::kAll, initial, e, reduce_axis); } -inline Expr ReduceAny(Expr e, const std::vector& reduce_axis, Expr initial = Expr()) { +inline Expr ReduceAny(Expr e, + const std::vector& reduce_axis, + Expr initial = Expr()) { if (!initial.defined()) { initial = Expr(false); } diff --git a/paddle/cinn/lang/compute.cc b/paddle/cinn/lang/compute.cc index 629c915a4f628..a81ea059cc3fa 100644 --- a/paddle/cinn/lang/compute.cc +++ b/paddle/cinn/lang/compute.cc @@ -136,7 +136,8 @@ ir::Tensor Compute(const std::vector &domain, std::vector reduce_axis; if (fn_body.defined() && fn_body.As()) { auto &fn_reduce_axis = fn_body.As()->reduce_axis; - reduce_axis.insert(std::begin(reduce_axis), fn_reduce_axis.begin(), fn_reduce_axis.end()); + reduce_axis.insert( + std::begin(reduce_axis), fn_reduce_axis.begin(), fn_reduce_axis.end()); } // When the fn_body is a CallExtern, a tensor will return directly. @@ -161,7 +162,8 @@ ir::Tensor Compute(const std::vector &domain, shape_simplified.push_back(copied); } - auto real_shape = shape_simplified.empty() ? domain_without_reduce_axis : shape_simplified; + auto real_shape = + shape_simplified.empty() ? domain_without_reduce_axis : shape_simplified; // The body returns void, that means no buffer is needed. if (fn_body.type() == Void()) real_shape.clear(); @@ -170,25 +172,39 @@ ir::Tensor Compute(const std::vector &domain, // check reduce_axis not include the reserved axis name for (auto &ra : reduce_axis) { - CHECK(!common::IsAxisNameReserved(ra->name)) << "reduce axis [" << ra->name << "]'s name is reserved"; + CHECK(!common::IsAxisNameReserved(ra->name)) + << "reduce axis [" << ra->name << "]'s name is reserved"; } - VLOG(3) << "tensor " << name << "'s domain is : " << domain_without_reduce_axis; + VLOG(3) << "tensor " << name + << "'s domain is : " << domain_without_reduce_axis; - auto op = ir::ComputeOp::Make(unique_name, fn, real_shape, domain_without_reduce_axis, reduce_axis); - auto tensor = ir::Tensor(unique_name, fn_body.type(), real_shape, domain_without_reduce_axis, op, reduce_axis); + auto op = ir::ComputeOp::Make( + unique_name, fn, real_shape, domain_without_reduce_axis, reduce_axis); + auto tensor = ir::Tensor(unique_name, + fn_body.type(), + real_shape, + domain_without_reduce_axis, + op, + reduce_axis); return tensor; } -std::vector CallLowered(const std::string &func_name, - const std::vector &args, - const std::vector &return_types) { - auto call = ir::Call::Make(Void(), func_name, args, {}, ir::CallType::CINN, ir::FunctionRef(), 0); +std::vector CallLowered( + const std::string &func_name, + const std::vector &args, + const std::vector &return_types) { + auto call = ir::Call::Make( + Void(), func_name, args, {}, ir::CallType::CINN, ir::FunctionRef(), 0); std::vector new_tensors; for (int i = 0; i < return_types.size(); i++) { auto &return_type = return_types[i]; - auto call_op = ir::CallOp::Make(func_name, call); - auto new_tensor = ir::Tensor(return_type.name, return_type.type, return_type.dims, {Expr(1)}, call_op); + auto call_op = ir::CallOp::Make(func_name, call); + auto new_tensor = ir::Tensor(return_type.name, + return_type.type, + return_type.dims, + {Expr(1)}, + call_op); // Append write tensors in the tail. call.As()->write_args.push_back(new_tensor); new_tensor->set_type(return_type.type); @@ -202,22 +218,33 @@ std::vector CallLowered(const std::string &func_name, Expr CallExtern(const std::string &func_name, const std::vector &args, const std::map &attrs) { - auto *proto = backends::ExternFunctionProtoRegistry::Global().Lookup(func_name); - CHECK(proto) << "No extern function prototype " << func_name << " found\n" - << "existing records are:\n" - << backends::ExternFunctionProtoRegistry::Global().debug_string(); + auto *proto = + backends::ExternFunctionProtoRegistry::Global().Lookup(func_name); + CHECK(proto) + << "No extern function prototype " << func_name << " found\n" + << "existing records are:\n" + << backends::ExternFunctionProtoRegistry::Global().debug_string(); - auto call = ir::Call::Make(proto->ret_type, func_name, args, {}, ir::CallType::Extern, ir::FunctionRef(), 0, attrs); + auto call = ir::Call::Make(proto->ret_type, + func_name, + args, + {}, + ir::CallType::Extern, + ir::FunctionRef(), + 0, + attrs); std::vector mutable_args; // Call a function with multiple outputs. if (proto->ret_type.is_void()) { for (int i = 0; i < proto->mutable_arg_types.size(); i++) { - auto shape = proto->shape_inference(args, i); - auto op = ir::CallOp::Make(func_name, call); - op->as()->value_slot = i; + auto shape = proto->shape_inference(args, i); + auto op = ir::CallOp::Make(func_name, call); + op->as()->value_slot = i; op->as()->is_tuple_get = true; - auto name = cinn::UniqName("tuple_" + func_name + "_out" + std::to_string(i) + "_"); - auto ret = ir::Tensor(name, proto->mutable_arg_types[i], shape, shape, op, {}); + auto name = cinn::UniqName("tuple_" + func_name + "_out" + + std::to_string(i) + "_"); + auto ret = + ir::Tensor(name, proto->mutable_arg_types[i], shape, shape, op, {}); mutable_args.push_back(ret); } call.As()->write_args = mutable_args; diff --git a/paddle/cinn/lang/compute.h b/paddle/cinn/lang/compute.h index 0970caa179603..5917db305b46f 100755 --- a/paddle/cinn/lang/compute.h +++ b/paddle/cinn/lang/compute.h @@ -30,7 +30,7 @@ namespace cinn { namespace lang { using compute_handler_t = std::function &)>; -using attr_t = absl::variant; +using attr_t = absl::variant; //! Compute methods for one to five Vars as arguments. // @{ @@ -83,16 +83,19 @@ struct ReturnType { * * A lowered function is generated by lang::Lower method. * - * TODO(Superjomn) Add a registry (symbol table?) to make return result inference automatically. + * TODO(Superjomn) Add a registry (symbol table?) to make return result + * inference automatically. * * @param func_name The name of the function to call. - * @param args The readonly arguments(while the mutable tensors are return result). + * @param args The readonly arguments(while the mutable tensors are return + * result). * @param return_types The types of the return values. * @return Return one or more tensors as result. */ -std::vector CallLowered(const std::string &func_name, - const std::vector &args, - const std::vector &return_types); +std::vector CallLowered( + const std::string &func_name, + const std::vector &args, + const std::vector &return_types); /** * \brief Call an external function and get some tensors as result. @@ -104,13 +107,12 @@ std::vector CallLowered(const std::string &func_name, * Tensor tuple = Compute({M}, []() { return CallExtern("mkl_gemm", {X, W}); }); * \endcode * - * To support returning multiple value one time, we include the tuple concept, it is a Tensor with CallOp marked with - * value_offset(from 0 to num_returns-1). + * To support returning multiple value one time, we include the tuple concept, + * it is a Tensor with CallOp marked with value_offset(from 0 to num_returns-1). * - * 2. POD value, return an expression directly, and it can be inline expand in following computations. - * \code - * Tensor tanh_out = Compute({M}, [](Var i) { return CallExtern("tanh", X(i)); }); - * \endcode + * 2. POD value, return an expression directly, and it can be inline expand in + * following computations. \code Tensor tanh_out = Compute({M}, [](Var i) { + * return CallExtern("tanh", X(i)); }); \endcode * * Will generate something like * @@ -121,7 +123,8 @@ std::vector CallLowered(const std::string &func_name, * \endcode * * @param func_name The name of the function to call. - * @param args The readonly arguments(while there should be only one tensor as result). + * @param args The readonly arguments(while there should be only one tensor as + * result). * @param attrs The readonly attrs. */ Expr CallExtern(const std::string &func_name, diff --git a/paddle/cinn/lang/compute_test.cc b/paddle/cinn/lang/compute_test.cc index f5244016012e9..d666a47547ac7 100644 --- a/paddle/cinn/lang/compute_test.cc +++ b/paddle/cinn/lang/compute_test.cc @@ -31,7 +31,8 @@ TEST(Call, basic) { Placeholder x("x", {M, Expr(10)}); Placeholder y("y", {M, Expr(10)}); - std::vector return_types({{Float(32), std::vector{{M, Expr(20)}}, "C"}}); + std::vector return_types( + {{Float(32), std::vector{{M, Expr(20)}}, "C"}}); auto tensors = CallLowered("lowered_fun0", {Expr(x), Expr(y)}, return_types); } diff --git a/paddle/cinn/lang/lower.cc b/paddle/cinn/lang/lower.cc index 160ab5be22160..5d09b7e1eb983 100755 --- a/paddle/cinn/lang/lower.cc +++ b/paddle/cinn/lang/lower.cc @@ -33,12 +33,14 @@ namespace lang { using ir::Tensor; using poly::Stage; -std::vector GetArgs(const Expr& func_body, const std::vector& input_output_nodes) { +std::vector GetArgs( + const Expr& func_body, const std::vector& input_output_nodes) { std::vector res; std::map> name2loads; std::map> name2stores; auto load_or_store_nodes = ir::CollectIRNodesWithoutTensor( - func_body, [&](const Expr* x) { return x->As() || x->As(); }); + func_body, + [&](const Expr* x) { return x->As() || x->As(); }); for (auto&& e : load_or_store_nodes) { if (e.As()) { @@ -51,9 +53,10 @@ std::vector GetArgs(const Expr& func_body, const std::vectorsecond) { const auto* tensor = node->tensor.as_tensor(); @@ -74,7 +77,8 @@ std::vector GetArgs(const Expr& func_body, const std::vector GetTempBuffers(const std::vector& tensor_args, buffer_arg_names.insert(tensor->buffer->name); } } - std::map name_to_buffer; // used to avoid duplication. + std::map + name_to_buffer; // used to avoid duplication. - auto all_temp_tensors = ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) { - return x->as_tensor() && x->as_tensor()->buffer.defined() && - (!stage_map->Lookup(x->as_tensor()->name) || !stage_map[x->as_tensor()]->inlined()) && - ((!buffer_arg_names.count(x->as_tensor()->buffer->name) && !tensor_arg_names.count(x->as_tensor()->name)) || - utils::Endswith(x->as_tensor()->buffer->name, "temp_buffer")); - }); + auto all_temp_tensors = + ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) { + return x->as_tensor() && x->as_tensor()->buffer.defined() && + (!stage_map->Lookup(x->as_tensor()->name) || + !stage_map[x->as_tensor()]->inlined()) && + ((!buffer_arg_names.count(x->as_tensor()->buffer->name) && + !tensor_arg_names.count(x->as_tensor()->name)) || + utils::Endswith(x->as_tensor()->buffer->name, "temp_buffer")); + }); for (auto& e : all_temp_tensors) { auto buffer_name = e.as_tensor()->buffer->name; if (!name_to_buffer.count(buffer_name)) { name_to_buffer[buffer_name] = e.as_tensor()->buffer; } else { - if (e.as_tensor()->buffer->numel() < name_to_buffer[buffer_name]->numel()) { + if (e.as_tensor()->buffer->numel() < + name_to_buffer[buffer_name]->numel()) { name_to_buffer[buffer_name] = e.as_tensor()->buffer; } } @@ -114,7 +123,9 @@ std::vector GetTempBuffers(const std::vector& tensor_args, auto update_map = ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) { if (x->as_tensor() && x->as_tensor()->buffer.defined()) { auto buffer_name = x->as_tensor()->buffer->name; - if (name_to_buffer.count(buffer_name) && x->as_tensor()->buffer->numel() < name_to_buffer[buffer_name]->numel()) { + if (name_to_buffer.count(buffer_name) && + x->as_tensor()->buffer->numel() < + name_to_buffer[buffer_name]->numel()) { name_to_buffer[buffer_name] = x->as_tensor()->buffer; } } @@ -127,26 +138,30 @@ std::vector GetTempBuffers(const std::vector& tensor_args, } //! Collect the temporary tensors from a computational graph. -std::vector GetTempBuffers(const std::vector& args, Expr body) { +std::vector GetTempBuffers(const std::vector& args, + Expr body) { std::unordered_set buffer_arg_names; for (auto& a : args) { if (a.is_buffer()) { buffer_arg_names.insert(a.name()); } } - std::map name_to_buffer; // used to avoid duplication. + std::map + name_to_buffer; // used to avoid duplication. - auto all_temp_tensors = ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) { - return x->as_tensor() && x->as_tensor()->buffer.defined() && - (!buffer_arg_names.count(x->as_tensor()->buffer->name) || - utils::Endswith(x->as_tensor()->buffer->name, "temp_buffer")); - }); + auto all_temp_tensors = + ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) { + return x->as_tensor() && x->as_tensor()->buffer.defined() && + (!buffer_arg_names.count(x->as_tensor()->buffer->name) || + utils::Endswith(x->as_tensor()->buffer->name, "temp_buffer")); + }); for (auto& e : all_temp_tensors) { auto buffer_name = e.as_tensor()->buffer->name; if (!name_to_buffer.count(buffer_name)) { name_to_buffer[buffer_name] = e.as_tensor()->buffer; } else { - if (e.as_tensor()->buffer->numel() < name_to_buffer[buffer_name]->numel()) { + if (e.as_tensor()->buffer->numel() < + name_to_buffer[buffer_name]->numel()) { name_to_buffer[buffer_name] = e.as_tensor()->buffer; } } @@ -155,7 +170,9 @@ std::vector GetTempBuffers(const std::vector& args, Ex auto update_map = ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) { if (x->as_tensor() && x->as_tensor()->buffer.defined()) { auto buffer_name = x->as_tensor()->buffer->name; - if (name_to_buffer.count(buffer_name) && x->as_tensor()->buffer->numel() < name_to_buffer[buffer_name]->numel()) { + if (name_to_buffer.count(buffer_name) && + x->as_tensor()->buffer->numel() < + name_to_buffer[buffer_name]->numel()) { name_to_buffer[buffer_name] = x->as_tensor()->buffer; } } @@ -167,11 +184,13 @@ std::vector GetTempBuffers(const std::vector& args, Ex return temp_buffers; } -std::set CollectTempTensorsFromCtrlDepends(StageMap stages, const std::vector& tensor_args) { +std::set CollectTempTensorsFromCtrlDepends( + StageMap stages, const std::vector& tensor_args) { std::set res; for (auto& stage : stages) { res.emplace(ir::Tensor(stage.second->tensor())); - res.insert(stage.second->ctrl_depends().begin(), stage.second->ctrl_depends().end()); + res.insert(stage.second->ctrl_depends().begin(), + stage.second->ctrl_depends().end()); } for (auto& t : tensor_args) { if (res.count(t)) res.erase(t); @@ -179,14 +198,18 @@ std::set CollectTempTensorsFromCtrlDepends(StageMap stages, const st return res; } -void InitReduceTensor(StageMap stages, const Tensor& tensor, const Target& target) { +void InitReduceTensor(StageMap stages, + const Tensor& tensor, + const Target& target) { if (tensor->is_reduce_tensor() && !tensor->IsReduceInited(stages)) { tensor->InitReduction(stages, target); } - auto uninited_reduce_tensors = ir::CollectIRNodes(tensor->body(), [&](const Expr* x) { - return x && x->defined() && x->as_tensor() && x->as_tensor()->is_reduce_tensor() && - !x->as_tensor()->IsReduceInited(stages); - }); + auto uninited_reduce_tensors = + ir::CollectIRNodes(tensor->body(), [&](const Expr* x) { + return x && x->defined() && x->as_tensor() && + x->as_tensor()->is_reduce_tensor() && + !x->as_tensor()->IsReduceInited(stages); + }); for (auto& t : uninited_reduce_tensors) { VLOG(3) << "Init reduce tensor: " << t.as_tensor()->name; t.as_tensor()->InitReduction(stages, target); @@ -207,14 +230,15 @@ ir::LoweredFunc Lower(const std::string& name, // Merge the ctrl_deps with the given temp_tensors ang get a new temp_tensors auto ctrl_deps = CollectTempTensorsFromCtrlDepends(stages, tensor_args); ctrl_deps.insert(temp_tensors.begin(), temp_tensors.end()); - auto lower_impl_instance = detail::LowerImpl(name, - stages, - tensor_args, - scalar_args, - std::vector(ctrl_deps.begin(), ctrl_deps.end()), - target, - support_ir_schedule); - auto result = lower_impl_instance(); + auto lower_impl_instance = + detail::LowerImpl(name, + stages, + tensor_args, + scalar_args, + std::vector(ctrl_deps.begin(), ctrl_deps.end()), + target, + support_ir_schedule); + auto result = lower_impl_instance(); std::vector return_value; for (auto& res : result) { auto temp_buffers = GetTempBuffers(tensor_args, stages, res->body); @@ -257,13 +281,14 @@ std::vector LowerVec(const std::string& name, // Merge the ctrl_deps with the given temp_tensors ang get a new temp_tensors auto ctrl_deps = CollectTempTensorsFromCtrlDepends(stages, tensor_args); ctrl_deps.insert(temp_tensors.begin(), temp_tensors.end()); - auto lower_impl_instance = detail::LowerImpl(name, - stages, - tensor_args, - scalar_args, - std::vector(ctrl_deps.begin(), ctrl_deps.end()), - target, - support_ir_schedule); + auto lower_impl_instance = + detail::LowerImpl(name, + stages, + tensor_args, + scalar_args, + std::vector(ctrl_deps.begin(), ctrl_deps.end()), + target, + support_ir_schedule); // return vectorof ir::LoweredFunc. auto result = lower_impl_instance(); std::vector return_value; diff --git a/paddle/cinn/lang/lower.h b/paddle/cinn/lang/lower.h index 92ffb101dedd6..af8a186583a69 100644 --- a/paddle/cinn/lang/lower.h +++ b/paddle/cinn/lang/lower.h @@ -33,45 +33,48 @@ using ir::Tensor; using poly::StageMap; /** - * \brief Lower the computation of \p tensor_args and \p scalar_args to a LoweredFunc. + * \brief Lower the computation of \p tensor_args and \p scalar_args to a + * LoweredFunc. * @param name The name of the function. * @param tensor_args The tensor arguments, where the computation logic locates. * @param scalar_args The scalar arguments, indicate some dimensions. * @param temp_tensors The temporary tensors(buffers) used in the body. * @param b The module this function belongs to. - * @return A LoweredFunc, whose name is \p name, the argument list is the concatenation of \p tensor_args and \p - * scalar_args. + * @return A LoweredFunc, whose name is \p name, the argument list is the + * concatenation of \p tensor_args and \p scalar_args. */ ir::LoweredFunc Lower(const std::string &name, StageMap stages, const std::vector &tensor_args, - const std::vector &scalar_args = {}, + const std::vector &scalar_args = {}, const std::vector &temp_tensors = {}, - ir::Module::Builder *b = nullptr, - const Target &target = common::DefaultHostTarget(), - bool support_ir_schedule = false); + ir::Module::Builder *b = nullptr, + const Target &target = common::DefaultHostTarget(), + bool support_ir_schedule = false); /** - * \brief Lower the computation of \p tensor_args and \p scalar_args to a vector of LoweredFuncs. Each schedule group - * forms a LoweredFunc. + * \brief Lower the computation of \p tensor_args and \p scalar_args to a vector + * of LoweredFuncs. Each schedule group forms a LoweredFunc. * @param name The name of the function. * @param tensor_args The tensor arguments, where the computation logic locates. * @param scalar_args The scalar arguments, indicate some dimensions. * @param temp_tensors The temporary tensors(buffers) used in the body. * @param b The module this function belongs to. - * @return A vector of LoweredFuncs, whose name is \p name, name + "_1", name + "_2"... The argument list is deduced - * from the expression of each func. + * @return A vector of LoweredFuncs, whose name is \p name, name + "_1", name + + * "_2"... The argument list is deduced from the expression of each func. */ -std::vector LowerVec(const std::string &name, - StageMap stages, - const std::vector &tensor_args, - const std::vector &scalar_args = {}, - const std::vector &temp_tensors = {}, - ir::Module::Builder *b = nullptr, - const Target &target = common::DefaultHostTarget(), - bool support_ir_schedule = false); +std::vector LowerVec( + const std::string &name, + StageMap stages, + const std::vector &tensor_args, + const std::vector &scalar_args = {}, + const std::vector &temp_tensors = {}, + ir::Module::Builder *b = nullptr, + const Target &target = common::DefaultHostTarget(), + bool support_ir_schedule = false); -std::vector GetArgs(const Expr &func_body, const std::vector &input_output_nodes); +std::vector GetArgs( + const Expr &func_body, const std::vector &input_output_nodes); //! Collect the temporary tensors from a computational graph. std::vector GetTempBuffers(const std::vector &tensor_args, @@ -79,7 +82,8 @@ std::vector GetTempBuffers(const std::vector &tensor_args, Expr body); //! Collect the temporary tensors from a computational graph. -std::vector GetTempBuffers(const std::vector &args, Expr body); +std::vector GetTempBuffers(const std::vector &args, + Expr body); } // namespace lang } // namespace cinn diff --git a/paddle/cinn/lang/lower_impl.cc b/paddle/cinn/lang/lower_impl.cc index 1bedacdf256ae..247888246c2cf 100644 --- a/paddle/cinn/lang/lower_impl.cc +++ b/paddle/cinn/lang/lower_impl.cc @@ -35,16 +35,17 @@ namespace lang { namespace detail { void CheckNoIslCallRemains(Expr* expr) { - auto isl_calls = ir::CollectIRNodes( - *expr, [](const Expr* expr) { return expr->As() && expr->As()->is_isl_call(); }); + auto isl_calls = ir::CollectIRNodes(*expr, [](const Expr* expr) { + return expr->As() && expr->As()->is_isl_call(); + }); #ifdef CINN_DEBUG for (auto& item : isl_calls) { LOG(ERROR) << "ISL call: " << item; } #endif if (!isl_calls.empty()) { - LOG(WARNING) << "Some ISL call nodes remained, get " << isl_calls.size() << " isl_calls, the first one is " - << *isl_calls.begin(); + LOG(WARNING) << "Some ISL call nodes remained, get " << isl_calls.size() + << " isl_calls, the first one is " << *isl_calls.begin(); } } @@ -54,13 +55,15 @@ void BindBuffer(StageMap& stages) { tensor_map[stage.second->tensor()->name] = stage.second->tensor(); } for (auto& stage : stages) { - if (!stage.second->tensor()->buffer.defined() && !stage.second->meta.tensors_to_share_buffer_with.empty()) { + if (!stage.second->tensor()->buffer.defined() && + !stage.second->meta.tensors_to_share_buffer_with.empty()) { for (auto& str : stage.second->meta.tensors_to_share_buffer_with) { if (tensor_map[str]->buffer.defined()) { auto edited_shape = tensor_map[str]->buffer->shape; stage.second->tensor()->Bind(tensor_map[str]->buffer); tensor_map[str]->buffer->shape = edited_shape; - VLOG(3) << "Tensor " << stage.second->tensor()->name << " bind buffer to " << tensor_map[str]->name << " , " + VLOG(3) << "Tensor " << stage.second->tensor()->name + << " bind buffer to " << tensor_map[str]->name << " , " << tensor_map[str]->buffer->name; } } @@ -68,12 +71,13 @@ void BindBuffer(StageMap& stages) { } } -Expr LowerGroup(const poly::ScheduleGroup& group, - const std::map& tuple_to_expr, - std::map* global_tensor_map, - std::unordered_map>& resized_buffer_cache, - StageMap stage_map, - ir::CudaAxisInfo* cuda_axis_info) { +Expr LowerGroup( + const poly::ScheduleGroup& group, + const std::map& tuple_to_expr, + std::map* global_tensor_map, + std::unordered_map>& resized_buffer_cache, + StageMap stage_map, + ir::CudaAxisInfo* cuda_axis_info) { BindBuffer(stage_map); std::vector stages; for (auto& node : group.nodes) { @@ -101,16 +105,19 @@ Expr LowerGroup(const poly::ScheduleGroup& group, VLOG(6) << "before ast to expr"; // poly::IslAstNodeToCinnExpr(ast, &e); poly::IslAstNodeToCinnExpr(ast, gen.domain(), &e); - // now we get a workable expression, but the statement are something like `B(((16 * po0) + po1), po2)`, we need to - // transform this to some realworld statement in CINN. + // now we get a workable expression, but the statement are something like + // `B(((16 * po0) + po1), po2)`, we need to transform this to some realworld + // statement in CINN. VLOG(1) << "ast to expr: \n" << e << std::endl; - // replace isl call to the corresponding CINN statement, we need to replace the axis at the same time. + // replace isl call to the corresponding CINN statement, we need to replace + // the axis at the same time. for (auto& statement : tuple_to_expr) { VLOG(2) << "LowerGroup working on statement: " << statement.first; if (!gen.ContainsStatement(statement.first)) continue; - // the axis_ast_map contains the axis from the original (like `i`) to the transformed (like `i+3`). + // the axis_ast_map contains the axis from the original (like `i`) to the + // transformed (like `i+3`). auto axis_expr_map = gen.axis2expr(statement.first); for (auto& item : axis_expr_map) { VLOG(4) << "statement ast map axis [" << item.first << "] to " @@ -120,8 +127,10 @@ Expr LowerGroup(const poly::ScheduleGroup& group, // the original CINN statements. Expr statement_candi_expr = tuple_to_expr.at(statement.first); - VLOG(3) << "replacing " << statement.first << " to " << statement_candi_expr; - optim::ReplaceIslCallWithExpr(&e, statement.first, statement_candi_expr, axis_expr_map); + VLOG(3) << "replacing " << statement.first << " to " + << statement_candi_expr; + optim::ReplaceIslCallWithExpr( + &e, statement.first, statement_candi_expr, axis_expr_map); } CheckNoIslCallRemains(&e); @@ -193,7 +202,8 @@ std::string CompuGraphNode::id() const { /** * \brief Add nodes to graph with dependencies. * We create a computation graph based on the tensor dependency relations. - * NOTE The graph will contain the inline tensors so that the dependency will be reserved. + * NOTE The graph will contain the inline tensors so that the dependency will be + * reserved. * @param graph The graph * @param t The tensor. * @param stages The stage map. @@ -213,12 +223,14 @@ void CreateCompGraphWithInlineTensors(common::Graph* graph, // collect dependency tensors of t // here we just collect the tensors in Load nodes // NOTE there may be some other cases. - auto deps = ir::CollectLoadTensors(t->body(), [](const Expr* x) { return x->as_tensor(); }); + auto deps = ir::CollectLoadTensors( + t->body(), [](const Expr* x) { return x->as_tensor(); }); for (const auto& dep : deps) { auto e_tensor = dep.as_tensor_ref(); - auto* e_node = graph->RetrieveNode(e_tensor->name); + auto* e_node = graph->RetrieveNode(e_tensor->name); if (!e_node) { - e_node = graph->RegisterNode(e_tensor->name, new CompuGraphNode(e_tensor)); + e_node = + graph->RegisterNode(e_tensor->name, new CompuGraphNode(e_tensor)); } e_node->Controls(t_node); if (!visited->count(e_tensor)) { @@ -227,8 +239,8 @@ void CreateCompGraphWithInlineTensors(common::Graph* graph, } } -std::unique_ptr CreateCompGraphWithInlineTensorHidden(const std::vector& tensors, - StageMap stages) { +std::unique_ptr CreateCompGraphWithInlineTensorHidden( + const std::vector& tensors, StageMap stages) { // create a graph with inline tensor first. std::unique_ptr graph(new common::Graph); std::set visited; @@ -236,7 +248,8 @@ std::unique_ptr CreateCompGraphWithInlineTensorHidden(const std:: CreateCompGraphWithInlineTensors(graph.get(), t, stages, &visited); } - // greedy remove the inline tensor, each time merge the inputs of an inline tensor to its sink node. + // greedy remove the inline tensor, each time merge the inputs of an inline + // tensor to its sink node. std::set inline_nodes; do { @@ -255,7 +268,7 @@ std::unique_ptr CreateCompGraphWithInlineTensorHidden(const std:: */ for (auto* inline_node : inline_nodes) { // remove this node, merge its inputs to the sink nodes. - auto inline_inlinks = inline_node->inlinks(); + auto inline_inlinks = inline_node->inlinks(); auto inline_outlinks = inline_node->outlinks(); // unlink the inline node from its inputs and outputs @@ -296,9 +309,8 @@ void CompuGraphAddCtrlDepLinks(common::Graph* graph, StageMap stages) { } } -std::unique_ptr CreateCompGraph(const std::vector& tensors, - StageMap stages, - bool hide_inline) { +std::unique_ptr CreateCompGraph( + const std::vector& tensors, StageMap stages, bool hide_inline) { if (hide_inline) { auto graph = CreateCompGraphWithInlineTensorHidden(tensors, stages); CompuGraphAddCtrlDepLinks(graph.get(), stages); @@ -316,7 +328,8 @@ std::unique_ptr CreateCompGraph(const std::vector& te void LowerImpl::CheckArgsUnique() { for (auto& tensor : tensor_args_) { - CHECK(!stages_[tensor]->inlined()) << "Inline tensor cannot be argument of function"; + CHECK(!stages_[tensor]->inlined()) + << "Inline tensor cannot be argument of function"; if (!tensor->buffer.defined()) { LOG(ERROR) << "tensor [" << tensor->name << "] buffer is null"; continue; @@ -324,7 +337,8 @@ void LowerImpl::CheckArgsUnique() { } } -std::vector LowerImpl::GenerateFunctionArgumentList(Expr fn_body) { +std::vector LowerImpl::GenerateFunctionArgumentList( + Expr fn_body) { CheckArgsUnique(); std::vector args; @@ -344,15 +358,19 @@ std::vector LowerImpl::GenerateFunctionArgumentList(Expr fn_body) for (auto& tensor : tensor_args_) { auto* tensor_node = tensor.As(); - bool is_output = teller.IsWrite(tensor->name); - VLOG(1) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; + bool is_output = teller.IsWrite(tensor->name); + VLOG(1) << "tensor argument " << tensor->name << " buffer " + << tensor->buffer->name; // avoid duplicate if (!tensor_node->buffer.defined()) continue; - // if a argument is already marked as kInput, mark it as kOutput and move it to the back. + // if a argument is already marked as kInput, mark it as kOutput and move it + // to the back. if (arg_names.count(tensor_node->buffer->name)) { - auto it = std::find_if( - args.begin(), args.end(), [&](const ir::Argument& x) { return x.name() == tensor_node->buffer->name; }); + auto it = + std::find_if(args.begin(), args.end(), [&](const ir::Argument& x) { + return x.name() == tensor_node->buffer->name; + }); CHECK(it != args.end()); if (it->is_input()) { args.erase(it); @@ -364,15 +382,16 @@ std::vector LowerImpl::GenerateFunctionArgumentList(Expr fn_body) arg_names.insert(tensor_node->buffer->name); auto io = is_output ? ir::Argument::IO::kOutput : ir::Argument::IO::kInput; - VLOG(3) << "Collect " << (is_output ? "W" : "R") << " argument " << tensor->buffer->name; + VLOG(3) << "Collect " << (is_output ? "W" : "R") << " argument " + << tensor->buffer->name; args.emplace_back(tensor_node->buffer, io); } return args; } // Generate Function Arguments for splitted kernel. -std::vector LowerImpl::GenFuncArgForSplitKernel(Expr func_iterator, - std::vector temp_tensors) { +std::vector LowerImpl::GenFuncArgForSplitKernel( + Expr func_iterator, std::vector temp_tensors) { CheckArgsUnique(); std::vector in_args; @@ -391,10 +410,12 @@ std::vector LowerImpl::GenFuncArgForSplitKernel(Expr func_iterator in_args.emplace_back(scalar, ir::Argument::IO::kInput); } - auto all_tensors = ir::CollectIRNodes( - func_iterator, [&](const Expr* x) { return x->as_tensor() && !stages_[x->as_tensor()]->inlined(); }); + auto all_tensors = ir::CollectIRNodes(func_iterator, [&](const Expr* x) { + return x->as_tensor() && !stages_[x->as_tensor()]->inlined(); + }); - auto all_vars = ir::CollectIRNodes(func_iterator, [&](const Expr* x) { return x->as_var(); }); + auto all_vars = ir::CollectIRNodes( + func_iterator, [&](const Expr* x) { return x->as_var(); }); for (auto& i : all_tensors) { auto* tensor = i.as_tensor(); @@ -428,17 +449,21 @@ std::vector LowerImpl::GenFuncArgForSplitKernel(Expr func_iterator if (temp_tensor_names.count(tensor->name) > 0) continue; if (all_tensor_names.count(tensor->name) == 0) continue; bool is_output = teller.IsWrite(tensor->name); - VLOG(3) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; + VLOG(3) << "tensor argument " << tensor->name << " buffer " + << tensor->buffer->name; // avoid duplicate if (!tensor->buffer.defined()) { VLOG(3) << "tensor->buffer is not defined"; continue; } - // if a argument is already marked as kInput, mark it as kOutput and move it to the back. + // if a argument is already marked as kInput, mark it as kOutput and move it + // to the back. if (arg_names.count(tensor->buffer->name)) { auto it = std::find_if( - in_args.begin(), in_args.end(), [&](const ir::Argument& x) { return x.name() == tensor->buffer->name; }); + in_args.begin(), in_args.end(), [&](const ir::Argument& x) { + return x.name() == tensor->buffer->name; + }); if (it != in_args.end()) { in_args.erase(it); } else { @@ -459,8 +484,10 @@ std::vector LowerImpl::GenFuncArgForSplitKernel(Expr func_iterator auto* tensor = i.as_tensor(); VLOG(3) << "Tensor " << tensor->name; if (tensor->buffer.defined() && !arg_names.count(tensor->buffer->name)) { - bool is_output = teller.IsWrite(tensor->name) && teller.IsWrite(tensor->name); - if (is_output) out_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); + bool is_output = + teller.IsWrite(tensor->name) && teller.IsWrite(tensor->name); + if (is_output) + out_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); } } } @@ -471,7 +498,8 @@ std::vector LowerImpl::GenFuncArgForSplitKernel(Expr func_iterator } std::vector LowerImpl::CollectTemporaryTensors() { - // a temporary should be in the comp_graph but not contained in the tensor_args. + // a temporary should be in the comp_graph but not contained in the + // tensor_args. absl::flat_hash_map tensor_arg_map = GenTensorArgMap(); absl::flat_hash_map temp_tensor_map; @@ -484,10 +512,11 @@ std::vector LowerImpl::CollectTemporaryTensors() { } std::vector temp_tensors; - std::transform(temp_tensor_map.begin(), - temp_tensor_map.end(), - std::back_inserter(temp_tensors), - [&](const decltype(temp_tensor_map)::value_type& x) { return x.second; }); + std::transform( + temp_tensor_map.begin(), + temp_tensor_map.end(), + std::back_inserter(temp_tensors), + [&](const decltype(temp_tensor_map)::value_type& x) { return x.second; }); return temp_tensors; } @@ -515,9 +544,12 @@ std::vector LowerImpl::operator()() { if (!stages_[t]->inlined()) stages.push_back(stages_[t]); } - auto deps = CollectExtraDependencies(); - auto schedule = poly::CreateSchedule( - stages, poly::ScheduleKind::Poly, std::vector>(deps.begin(), deps.end())); + auto deps = CollectExtraDependencies(); + auto schedule = + poly::CreateSchedule(stages, + poly::ScheduleKind::Poly, + std::vector>( + deps.begin(), deps.end())); auto func_body = GenerateFunctionBody(schedule.get()); std::vector result; @@ -526,22 +558,30 @@ std::vector LowerImpl::operator()() { if (support_ir_schedule_) { // add ScheduleBlockRealize func_iterator = ir::ScheduleBlockRealize::Make( - {}, ir::ScheduleBlock::Make({}, {}, {}, common::UniqName("root"), func_iterator)); + {}, + ir::ScheduleBlock::Make( + {}, {}, {}, common::UniqName("root"), func_iterator)); } std::set temp_tensor_names; for (auto& t : temp_tensor_args_) temp_tensor_names.insert(t->name); - auto tensor_map = - optim::InitialAssignBuffer(&func_iterator, stages_, all_tensor_map, comp_graph(), temp_tensor_names); + auto tensor_map = optim::InitialAssignBuffer(&func_iterator, + stages_, + all_tensor_map, + comp_graph(), + temp_tensor_names); // copy the tensor(with buffer assigned) back to func's args. { for (auto& arg : tensor_args_) { if (arg->is_placeholder_node()) continue; if (arg->buffer.defined()) continue; - if (arg->body().As() && arg->body().type().is_void()) continue; // extern call + if (arg->body().As() && arg->body().type().is_void()) + continue; // extern call if (tensor_map.find(arg->name) == tensor_map.end()) { - LOG(INFO) << "Didn't find arg tensor " << arg->name << "in tensor_map.\n" - << "The function is " << fn_name_ << "\nAnd all the arg tensors are:\n"; + LOG(INFO) << "Didn't find arg tensor " << arg->name + << "in tensor_map.\n" + << "The function is " << fn_name_ + << "\nAnd all the arg tensors are:\n"; for (auto& i : tensor_args_) { LOG(INFO) << i->name; } @@ -550,7 +590,8 @@ std::vector LowerImpl::operator()() { Reference(&arg)->buffer = tensor_map.at(arg->name)->buffer; } } - auto store_exprs = ir::CollectIRNodes(func_iterator, [](const Expr* x) { return x->As(); }); + auto store_exprs = ir::CollectIRNodes( + func_iterator, [](const Expr* x) { return x->As(); }); std::vector new_temp_tensors; for (auto& expr : store_exprs) { auto* store_node = expr.As(); @@ -591,7 +632,8 @@ std::vector LowerImpl::operator()() { ir::LoweredFunc func; if (target_ == common::DefaultNVGPUTarget()) { - auto func_args2 = GenFuncArgForSplitKernel(func_iterator, new_temp_tensors); + auto func_args2 = + GenFuncArgForSplitKernel(func_iterator, new_temp_tensors); std::string new_fn_name = fn_name_; if (num_func > 0) { new_fn_name += "_" + std::to_string(num_func); @@ -603,10 +645,12 @@ std::vector LowerImpl::operator()() { for (auto& i : temp_buffers) { VLOG(3) << "temp_buffers is : " << i->name; } - func = ir::_LoweredFunc_::Make(new_fn_name, func_args2, func_iterator, temp_buffers); + func = ir::_LoweredFunc_::Make( + new_fn_name, func_args2, func_iterator, temp_buffers); } else { auto func_args = GenerateFunctionArgumentList(func_iterator); - func = ir::_LoweredFunc_::Make(fn_name_, func_args, func_iterator, temp_buffers); + func = ir::_LoweredFunc_::Make( + fn_name_, func_args, func_iterator, temp_buffers); } if (support_ir_schedule_) { @@ -617,11 +661,14 @@ std::vector LowerImpl::operator()() { num_func++; } else { optim::ComputeInlineExpand(&func->body, stages_, &all_tensor_map); - auto res = - optim::Optimize(func, target_, FLAGS_cinn_runtime_display_debug_info, /* remove_gpu_for_loops = */ false); - - if (cuda_axis_info_.size() > num_func && cuda_axis_info_[num_func].valid()) { - auto* res_func = res.as_lowered_func(); + auto res = optim::Optimize(func, + target_, + FLAGS_cinn_runtime_display_debug_info, + /* remove_gpu_for_loops = */ false); + + if (cuda_axis_info_.size() > num_func && + cuda_axis_info_[num_func].valid()) { + auto* res_func = res.as_lowered_func(); res_func->cuda_axis_info = cuda_axis_info_[num_func]; } result.push_back(ir::LoweredFunc(res.get())); @@ -634,8 +681,8 @@ std::vector LowerImpl::operator()() { std::vector LowerImpl::CollectAllTensors() { std::vector tensors; auto topo_order = compu_graph_->topological_order(); // NOLINT - auto& nodes = std::get<0>(topo_order); - auto& edges = std::get<1>(topo_order); + auto& nodes = std::get<0>(topo_order); + auto& edges = std::get<1>(topo_order); for (auto* node : nodes) { auto* cnode = node->safe_as(); CHECK(cnode); @@ -644,7 +691,8 @@ std::vector LowerImpl::CollectAllTensors() { return tensors; } -std::set> LowerImpl::CollectExtraDependencies() const { +std::set> +LowerImpl::CollectExtraDependencies() const { std::set> deps; for (auto* node : compu_graph_->nodes()) { auto* cnode = node->safe_as(); @@ -656,7 +704,8 @@ std::set> LowerImpl::CollectExtraDependencie return deps; } -std::vector LowerImpl::GenerateFunctionBody(const poly::Schedule* schedule) { +std::vector LowerImpl::GenerateFunctionBody( + const poly::Schedule* schedule) { // generate the expressions for each group. std::vector exprs; std::vector result; @@ -678,15 +727,18 @@ std::vector LowerImpl::GenerateFunctionBody(const poly::Schedule* schedule auto& tensor = tensor_map[node->id()]; if (!tensor->has_expression()) continue; all_temp_tensor = - all_temp_tensor && (stages_[tensor]->inlined() || - (tensor->buffer.defined() && (tensor->buffer->memory_type == ir::MemoryType::GPUShared || - tensor->buffer->memory_type == ir::MemoryType::GPULocal))); + all_temp_tensor && + (stages_[tensor]->inlined() || + (tensor->buffer.defined() && + (tensor->buffer->memory_type == ir::MemoryType::GPUShared || + tensor->buffer->memory_type == ir::MemoryType::GPULocal))); auto store_body = tensor->tensor_store_expanded_body(); if (support_ir_schedule_) { // add schedule block of tensor computation for schedule IR int var_counts = tensor->shape.size() + tensor->reduce_axis.size(); std::vector int_shape; - VLOG(3) << "Tensor " << tensor->name << "'s shape is : " << utils::Join(tensor->shape, ","); + VLOG(3) << "Tensor " << tensor->name + << "'s shape is : " << utils::Join(tensor->shape, ","); for (auto& expr : tensor->shape) { CHECK(expr.is_constant()); int_shape.push_back((int)expr.get_constant()); @@ -701,31 +753,45 @@ std::vector LowerImpl::GenerateFunctionBody(const poly::Schedule* schedule // create block itervars, i0,i1... std::vector block_vars; std::vector iter_values; - std::vector axis_vars = common::GenDefaultAxis(tensor->shape.size()); + std::vector axis_vars = + common::GenDefaultAxis(tensor->shape.size()); // bind var_values - axis_vars.insert(axis_vars.end(), tensor->reduce_axis.begin(), tensor->reduce_axis.end()); + axis_vars.insert(axis_vars.end(), + tensor->reduce_axis.begin(), + tensor->reduce_axis.end()); for (int i = 0; i < var_counts; i++) { - block_vars.push_back(Var(Expr(0), Expr(int_shape[i]), cinn::UniqName("i" + std::to_string(i)), false)); + block_vars.push_back(Var(Expr(0), + Expr(int_shape[i]), + cinn::UniqName("i" + std::to_string(i)), + false)); if (i >= tensor->shape.size()) { block_vars[i]->is_reduce_axis = true; - axis_vars[i]->is_reduce_axis = true; + axis_vars[i]->is_reduce_axis = true; } iter_values.push_back(axis_vars[i]); // replace store's indice - VLOG(3) << "replace axis_var " << axis_vars[i]->name << " to block_var " << block_vars[i]; + VLOG(3) << "replace axis_var " << axis_vars[i]->name + << " to block_var " << block_vars[i]; optim::ReplaceVarWithExpr(&store_body, axis_vars[i], block_vars[i]); } store_body = ir::ScheduleBlockRealize::Make( - iter_values, ir::ScheduleBlock::Make(block_vars, {}, {}, tensor->name, store_body)); - // iter_values, ir::ScheduleBlock::Make(block_vars, {}, {}, common::UniqName(tensor->name), store_body)); + iter_values, + ir::ScheduleBlock::Make( + block_vars, {}, {}, tensor->name, store_body)); + // iter_values, ir::ScheduleBlock::Make(block_vars, {}, {}, + // common::UniqName(tensor->name), store_body)); VLOG(3) << "store body\n" << store_body; } tuple_to_expr[tensor->name] = store_body; } ir::CudaAxisInfo temp_cuda_axis_info; - Expr group_expr = - LowerGroup(group, tuple_to_expr, &global_tensor_map, resized_buffer_cache, stages_, &temp_cuda_axis_info); + Expr group_expr = LowerGroup(group, + tuple_to_expr, + &global_tensor_map, + resized_buffer_cache, + stages_, + &temp_cuda_axis_info); if (group_expr.defined()) { cuda_axis_info_.emplace_back(std::move(temp_cuda_axis_info)); @@ -768,7 +834,8 @@ LowerImpl::LowerImpl(const std::string& fn_name, support_ir_schedule_(support_ir_schedule) { { // Initialize the graph std::vector tensors(tensor_args.begin(), tensor_args.end()); - tensors.insert(std::end(tensors), temp_tensor_args.begin(), temp_tensor_args.end()); + tensors.insert( + std::end(tensors), temp_tensor_args.begin(), temp_tensor_args.end()); compu_graph_ = CreateCompGraph(tensors, stages, false /*inline_hide*/); @@ -779,7 +846,8 @@ LowerImpl::LowerImpl(const std::string& fn_name, { // update schedule. std::vector tensors(tensor_args.begin(), tensor_args.end()); - tensors.insert(std::end(tensors), temp_tensor_args_.begin(), temp_tensor_args_.end()); + tensors.insert( + std::end(tensors), temp_tensor_args_.begin(), temp_tensor_args_.end()); compu_graph_ = CreateCompGraph(tensors, stages, true /*inline_hide*/); VLOG(1) << "Computation Graph:\n" << compu_graph_->Visualize(); diff --git a/paddle/cinn/lang/lower_impl.h b/paddle/cinn/lang/lower_impl.h index b2acb773e6806..505e80ca6a49e 100644 --- a/paddle/cinn/lang/lower_impl.h +++ b/paddle/cinn/lang/lower_impl.h @@ -49,18 +49,21 @@ namespace lang { namespace detail { /** - * After the AstGen build the forloop from isl exprs, all the ISL Call nodes should be mapped to the corresponding CINN - * expressions, there should be no remaining. + * After the AstGen build the forloop from isl exprs, all the ISL Call nodes + * should be mapped to the corresponding CINN expressions, there should be no + * remaining. */ void CheckNoIslCallRemains(const Expr* expr); /** * \brief Lower a single group of nodes. * - * We partition the whole computation of a function into several groups, each group is a basic element for ISL - * polyhedral computation, that is, we transform a group into a isl domain and schedule, and generate ast latter. + * We partition the whole computation of a function into several groups, each + * group is a basic element for ISL polyhedral computation, that is, we + * transform a group into a isl domain and schedule, and generate ast latter. * - * @param group A single schedule group containing several Stages and the scheduling order. + * @param group A single schedule group containing several Stages and the + * scheduling order. * @param tuple_to_expr A map from isl set tuple name to CINN expressions. */ Expr LowerGroup(const poly::ScheduleGroup& group, @@ -92,9 +95,10 @@ struct CompuGraphNode : public common::GraphNode { * @param hide_inline hide inline tensor nodes. * @return a graph. */ -std::unique_ptr CreateCompGraph(const std::vector& tensors, - StageMap stages, - bool hide_inline = false); +std::unique_ptr CreateCompGraph( + const std::vector& tensors, + StageMap stages, + bool hide_inline = false); class LowerImpl { public: @@ -111,8 +115,8 @@ class LowerImpl { const std::vector& tensor_args, const std::vector& scalar_args, const std::vector& temp_tensor_args = {}, - const Target& target = common::DefaultHostTarget(), - bool support_ir_schedule = false); + const Target& target = common::DefaultHostTarget(), + bool support_ir_schedule = false); std::vector operator()(); @@ -123,13 +127,15 @@ class LowerImpl { /** * \brief generate the argument list of the final output function. - * We put the scalar_args in front of tensor_args, e.g. get tensor_args{A,B}, scalar_args{m}, the final argument list - * is {m, A, B}, the input and output tensor can be mixed in the tensor_args, the kInput and kOutput token will deduce - * from their usage in the computation. + * We put the scalar_args in front of tensor_args, e.g. get tensor_args{A,B}, + * scalar_args{m}, the final argument list is {m, A, B}, the input and output + * tensor can be mixed in the tensor_args, the kInput and kOutput token will + * deduce from their usage in the computation. */ std::vector GenerateFunctionArgumentList(Expr fn_body); - std::vector GenFuncArgForSplitKernel(Expr func_iterator, std::vector temp_tensors); + std::vector GenFuncArgForSplitKernel( + Expr func_iterator, std::vector temp_tensors); /** * \brief generate the body expression of the final output function. @@ -139,23 +145,26 @@ class LowerImpl { private: /** * \brief Collect the temporary tensors. - * A temporary tensor is one that is in the computation graph, not inlined and not in the tensor_args(similar to a - * temporary variable inside function). + * A temporary tensor is one that is in the computation graph, not inlined and + * not in the tensor_args(similar to a temporary variable inside function). */ std::vector CollectTemporaryTensors(); /** - * \brief Check both the tensor_args and sclar_args not contain duplication (different arguemnt with the same name). + * \brief Check both the tensor_args and sclar_args not contain duplication + * (different arguemnt with the same name). */ void CheckArgsUnique(); /** - * \brief Get a map, for each tensor in the tensor_args, map from name to itself. + * \brief Get a map, for each tensor in the tensor_args, map from name to + * itself. */ inline absl::flat_hash_map GenTensorArgMap(); /** - * \brief Get a map, for each tensor in the computation graph, map from name to itself. + * \brief Get a map, for each tensor in the computation graph, map from name + * to itself. */ inline absl::flat_hash_map GenAllTensorMap(); @@ -172,7 +181,8 @@ class LowerImpl { * * TODO(Superjomn) remove the field `extra_depend_stages` */ - std::set> CollectExtraDependencies() const; + std::set> CollectExtraDependencies() + const; private: const std::string& fn_name_; @@ -193,7 +203,8 @@ class LowerImpl { }; /** - * \brief Tell whether a tensor contains some GPU related information, such some schedule. + * \brief Tell whether a tensor contains some GPU related information, such some + * schedule. */ bool TensorContainsGPUInfo(ir::Tensor t, poly::Stage* stage); @@ -203,7 +214,8 @@ bool TensorContainsGPUInfo(ir::Tensor t, poly::Stage* stage); struct MarkVectorizeMutator : public ir::IRMutator { const std::map& vectorizes; - explicit MarkVectorizeMutator(const std::map& vectorizes) + explicit MarkVectorizeMutator(const std::map& vectorizes) : vectorizes(vectorizes) {} void operator()(Expr* expr) { ir::IRMutator::Visit(expr, expr); } @@ -237,7 +249,9 @@ struct MarkVectorizeMutator : public ir::IRMutator { struct MarkUnrollMutator : public ir::IRMutator { std::map /*level*/> unrolls; - explicit MarkUnrollMutator(const std::map>& unrolls) : unrolls(unrolls) {} + explicit MarkUnrollMutator( + const std::map>& unrolls) + : unrolls(unrolls) {} void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } @@ -271,7 +285,9 @@ struct MarkUnrollMutator : public ir::IRMutator { struct MarkParallelMutator : public ir::IRMutator { std::map /*level*/> parallels; - explicit MarkParallelMutator(const std::map>& parallels) : parallels(parallels) {} + explicit MarkParallelMutator( + const std::map>& parallels) + : parallels(parallels) {} void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } diff --git a/paddle/cinn/lang/lower_impl_test.cc b/paddle/cinn/lang/lower_impl_test.cc index 3c9637128a59e..273854ac2566c 100644 --- a/paddle/cinn/lang/lower_impl_test.cc +++ b/paddle/cinn/lang/lower_impl_test.cc @@ -42,7 +42,7 @@ TEST(CreateCompGraph, single_layer) { } auto stages = CreateStages({C}); - auto graph = CreateCompGraph({A, B, C}, stages); + auto graph = CreateCompGraph({A, B, C}, stages); LOG(INFO) << "graph:\n" << graph->Visualize(); @@ -86,10 +86,12 @@ TEST(CreateCompGraph, multi_layers) { // C->E // D->E auto E = Compute( - {M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j) + C(i, j) + D(i, j); }, "E"); + {M, N}, + [&](Expr i, Expr j) { return A(i, j) + B(i, j) + C(i, j) + D(i, j); }, + "E"); auto stages = CreateStages({C, D, E}); - auto graph = CreateCompGraph({A, B, E}, stages); + auto graph = CreateCompGraph({A, B, E}, stages); LOG(INFO) << "graph:\n" << graph->Visualize(); @@ -220,7 +222,9 @@ TEST(CreateCompGraph, inline_compatible) { // C->E // D->E auto E = Compute( - {M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j) + C(i, j) + D(i, j); }, "E"); + {M, N}, + [&](Expr i, Expr j) { return A(i, j) + B(i, j) + C(i, j) + D(i, j); }, + "E"); auto stages = CreateStages({C, D, E}); stages[D]->ComputeInline(); @@ -279,7 +283,9 @@ TEST(CreateCompGraph, inline_compatible1) { // C->E // D->E auto E = Compute( - {M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j) + C(i, j) + D(i, j); }, "E"); + {M, N}, + [&](Expr i, Expr j) { return A(i, j) + B(i, j) + C(i, j) + D(i, j); }, + "E"); auto stages = CreateStages({C, D, E}); stages[C]->ComputeInline(); diff --git a/paddle/cinn/lang/lower_test.cc b/paddle/cinn/lang/lower_test.cc index a5b95bcaaf69a..14f81090e30cb 100755 --- a/paddle/cinn/lang/lower_test.cc +++ b/paddle/cinn/lang/lower_test.cc @@ -69,7 +69,9 @@ TEST(lower, more_complex) { Placeholder B("B", {Expr(N), Expr(K)}); auto C = Compute( - {M, N, K}, [=](Var i, Var j, Var k) -> Expr { return A(i, j) * B(j, k); }, "C"); + {M, N, K}, + [=](Var i, Var j, Var k) -> Expr { return A(i, j) * B(j, k); }, + "C"); auto stages = CreateStages({C}); @@ -78,7 +80,8 @@ TEST(lower, more_complex) { std::cout << "func:\n" << Expr(lower_funcs->self()) << std::endl; } -//! To support training, the dynamic shape support is vital. We test the corresponding lower ability here. +//! To support training, the dynamic shape support is vital. We test the +//! corresponding lower ability here. TEST(lower, dynamic_shape) { Var B("B"); // B is like shape here. Expr N(15); @@ -89,9 +92,11 @@ TEST(lower, dynamic_shape) { Placeholder W("W", {Expr(N), Expr(K)}); auto C = Compute( - {B, N, K}, [=](Var i, Var j, Var k) -> Expr { return X(i, j) * W(j, k); }, "C"); + {B, N, K}, + [=](Var i, Var j, Var k) -> Expr { return X(i, j) * W(j, k); }, + "C"); - auto stages = CreateStages({C}); + auto stages = CreateStages({C}); auto lower_funcs = Lower("cal_C", stages, {X, W, C}); std::cout << "func:\n" << Expr(lower_funcs->self()) << std::endl; @@ -108,9 +113,10 @@ TEST(lower, lowered_call) { auto Z = Compute( {B, N}, [&](Var i, Var j) { return X(i, j) + Y(i, j); }, "Z"); - std::vector return_types({{Float(32), std::vector{{B, N}}, "C"}}); + std::vector return_types( + {{Float(32), std::vector{{B, N}}, "C"}}); auto tensors = CallLowered("lowered_fun0", {X, Y, Z}, return_types); - auto C = tensors[0]; + auto C = tensors[0]; auto stages = CreateStages({X, Y, Z, C}); diff --git a/paddle/cinn/lang/packed_func.h b/paddle/cinn/lang/packed_func.h index aba67cdbf524b..fa7f3e05cd34b 100644 --- a/paddle/cinn/lang/packed_func.h +++ b/paddle/cinn/lang/packed_func.h @@ -54,7 +54,9 @@ class Args { ArgValue& operator[](int i) { return values_[i]; } const ArgValue& operator[](int i) const { return values_[i]; } - common::CINNValuePack ToValuePack() const { return common::CINNValuePack(values_); } + common::CINNValuePack ToValuePack() const { + return common::CINNValuePack(values_); + } private: std::vector values_; @@ -67,7 +69,8 @@ struct for_each_dispatcher { template static void Run(const F& f, T&& value, Args&&... args) { f(I, std::forward(value)); - for_each_dispatcher::Run(f, std::forward(args)...); + for_each_dispatcher::Run( + f, std::forward(args)...); } }; @@ -78,7 +81,8 @@ struct for_each_dispatcher { template inline void for_each(const F& f, Args&&... args) { - for_each_dispatcher::Run(f, std::forward(args)...); + for_each_dispatcher::Run( + f, std::forward(args)...); } struct FuncArgsSetter { @@ -96,7 +100,8 @@ struct FuncArgsSetter { } // namespace detail /** - * A function defininer with the arguments packed, all the PackedFuncs have the same signature. + * A function defininer with the arguments packed, all the PackedFuncs have the + * same signature. */ class PackedFunc { public: diff --git a/paddle/cinn/lang/packed_func_test.cc b/paddle/cinn/lang/packed_func_test.cc index 4a3eeb4b16e3d..b5e941f1abed2 100644 --- a/paddle/cinn/lang/packed_func_test.cc +++ b/paddle/cinn/lang/packed_func_test.cc @@ -27,7 +27,7 @@ TEST(Function, test) { PackedFunc::body_t func_body = [](Args args, RetValue* ret) { int a = args[0]; int b = args[1]; - *ret = (a + b); + *ret = (a + b); }; PackedFunc func(func_body); @@ -38,12 +38,12 @@ TEST(Function, test) { TEST(Function, test1) { PackedFunc::body_t body = [](Args args, RetValue* ret) { auto* msg = static_cast(args[0]); - (*ret) = msg; + (*ret) = msg; }; PackedFunc func(body); const char* msg = "hello world"; - char* c = func(msg); + char* c = func(msg); LOG(INFO) << static_cast(c); } @@ -84,8 +84,8 @@ TEST(Function, ReturnMultiValue) { PackedFunc func(body); common::CINNValuePack ret = func(1, 2); - int c = ret[0]; - int d = ret[1]; + int c = ret[0]; + int d = ret[1]; EXPECT_EQ(c, 3); EXPECT_EQ(d, -1); diff --git a/paddle/cinn/lang/placeholder.cc b/paddle/cinn/lang/placeholder.cc index f3dfd043178e3..2a71504c364b2 100644 --- a/paddle/cinn/lang/placeholder.cc +++ b/paddle/cinn/lang/placeholder.cc @@ -22,7 +22,9 @@ namespace lang { using cinn::common::bfloat16; using cinn::common::float16; -ir::Tensor CreatePlaceHolder(const std::vector &shape, Type type, const std::string &name) { +ir::Tensor CreatePlaceHolder(const std::vector &shape, + Type type, + const std::string &name) { std::vector expr_shape; for (int s : shape) { expr_shape.push_back(Expr(s)); @@ -30,7 +32,9 @@ ir::Tensor CreatePlaceHolder(const std::vector &shape, Type type, const std return CreatePlaceHolder(expr_shape, type, name); } -ir::Tensor CreatePlaceHolder(const std::vector &shape, Type type, const std::string &name) { +ir::Tensor CreatePlaceHolder(const std::vector &shape, + Type type, + const std::string &name) { if (type.is_float(32)) { return Placeholder(name, shape); } else if (type.is_float(64)) { diff --git a/paddle/cinn/lang/placeholder.h b/paddle/cinn/lang/placeholder.h index 588ab44d40de4..a0eaaab5ddc51 100644 --- a/paddle/cinn/lang/placeholder.h +++ b/paddle/cinn/lang/placeholder.h @@ -44,7 +44,9 @@ class Placeholder { Expr operator()(Expr a) const { return Call({a}); } Expr operator()(Expr a, Expr b) const { return Call({a, b}); } Expr operator()(Expr a, Expr b, Expr c) const { return Call({a, b, c}); } - Expr operator()(Expr a, Expr b, Expr c, Expr d) const { return Call({a, b, c, d}); } + Expr operator()(Expr a, Expr b, Expr c, Expr d) const { + return Call({a, b, c, d}); + } Expr operator()(const std::vector &indices) const; // @} @@ -77,24 +79,31 @@ Expr Placeholder::Call(const std::vector &indices) const { } template -Placeholder::Placeholder(const std::string &name, const std::vector &shape) { +Placeholder::Placeholder(const std::string &name, + const std::vector &shape) { std::vector _shape; for (int v : shape) _shape.push_back(Expr(v)); Init(name, _shape); } template -Placeholder::Placeholder(const std::string &name, const std::vector &shape) { +Placeholder::Placeholder(const std::string &name, + const std::vector &shape) { Init(name, shape); } -ir::Tensor CreatePlaceHolder(const std::vector &shape, Type type, const std::string &name); +ir::Tensor CreatePlaceHolder(const std::vector &shape, + Type type, + const std::string &name); -ir::Tensor CreatePlaceHolder(const std::vector &shape, Type type, const std::string &name); +ir::Tensor CreatePlaceHolder(const std::vector &shape, + Type type, + const std::string &name); /// ------- details ------- template -void Placeholder::Init(const std::string &name, const std::vector &shape) { +void Placeholder::Init(const std::string &name, + const std::vector &shape) { ir::Var buffer_ptr(Context::Global().NewName("buffer")); buffer_ptr->set_type(type_of()); @@ -102,7 +111,8 @@ void Placeholder::Init(const std::string &name, const std::vector &shap Expr offset(0); std::vector axis; - for (int i = 0; i < shape.size(); i++) axis.emplace_back(common::axis_name(i)); + for (int i = 0; i < shape.size(); i++) + axis.emplace_back(common::axis_name(i)); auto op = ir::PlaceholderOp::Make(name, shape, type_of()); diff --git a/paddle/cinn/optim/buffer_assign.cc b/paddle/cinn/optim/buffer_assign.cc index 74507ffe1807b..f5f0e47a68fee 100644 --- a/paddle/cinn/optim/buffer_assign.cc +++ b/paddle/cinn/optim/buffer_assign.cc @@ -38,7 +38,8 @@ const char* BufferUFNode::__type_info__ = "BufferUFNode"; struct IRReplaceTensorMutator : ir::IRMutator<> { const std::map& tensor_map; - IRReplaceTensorMutator(const std::map& tensor_map) : tensor_map(tensor_map) {} + IRReplaceTensorMutator(const std::map& tensor_map) + : tensor_map(tensor_map) {} void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } void Visit(const ir::_Tensor_* op, Expr* expr) override { @@ -51,12 +52,14 @@ struct IRReplaceTensorMutator : ir::IRMutator<> { } // namespace -std::map InitialAssignBuffer(Expr* expr, - poly::StageMap stages, - const std::map& all_tensor_map, - const common::Graph* comp_graph, - const std::set& temp_tensor_names) { - // The tensor map helps to reserve only one tensor instance for a tensor(called the same name). +std::map InitialAssignBuffer( + Expr* expr, + poly::StageMap stages, + const std::map& all_tensor_map, + const common::Graph* comp_graph, + const std::set& temp_tensor_names) { + // The tensor map helps to reserve only one tensor instance for a + // tensor(called the same name). std::map buffer_updated_tensor; for (auto& item : all_tensor_map) { @@ -67,8 +70,8 @@ std::map InitialAssignBuffer(Expr* expr, // union-find to cluster the tensors with the same buffer. common::UnionFind union_find; - // unify all the tensor occurance with a global one, e.g. there are multiple tensor B exists in the expression, - // replace them with a shared one. + // unify all the tensor occurance with a global one, e.g. there are multiple + // tensor B exists in the expression, replace them with a shared one. ir::CollectIRNodes(*expr, [&](const Expr* x) -> bool { auto* t = x->as_tensor(); if (t && !stages[t]->inlined()) { @@ -79,7 +82,7 @@ std::map InitialAssignBuffer(Expr* expr, std::map uf_map; for (auto& item : all_tensor_map) { - auto* n = union_find.AddNode(new BufferUFNode(item.second->name)); + auto* n = union_find.AddNode(new BufferUFNode(item.second->name)); uf_map[item.second->name] = n->safe_as(); } @@ -90,17 +93,19 @@ std::map InitialAssignBuffer(Expr* expr, auto* other_n = uf_map[other]; if (!other_n) continue; - VLOG(3) << "share buffer between " << item.first << " " << other_n->tensor_name; + VLOG(3) << "share buffer between " << item.first << " " + << other_n->tensor_name; cur_n->Union(other_n); } } - // determine which tensor to have the initial buffer, and will share across the cluster, we take a topological order - // of the computational graph, and find out which tensor comes first in a cluster. + // determine which tensor to have the initial buffer, and will share across + // the cluster, we take a topological order of the computational graph, and + // find out which tensor comes first in a cluster. auto _topo_order_topo_edges_ = comp_graph->topological_order(); - auto& topo_order = std::get<0>(_topo_order_topo_edges_); - auto& topo_edges = std::get<1>(_topo_order_topo_edges_); + auto& topo_order = std::get<0>(_topo_order_topo_edges_); + auto& topo_edges = std::get<1>(_topo_order_topo_edges_); for (common::GraphNode* n : topo_order) { auto nn = n->safe_as(); CHECK(nn); @@ -108,7 +113,8 @@ std::map InitialAssignBuffer(Expr* expr, auto it = uf_map.find(nn->tensor->name); CHECK(it != uf_map.end()); auto& cluster_info = std::get<0>(it->second->GetRoot())->cluster_info; - if (cluster_info.empty()) { // buffer owner(a tensor) of this cluster not set yet. + if (cluster_info.empty()) { // buffer owner(a tensor) of this cluster not + // set yet. cluster_info = nn->tensor->name; } } @@ -116,20 +122,22 @@ std::map InitialAssignBuffer(Expr* expr, // Get a center of the cluster, it will consider the following rules // 1. Prefer a tensor arg than a temp tensor. - auto cluster_get_center_tensor = [&](const std::vector& cluster) { - ir::Tensor some_tensor; - // try to find a node that is a tensor_arg, allocate buffer for it, and make others share buffer with it. - for (auto* n : cluster) { - auto* node = n->safe_as(); - bool is_temp = temp_tensor_names.count(node->tensor_name); - if (!is_temp) return all_tensor_map.at(node->tensor_name); - if (all_tensor_map.at(node->tensor_name)->buffer.defined()) { - return all_tensor_map.at(node->tensor_name); - } - some_tensor = all_tensor_map.at(node->tensor_name); - } - return some_tensor; - }; + auto cluster_get_center_tensor = + [&](const std::vector& cluster) { + ir::Tensor some_tensor; + // try to find a node that is a tensor_arg, allocate buffer for it, and + // make others share buffer with it. + for (auto* n : cluster) { + auto* node = n->safe_as(); + bool is_temp = temp_tensor_names.count(node->tensor_name); + if (!is_temp) return all_tensor_map.at(node->tensor_name); + if (all_tensor_map.at(node->tensor_name)->buffer.defined()) { + return all_tensor_map.at(node->tensor_name); + } + some_tensor = all_tensor_map.at(node->tensor_name); + } + return some_tensor; + }; for (auto& cluster : union_find.GetClusters()) { auto root_tensor = cluster_get_center_tensor(cluster); @@ -142,7 +150,7 @@ std::map InitialAssignBuffer(Expr* expr, if (tensor != root_tensor) { auto keep_shape = root_tensor->buffer->shape; Reference(&tensor)->Bind(root_tensor->buffer); - root_tensor->buffer->shape = keep_shape; + root_tensor->buffer->shape = keep_shape; Reference(&tensor)->buffer->shape = keep_shape; VLOG(3) << "keep_shape is : " << utils::GetStreamCnt(keep_shape[0]); } diff --git a/paddle/cinn/optim/buffer_assign.h b/paddle/cinn/optim/buffer_assign.h index bd5e0c1359413..e44b3a77cee2e 100644 --- a/paddle/cinn/optim/buffer_assign.h +++ b/paddle/cinn/optim/buffer_assign.h @@ -29,11 +29,12 @@ namespace optim { * @param expr * @param stages The stage map. */ -std::map InitialAssignBuffer(Expr* expr, - poly::StageMap stages, - const std::map& all_tensor_map, - const common::Graph* comp_graph, - const std::set& temp_tensor_names); +std::map InitialAssignBuffer( + Expr* expr, + poly::StageMap stages, + const std::map& all_tensor_map, + const common::Graph* comp_graph, + const std::set& temp_tensor_names); } // namespace optim } // namespace cinn diff --git a/paddle/cinn/optim/call_arg_list_to_pod_value.cc b/paddle/cinn/optim/call_arg_list_to_pod_value.cc index c1f9389cdfa2f..62afec620f364 100644 --- a/paddle/cinn/optim/call_arg_list_to_pod_value.cc +++ b/paddle/cinn/optim/call_arg_list_to_pod_value.cc @@ -34,23 +34,25 @@ struct CallArgListToPodValueMutator : ir::IRMutator<> { void Visit(const ir::Call* op, Expr* expr) override { if (op->is_cinn_call()) { auto _oprs_args_ = pack_arg_exprs(op); // NOLINT - auto& oprs = std::get<0>(_oprs_args_); - auto& args = std::get<1>(_oprs_args_); + auto& oprs = std::get<0>(_oprs_args_); + auto& args = std::get<1>(_oprs_args_); - Var pod_array_var(Context::Global().NewName("_pod_arr"), - type_of().with_lanes(op->total_args_count())); + Var pod_array_var( + Context::Global().NewName("_pod_arr"), + type_of().with_lanes(op->total_args_count())); // Declare pod_array. oprs.push_back(ir::Let::Make(pod_array_var, Expr())); oprs.push_back(ir::intrinsics::ArgsConstruct::Make(pod_array_var, args)); - auto new_call = ir::Call::Make(Void(), - op->name, - {pod_array_var, common::make_const(Int(32), args.size())}, - {}, - ir::CallType::CINN, - op->func, - op->value_index); + auto new_call = ir::Call::Make( + Void(), + op->name, + {pod_array_var, common::make_const(Int(32), args.size())}, + {}, + ir::CallType::CINN, + op->func, + op->value_index); oprs.push_back(new_call); @@ -58,12 +60,14 @@ struct CallArgListToPodValueMutator : ir::IRMutator<> { } } - std::tuple /*oprs*/, std::vector /*args*/> pack_arg_exprs(const ir::Call* op) { + std::tuple /*oprs*/, std::vector /*args*/> + pack_arg_exprs(const ir::Call* op) { std::vector exprs; std::vector args; auto pack_arg = [&](const Expr& arg) { - Var pod_var(Context::Global().NewName("_pod_val_"), type_of()); + Var pod_var(Context::Global().NewName("_pod_val_"), + type_of()); // declare the array. exprs.push_back(ir::Let::Make(pod_var, Expr())); @@ -73,14 +77,23 @@ struct CallArgListToPodValueMutator : ir::IRMutator<> { Expr cast; if (arg.As()) { cast = runtime::IntrinsicCall( - Void(), runtime::intrinsic::buffer_p_to_cinn_pod_value_repr, {arg}, {pod_val_addr_expr}); + Void(), + runtime::intrinsic::buffer_p_to_cinn_pod_value_repr, + {arg}, + {pod_val_addr_expr}); } else if (arg.type() == type_of()) { cast = runtime::IntrinsicCall( - Void(), runtime::intrinsic::float_to_cinn_pod_value_repr, {arg}, {pod_val_addr_expr}); + Void(), + runtime::intrinsic::float_to_cinn_pod_value_repr, + {arg}, + {pod_val_addr_expr}); } else if (arg.type() == type_of()) { cast = runtime::IntrinsicCall( - Void(), runtime::intrinsic::int32_to_cinn_pod_value_repr, {arg}, {pod_val_addr_expr}); + Void(), + runtime::intrinsic::int32_to_cinn_pod_value_repr, + {arg}, + {pod_val_addr_expr}); } else { CINN_NOT_IMPLEMENTED } diff --git a/paddle/cinn/optim/cast_simplify.h b/paddle/cinn/optim/cast_simplify.h index 595f85e2786da..072f39783d187 100644 --- a/paddle/cinn/optim/cast_simplify.h +++ b/paddle/cinn/optim/cast_simplify.h @@ -23,7 +23,8 @@ namespace cinn::optim { * * There are several patterns: * 1. the source and target type are the same, drop the Cast node - * 2. for intermediate numbers, just replace the Cast node with a Node of the target type + * 2. for intermediate numbers, just replace the Cast node with a Node of the + * target type */ void CastSimplify(Expr* e); diff --git a/paddle/cinn/optim/collect_undefined_vars.cc b/paddle/cinn/optim/collect_undefined_vars.cc index 31b91c1f26944..a912a484fce1b 100644 --- a/paddle/cinn/optim/collect_undefined_vars.cc +++ b/paddle/cinn/optim/collect_undefined_vars.cc @@ -28,8 +28,10 @@ struct Mutator : public ir::IRMutator<> { std::set used_vars; void CollectVarDef(const std::string& var) { - CHECK(!defined_vars.count(var)) << "var " << var << " has been defined, please check"; - CHECK(!used_vars.count(var)) << "var " << var << " is wrongly used before definition"; + CHECK(!defined_vars.count(var)) + << "var " << var << " has been defined, please check"; + CHECK(!used_vars.count(var)) + << "var " << var << " is wrongly used before definition"; defined_vars.insert(var); } @@ -47,7 +49,7 @@ struct Mutator : public ir::IRMutator<> { void Visit(const ir::Let* op, Expr* expr) final { Expr symbol = op->symbol; - auto var = symbol.as_var_ref(); + auto var = symbol.as_var_ref(); CHECK(var.defined()); CollectVarDef(var->name); auto* node = expr->As(); diff --git a/paddle/cinn/optim/compute_inline_expand.cc b/paddle/cinn/optim/compute_inline_expand.cc index 7dae4cfeae5eb..aef64af01e011 100644 --- a/paddle/cinn/optim/compute_inline_expand.cc +++ b/paddle/cinn/optim/compute_inline_expand.cc @@ -45,7 +45,9 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> { TensorInlineExpandMutator(const std::string &tensor_name, std::map *all_tensor_map, poly::StageMap stages) - : tensor_name_(tensor_name), all_tensor_map_(all_tensor_map), stages_(stages) {} + : tensor_name_(tensor_name), + all_tensor_map_(all_tensor_map), + stages_(stages) {} void operator()(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); @@ -57,7 +59,8 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> { void Visit(const ir::_Var_ *expr, Expr *op) override { if (inline_code && temp_buffer) { - if (utils::Startswith(expr->name, "blockIdx") || (utils::Startswith(expr->name, "threadIdx") && memory_local)) { + if (utils::Startswith(expr->name, "blockIdx") || + (utils::Startswith(expr->name, "threadIdx") && memory_local)) { *op = ir::Expr(0); } } @@ -65,7 +68,8 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> { void Visit(const ir::_Tensor_ *op, Expr *expr) override { if (inline_code && utils::Endswith(op->name, "_write_cache") && - (*all_tensor_map_).at(op->name)->buffer->memory_type == ir::MemoryType::Heap) { + (*all_tensor_map_).at(op->name)->buffer->memory_type == + ir::MemoryType::Heap) { auto no_cache_name = op->name.substr(0, op->name.size() - 12); VLOG(2) << "no_cache_name: " << no_cache_name; CHECK(all_tensor_map_->count(no_cache_name)); @@ -75,30 +79,33 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> { void Visit(const ir::For *op, Expr *expr) override { CHECK(op->extent.is_constant()); - int cons_extent = (int)op->extent.get_constant(); + int cons_extent = (int)op->extent.get_constant(); var_to_extent[op->loop_var->name] = op->extent; ir::IRMutator<>::Visit(op, expr); } void Visit(const ir::PolyFor *op, Expr *expr) override { - auto extent = op->ExtractExtent(); + auto extent = op->ExtractExtent(); var_to_extent[op->iterator->name] = extent; ir::IRMutator<>::Visit(op, expr); } void Visit(const ir::Load *op, Expr *expr) override { - auto *node = expr->As(); + auto *node = expr->As(); auto *tensor = node->tensor.as_tensor(); if (tensor && tensor->name == tensor_name_) { - *expr = tensor->inline_expanded(op->indices); + *expr = tensor->inline_expanded(op->indices); inline_code = true; ir::IRMutator<>::Visit(expr, expr); inline_code = false; } else if (inline_code && tensor->buffer.defined()) { - bool is_heap = (*all_tensor_map_).at(tensor->name)->buffer->memory_type == ir::MemoryType::Heap; + bool is_heap = (*all_tensor_map_).at(tensor->name)->buffer->memory_type == + ir::MemoryType::Heap; if (utils::Endswith(tensor->buffer->name, "_write_cache") && is_heap) { - // temp fix: cache_write will change the tensor to the cache tensor wrongly - auto no_cache_name = tensor->buffer->name.substr(1, tensor->buffer->name.size() - 13); + // temp fix: cache_write will change the tensor to the cache tensor + // wrongly + auto no_cache_name = + tensor->buffer->name.substr(1, tensor->buffer->name.size() - 13); if (all_tensor_map_->count(no_cache_name)) { ir::IRMutator<>::Visit(&node->tensor, &node->tensor); } else { @@ -117,7 +124,7 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> { utils::Endswith(tensor->buffer->name, "_read_cache") || utils::Endswith(tensor->buffer->name, "_temp_buffer")) { #ifdef CINN_WITH_CUDA - auto axis_names = stages_[tensor]->axis_names(); + auto axis_names = stages_[tensor]->axis_names(); auto compute_ats = stages_[tensor]->GetComputeAts(); if (compute_ats.size() == 1) { int level_tmp; @@ -127,16 +134,18 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> { std::vector replace_vars; for (int j = 0; j <= level_tmp; j++) { if (var_to_extent.count(axis_names[j]) == 0) continue; - replace_vars.push_back(Var(var_to_extent[axis_names[j]], axis_names[j])); + replace_vars.push_back( + Var(var_to_extent[axis_names[j]], axis_names[j])); } replace_var.push_back(replace_vars); tensor_names.push_back(tensor->buffer->name); } #endif - bool keep_buffer = temp_buffer; - temp_buffer = true; + bool keep_buffer = temp_buffer; + temp_buffer = true; bool keep_memory_local = memory_local; - if ((*all_tensor_map_).at(tensor->name)->buffer->memory_type == ir::MemoryType::GPULocal) { + if ((*all_tensor_map_).at(tensor->name)->buffer->memory_type == + ir::MemoryType::GPULocal) { memory_local = true; } ir::IRMutator<>::Visit(&node->tensor, &node->tensor); @@ -145,7 +154,7 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> { ir::IRMutator<>::Visit(&temp, &temp); node->indices[i] = temp; } - temp_buffer = keep_buffer; + temp_buffer = keep_buffer; memory_local = keep_memory_local; } else { ir::IRMutator<>::Visit(&node->tensor, &node->tensor); @@ -178,8 +187,9 @@ struct SSANode : public common::GraphNode { static constexpr char *__type_info__ = "optim::SSANode"; }; -// TODO(Superjomn) the graph here is not a SSA now, it is flattern for the ir::CollectIRNodes method collects all the -// tensors recursively, so it can not reserve the level information, fix it. +// TODO(Superjomn) the graph here is not a SSA now, it is flattern for the +// ir::CollectIRNodes method collects all the tensors recursively, so it can not +// reserve the level information, fix it. struct SSABuilder : public ir::IRMutator<> { common::Graph graph; @@ -193,7 +203,9 @@ struct SSABuilder : public ir::IRMutator<> { auto *cur_graph_node = graph.RetrieveNode(node->tensor.as_tensor()->name); if (!cur_graph_node) { - cur_graph_node = graph.RegisterNode(node->tensor.as_tensor()->name, new SSANode(node->tensor.as_tensor()->name)); + cur_graph_node = + graph.RegisterNode(node->tensor.as_tensor()->name, + new SSANode(node->tensor.as_tensor()->name)); } auto deps_tensor_names = node->tensor.as_tensor()->GetDependTensorNames(); @@ -209,14 +221,18 @@ struct SSABuilder : public ir::IRMutator<> { } // namespace -void ComputeInlineExpand(Expr *expr, poly::StageMap stages, std::map *all_tensor_map) { +void ComputeInlineExpand(Expr *expr, + poly::StageMap stages, + std::map *all_tensor_map) { // the inline tensors contained in the expression. - auto inline_tensors = - ir::CollectIRNodes(*expr, [&](const Expr *x) { return x->as_tensor() && stages[x->as_tensor()]->inlined(); }); + auto inline_tensors = ir::CollectIRNodes(*expr, [&](const Expr *x) { + return x->as_tensor() && stages[x->as_tensor()]->inlined(); + }); // keep inline expand if any inline tensor exists - // NOTE This is a naive method to greedily expand the inline tensors until none exists, a better way is to create a - // SSA graph and expand the inline tensors in the reversed dependency order. + // NOTE This is a naive method to greedily expand the inline tensors until + // none exists, a better way is to create a SSA graph and expand the inline + // tensors in the reversed dependency order. // TODO(Superjomn) Use the SSA graph to improve this. while (!inline_tensors.empty()) { for (const auto &t : inline_tensors) { @@ -224,8 +240,9 @@ void ComputeInlineExpand(Expr *expr, poly::StageMap stages, std::mapname, all_tensor_map, stages)(expr); } - inline_tensors = ir::CollectLoadTensors( - *expr, [&](const Expr *x) { return x->as_tensor() && stages[x->as_tensor()]->inlined(); }); + inline_tensors = ir::CollectLoadTensors(*expr, [&](const Expr *x) { + return x->as_tensor() && stages[x->as_tensor()]->inlined(); + }); } } diff --git a/paddle/cinn/optim/compute_inline_expand.h b/paddle/cinn/optim/compute_inline_expand.h index eb17641dc50b1..07187ec9eeb39 100644 --- a/paddle/cinn/optim/compute_inline_expand.h +++ b/paddle/cinn/optim/compute_inline_expand.h @@ -27,7 +27,9 @@ namespace optim { * @param tensor_name name of the tensor to expand inline. * @param memo a memo to avoid duplicate expand. */ -void ComputeInlineExpand(Expr* expr, poly::StageMap stages, std::map* all_tensor_map); +void ComputeInlineExpand(Expr* expr, + poly::StageMap stages, + std::map* all_tensor_map); } // namespace optim } // namespace cinn diff --git a/paddle/cinn/optim/eliminate_broadcast_in_forloop.cc b/paddle/cinn/optim/eliminate_broadcast_in_forloop.cc index 46a17b4954abb..d9f65cdb80a3e 100644 --- a/paddle/cinn/optim/eliminate_broadcast_in_forloop.cc +++ b/paddle/cinn/optim/eliminate_broadcast_in_forloop.cc @@ -36,14 +36,18 @@ struct EliminateBroadcastInForloop : public ir::IRMutator { auto* node = expr->As(); - auto broadcasts = ir::CollectIRNodes(node->value, [&](const Expr* expr) { return expr->As(); }); + auto broadcasts = ir::CollectIRNodes(node->value, [&](const Expr* expr) { + return expr->As(); + }); std::vector let_exprs; Var tmp; Expr let_expr; - Var cur_level_loop_var = forloop_stack.back()->As() ? forloop_stack.back()->As()->loop_var - : forloop_stack.back()->As()->iterator; + Var cur_level_loop_var = + forloop_stack.back()->As() + ? forloop_stack.back()->As()->loop_var + : forloop_stack.back()->As()->iterator; for (Expr broadcast : broadcasts) { if (ContainsLoopVar(broadcast, cur_level_loop_var)) continue; VLOG(4) << "eliminating " << broadcast; @@ -57,13 +61,16 @@ struct EliminateBroadcastInForloop : public ir::IRMutator { Expr* outer_forloop = forloop_stack[forloop_stack.size() - 2]; - auto& outer_forloop_body = - outer_forloop->As() ? outer_forloop->As()->body : outer_forloop->As()->body; + auto& outer_forloop_body = outer_forloop->As() + ? outer_forloop->As()->body + : outer_forloop->As()->body; auto* outer_forloop_body_block = outer_forloop_body.As(); if (outer_forloop_body_block) { outer_forloop_body_block->stmts.insert( - std::begin(outer_forloop_body_block->stmts), let_exprs.begin(), let_exprs.end()); + std::begin(outer_forloop_body_block->stmts), + let_exprs.begin(), + let_exprs.end()); } else { let_exprs.push_back(outer_forloop_body); @@ -73,7 +80,8 @@ struct EliminateBroadcastInForloop : public ir::IRMutator { bool ContainsLoopVar(Expr expr, Var loop_var) { return !ir::CollectIRNodes(expr, [&](const Expr* e) -> bool { - return e->As() && e->As()->name == loop_var->name; + return e->As() && + e->As()->name == loop_var->name; }).empty(); } diff --git a/paddle/cinn/optim/extern_call_process.cc b/paddle/cinn/optim/extern_call_process.cc index 6c9532a02fa99..be3636c81982e 100644 --- a/paddle/cinn/optim/extern_call_process.cc +++ b/paddle/cinn/optim/extern_call_process.cc @@ -35,7 +35,9 @@ struct ExternCallMultiOutputShallowStoreMutator : public ir::IRMutator<> { } // namespace -void ExternCallMultiOutputShallowStore(Expr* e) { ExternCallMultiOutputShallowStoreMutator()(e); } +void ExternCallMultiOutputShallowStore(Expr* e) { + ExternCallMultiOutputShallowStoreMutator()(e); +} } // namespace optim } // namespace cinn diff --git a/paddle/cinn/optim/fold_cinn_call_arguments.cc b/paddle/cinn/optim/fold_cinn_call_arguments.cc index 3ce60fd569e54..8ce5743d2230c 100644 --- a/paddle/cinn/optim/fold_cinn_call_arguments.cc +++ b/paddle/cinn/optim/fold_cinn_call_arguments.cc @@ -82,7 +82,8 @@ struct FoldCINNCallArgumentsMutator : public ir::IRMutator<> { std::vector write_args; for (auto& arg : call->read_args) { if (arg.as_tensor()) { - CHECK(arg.as_tensor()->buffer.defined()) << "arg tensor [" << arg.as_tensor()->name << "] not has buffer"; + CHECK(arg.as_tensor()->buffer.defined()) + << "arg tensor [" << arg.as_tensor()->name << "] not has buffer"; read_args.push_back(arg.as_tensor()->buffer); } else { read_args.push_back(arg); @@ -97,7 +98,7 @@ struct FoldCINNCallArgumentsMutator : public ir::IRMutator<> { } } - call->read_args = read_args; + call->read_args = read_args; call->write_args = write_args; } diff --git a/paddle/cinn/optim/fold_cinn_call_arguments.h b/paddle/cinn/optim/fold_cinn_call_arguments.h index 78cf2bd91a5cd..740facfba306f 100644 --- a/paddle/cinn/optim/fold_cinn_call_arguments.h +++ b/paddle/cinn/optim/fold_cinn_call_arguments.h @@ -22,8 +22,8 @@ namespace cinn { namespace optim { /** - * \brief Rewrite the Call Nodes marked type as CINN, pack their arguments into `void*, int` so that they can trigger a - * `LoweredFunc`. + * \brief Rewrite the Call Nodes marked type as CINN, pack their arguments into + * `void*, int` so that they can trigger a `LoweredFunc`. * * For example, input the IR * \code diff --git a/paddle/cinn/optim/if_simplify.cc b/paddle/cinn/optim/if_simplify.cc index 5226ae64c81e2..b3b46385d6270 100644 --- a/paddle/cinn/optim/if_simplify.cc +++ b/paddle/cinn/optim/if_simplify.cc @@ -24,7 +24,7 @@ struct Mutator : public ir::IRMutator<> { using ir::IRMutator<>::Visit; void Visit(const ir::IfThenElse* op, Expr* expr) { - auto* condition_int = op->condition.As(); + auto* condition_int = op->condition.As(); auto* condition_uint = op->condition.As(); int64_t value; if (condition_int || condition_uint) { diff --git a/paddle/cinn/optim/if_simplify_test.cc b/paddle/cinn/optim/if_simplify_test.cc index 201d82f8f1aca..2be36eb14c3fa 100644 --- a/paddle/cinn/optim/if_simplify_test.cc +++ b/paddle/cinn/optim/if_simplify_test.cc @@ -24,7 +24,8 @@ namespace cinn::optim { TEST(IfSimplify, if_true) { Var n("n"); - auto e = ir::IfThenElse::Make(Expr(1) /*true*/, ir::Let::Make(n, Expr(1)), ir::Let::Make(n, Expr(2))); + auto e = ir::IfThenElse::Make( + Expr(1) /*true*/, ir::Let::Make(n, Expr(1)), ir::Let::Make(n, Expr(2))); LOG(INFO) << "\n" << e; @@ -37,7 +38,8 @@ TEST(IfSimplify, if_true) { TEST(IfSimplify, if_false) { Var n("n"); - auto e = ir::IfThenElse::Make(Expr(0) /*false*/, ir::Let::Make(n, Expr(1)), ir::Let::Make(n, Expr(2))); + auto e = ir::IfThenElse::Make( + Expr(0) /*false*/, ir::Let::Make(n, Expr(1)), ir::Let::Make(n, Expr(2))); LOG(INFO) << "\n" << e; diff --git a/paddle/cinn/optim/insert_debug_log_callee.cc b/paddle/cinn/optim/insert_debug_log_callee.cc index 7addcad664d01..f4e3c556a72b9 100644 --- a/paddle/cinn/optim/insert_debug_log_callee.cc +++ b/paddle/cinn/optim/insert_debug_log_callee.cc @@ -129,24 +129,26 @@ struct InsertDebugLogCalleeMutator : public ir::IRMutator<> { void operator()(Expr *e) { ir::IRMutator<>::Visit(e, e); } void Visit(const ir::_LoweredFunc_ *op, Expr *expr) { - auto *node = expr->As(); + auto *node = expr->As(); auto *body_block = node->body.As(); CHECK(body_block); - auto msg = StringFormat("running : %s", GetDebugString(*expr).c_str()); + auto msg = StringFormat("running : %s", GetDebugString(*expr).c_str()); auto debug_node = CreateDebugStatement(msg); ir::IRMutator<>::Visit(&node->body, &node->body); - auto deal_with_exprs = [&](std::vector *exprs) { // deal with op->argument_preapre_exprs - std::vector new_stmts; - for (auto &expr : *exprs) { - auto msg = StringFormat("running : %s", GetDebugString(expr).c_str()); - new_stmts.push_back(CreateDebugStatement(msg)); - new_stmts.push_back(expr); - } - *exprs = new_stmts; - }; + auto deal_with_exprs = + [&](std::vector *exprs) { // deal with op->argument_preapre_exprs + std::vector new_stmts; + for (auto &expr : *exprs) { + auto msg = + StringFormat("running : %s", GetDebugString(expr).c_str()); + new_stmts.push_back(CreateDebugStatement(msg)); + new_stmts.push_back(expr); + } + *exprs = new_stmts; + }; deal_with_exprs(&node->alloc_output_buffer_exprs); deal_with_exprs(&node->dealloc_output_buffer_exprs); @@ -163,14 +165,15 @@ struct InsertDebugLogCalleeMutator : public ir::IRMutator<> { if (!IsDebugInfoNode(e)) { std::string msg; if (!e.As()) { - msg = StringFormat("running: %s", GetDebugString(e).c_str()); + msg = StringFormat("running: %s", GetDebugString(e).c_str()); auto debug_info_node = CreateDebugStatement(msg); new_stmts.push_back(debug_info_node); } else { - auto _msg_args_ = StoreDebugInfo(e); - auto &msg = std::get<0>(_msg_args_); - auto &args = std::get<1>(_msg_args_); - auto debug_info_node = CreateDebugStatement("running: " + msg, std::move(args)); + auto _msg_args_ = StoreDebugInfo(e); + auto &msg = std::get<0>(_msg_args_); + auto &args = std::get<1>(_msg_args_); + auto debug_info_node = + CreateDebugStatement("running: " + msg, std::move(args)); new_stmts.push_back(debug_info_node); } } @@ -180,16 +183,16 @@ struct InsertDebugLogCalleeMutator : public ir::IRMutator<> { new_stmts.push_back(e); if (!IsDebugInfoNode(e) && e.As()) { - auto _msg_args_ = StoreDebugInfo(e); - auto &msg = std::get<0>(_msg_args_); - auto &args = std::get<1>(_msg_args_); + auto _msg_args_ = StoreDebugInfo(e); + auto &msg = std::get<0>(_msg_args_); + auto &args = std::get<1>(_msg_args_); auto debug_info_node = CreateDebugStatement(msg, std::move(args)); new_stmts.push_back(debug_info_node); { // detailed debug auto _format_args_ = StoreDebugInfoBuilder()(&e); - auto &format = std::get<0>(_format_args_); - auto &args = std::get<1>(_format_args_); + auto &format = std::get<0>(_format_args_); + auto &args = std::get<1>(_format_args_); new_stmts.push_back(CreateDebugStatement(format, std::move(args))); } } @@ -206,12 +209,14 @@ struct InsertDebugLogCalleeMutator : public ir::IRMutator<> { break; case ir::IrNodeTy::For: { auto *node = e.As(); - ss << "loop_var << " in [" << node->min << ", " << node->extent << ")>"; + ss << "loop_var << " in [" << node->min << ", " + << node->extent << ")>"; break; } case ir::IrNodeTy::PolyFor: { auto *node = e.As(); - ss << "iterator << " in [" << node->init << ", " << node->ExtractExtent() << ")" + ss << "iterator << " in [" << node->init << ", " + << node->ExtractExtent() << ")" << " with condition: " << node->condition << ">"; break; } @@ -257,13 +262,20 @@ struct InsertDebugLogCalleeMutator : public ir::IRMutator<> { } inline bool IsDebugInfoNode(const Expr &e) { - return e.As() && e.As()->name == runtime::intrinsic::debug_log_repr; + return e.As() && + e.As()->name == runtime::intrinsic::debug_log_repr; } - Expr CreateDebugStatement(const std::string &msg, std::vector &&args = {}) { + Expr CreateDebugStatement(const std::string &msg, + std::vector &&args = {}) { args.insert(args.begin(), Expr(msg)); - return ir::Call::Make( - Void(), runtime::intrinsic::debug_log_repr, args, {}, ir::CallType ::Intrinsic, ir::FunctionRef(), 0); + return ir::Call::Make(Void(), + runtime::intrinsic::debug_log_repr, + args, + {}, + ir::CallType ::Intrinsic, + ir::FunctionRef(), + 0); } }; diff --git a/paddle/cinn/optim/ir_copy.cc b/paddle/cinn/optim/ir_copy.cc index 6adf9b44a1a9b..485b1606de5c9 100644 --- a/paddle/cinn/optim/ir_copy.cc +++ b/paddle/cinn/optim/ir_copy.cc @@ -40,10 +40,18 @@ struct IRCopyVisitor : public ir::IRVisitorBase { protected: // The methods of ir nodes follows the order defined in node.h - Expr Visit(const ir::IntImm* op) override { return Expr(make_shared(op->type(), op->value)); } - Expr Visit(const ir::UIntImm* op) override { return Expr(make_shared(op->type(), op->value)); } - Expr Visit(const ir::FloatImm* op) override { return Expr(make_shared(op->type(), op->value)); } - Expr Visit(const ir::StringImm* op) override { return Expr(common::make_shared(op->value)); } + Expr Visit(const ir::IntImm* op) override { + return Expr(make_shared(op->type(), op->value)); + } + Expr Visit(const ir::UIntImm* op) override { + return Expr(make_shared(op->type(), op->value)); + } + Expr Visit(const ir::FloatImm* op) override { + return Expr(make_shared(op->type(), op->value)); + } + Expr Visit(const ir::StringImm* op) override { + return Expr(common::make_shared(op->value)); + } Expr Visit(const ir::Cast* op) override { auto v = Visit(&op->v()); @@ -51,8 +59,8 @@ struct IRCopyVisitor : public ir::IRVisitorBase { } Expr Visit(const Select* op) override { - auto condition = Visit(&op->condition); - auto true_value = Visit(&op->true_value); + auto condition = Visit(&op->condition); + auto true_value = Visit(&op->true_value); auto false_value = Visit(&op->false_value); return Select::Make(condition, true_value, false_value); } @@ -74,15 +82,22 @@ struct IRCopyVisitor : public ir::IRVisitorBase { } Expr Visit(const Call* op) override { - auto read_args = Visit(op->read_args); + auto read_args = Visit(op->read_args); auto write_args = Visit(op->write_args); - return Call::Make(op->type(), op->name, read_args, write_args, op->call_type, FunctionRef(), 0, op->attrs); + return Call::Make(op->type(), + op->name, + read_args, + write_args, + op->call_type, + FunctionRef(), + 0, + op->attrs); } Expr Visit(const _Var_* op) override { auto* n = make_shared<_Var_>(); - n->name = op->name; + n->name = op->name; n->is_reduce_axis = op->is_reduce_axis; n->set_type(op->type()); @@ -107,7 +122,7 @@ struct IRCopyVisitor : public ir::IRVisitorBase { Expr Visit(const Store* op) override { auto tensor = Visit(&op->tensor); - auto value = Visit(&op->value); + auto value = Visit(&op->value); std::vector indices; for (auto& idx : op->indices) indices.push_back(Visit(&idx)); @@ -131,25 +146,25 @@ struct IRCopyVisitor : public ir::IRVisitorBase { return buffer_map[op->name]; } - auto shape = Visit(op->shape); - auto strides = Visit(op->strides); - auto name = op->name; - auto scope = op->scope; + auto shape = Visit(op->shape); + auto strides = Visit(op->strides); + auto name = op->name; + auto scope = op->scope; int data_alignment = op->data_alignment; - auto elem_offset = Visit(&op->elem_offset); - int offset_factor = op->offset_factor; - Target target = op->target; - - auto new_node = _Buffer_::Make(name, shape); - new_node->strides = strides; - new_node->dtype = op->dtype; // copy data element's type. - new_node->name = name; - new_node->scope = scope; + auto elem_offset = Visit(&op->elem_offset); + int offset_factor = op->offset_factor; + Target target = op->target; + + auto new_node = _Buffer_::Make(name, shape); + new_node->strides = strides; + new_node->dtype = op->dtype; // copy data element's type. + new_node->name = name; + new_node->scope = scope; new_node->data_alignment = data_alignment; - new_node->elem_offset = elem_offset; - new_node->offset_factor = offset_factor; - new_node->target = target; - new_node->memory_type = op->memory_type; + new_node->elem_offset = elem_offset; + new_node->offset_factor = offset_factor; + new_node->target = target; + new_node->memory_type = op->memory_type; new_node->set_type(op->type()); op->CopyMeta(new_node.As()); @@ -163,23 +178,23 @@ struct IRCopyVisitor : public ir::IRVisitorBase { return tensor_map[op->name]; } - auto shape = Visit(op->shape); - auto domain = Visit(op->domain); + auto shape = Visit(op->shape); + auto domain = Visit(op->domain); auto buffer_expr = Expr(op->buffer); // TODO(Superjomn) copy the operation. auto operaion = op->operation; - auto name = op->name; - auto tensor = make_shared<_Tensor_>(); + auto name = op->name; + auto tensor = make_shared<_Tensor_>(); if (buffer_expr.defined()) { - auto buffer = Visit(&buffer_expr); + auto buffer = Visit(&buffer_expr); tensor->buffer = buffer.as_buffer_ref(); } - tensor->domain = domain; - tensor->shape = shape; + tensor->domain = domain; + tensor->shape = shape; tensor->reduce_axis = op->reduce_axis; - tensor->operation = operaion; - tensor->name = name; + tensor->operation = operaion; + tensor->name = name; tensor->set_type(op->type()); tensor->axis_ = op->axis_; @@ -190,19 +205,25 @@ struct IRCopyVisitor : public ir::IRVisitorBase { Expr Visit(const For* op) override { auto extent = Visit(&op->extent); - auto min = Visit(&op->min); - auto body = Visit(&op->body); + auto min = Visit(&op->min); + auto body = Visit(&op->body); - return ir::For::Make( - op->loop_var, min, extent, op->for_type(), op->device_api, body, op->vectorize_info(), op->bind_info()); + return ir::For::Make(op->loop_var, + min, + extent, + op->for_type(), + op->device_api, + body, + op->vectorize_info(), + op->bind_info()); } Expr Visit(const ir::PolyFor* op) override { - auto init = Visit(&op->init); + auto init = Visit(&op->init); auto condition = Visit(&op->condition); - auto inc = Visit(&op->inc); - auto body = Visit(&op->body); - auto expr = PolyFor::Make(op->iterator, + auto inc = Visit(&op->inc); + auto body = Visit(&op->body); + auto expr = PolyFor::Make(op->iterator, init, condition, inc, @@ -231,9 +252,9 @@ struct IRCopyVisitor : public ir::IRVisitorBase { submodules.push_back(Visit(&expr)); } - auto res = ir::_Module_::Make(op->name, op->target); - res->buffers = buffers; - res->functions = functions; + auto res = ir::_Module_::Make(op->name, op->target); + res->buffers = buffers; + res->functions = functions; res->submodules = submodules; return Expr(res); @@ -242,9 +263,9 @@ struct IRCopyVisitor : public ir::IRVisitorBase { Expr Visit(const _LoweredFunc_* op) override { auto func = make_shared<_LoweredFunc_>(); - func->name = op->name; - func->args = op->args; - func->body = Visit(&op->body); + func->name = op->name; + func->args = op->args; + func->body = Visit(&op->body); func->temp_bufs = op->temp_bufs; func->device_api = op->device_api; @@ -272,7 +293,7 @@ struct IRCopyVisitor : public ir::IRVisitorBase { Expr Visit(const Let* op) override { auto value = Visit(&op->symbol); - auto body = Visit(&op->body); + auto body = Visit(&op->body); return Let::Make(value, body); } @@ -280,24 +301,25 @@ struct IRCopyVisitor : public ir::IRVisitorBase { Expr Visit(const Reduce* op) override { auto init = Visit(&op->init); auto body = Visit(&op->body); - std::vector reduce_axis(op->reduce_axis.begin(), op->reduce_axis.end()); + std::vector reduce_axis(op->reduce_axis.begin(), + op->reduce_axis.end()); return Reduce::Make(op->reduce_type, init, body, reduce_axis); } Expr Visit(const Ramp* op) override { - auto base = Visit(&op->base); + auto base = Visit(&op->base); auto stride = Visit(&op->stride); - int lanes = op->lanes; + int lanes = op->lanes; return Ramp::Make(base, stride, lanes); } Expr Visit(const Broadcast* op) override { auto value = Visit(&op->value); - int lanes = op->lanes; + int lanes = op->lanes; CHECK(value.defined()); CHECK(value.type().valid()); - auto* n = make_shared(); + auto* n = make_shared(); n->value = value; n->lanes = lanes; return Expr(n); @@ -310,8 +332,8 @@ struct IRCopyVisitor : public ir::IRVisitorBase { CHECK(b.defined()); auto* n = make_shared(); - n->a() = a; - n->b() = b; + n->a() = a; + n->b() = b; return Expr(n); } @@ -337,9 +359,9 @@ struct IRCopyVisitor : public ir::IRVisitorBase { arguments.push_back(Visit(args)); } - auto n = common::make_shared(); - n->name = op->name; - n->attrs = op->attrs; // attrs are PODs + auto n = common::make_shared(); + n->name = op->name; + n->attrs = op->attrs; // attrs are PODs n->arguments = arguments; return Expr(n); } @@ -368,7 +390,8 @@ struct IRCopyVisitor : public ir::IRVisitorBase { for (auto buffer_range : op->write_buffers) { write_buffers.push_back(Visit(&buffer_range)); } - Expr res = ir::ScheduleBlock::Make(iter_vars, read_buffers, write_buffers, op->name, Visit(&op->body)); + Expr res = ir::ScheduleBlock::Make( + iter_vars, read_buffers, write_buffers, op->name, Visit(&op->body)); res.As()->attrs = op->attrs; return res; } @@ -378,7 +401,8 @@ struct IRCopyVisitor : public ir::IRVisitorBase { for (auto iter_value : op->iter_values) { iter_values.push_back(Visit(&iter_value)); } - return ir::ScheduleBlockRealize::Make(iter_values, Visit(&op->schedule_block)); + return ir::ScheduleBlockRealize::Make(iter_values, + Visit(&op->schedule_block)); } #define __(x__) Expr Visit(const ir::intrinsics::x__* op); @@ -428,12 +452,15 @@ Expr IRCopyVisitor::Visit(const ir::intrinsics::BufferGetDataConstHandle* op) { return intrinsics::BufferGetDataConstHandle::Make(Visit(&op->buffer)); } Expr IRCopyVisitor::Visit(const ir::intrinsics::PodValueToX* op) { - return intrinsics::PodValueToX::Make(Visit(&op->pod_value_ptr), op->GetOutputType(0)); + return intrinsics::PodValueToX::Make(Visit(&op->pod_value_ptr), + op->GetOutputType(0)); } Expr IRCopyVisitor::Visit(const ir::intrinsics::BufferCreate* op) { return intrinsics::BufferCreate::Make(Visit(&op->buffer)); } -Expr IRCopyVisitor::Visit(const ir::intrinsics::GetAddr* op) { return intrinsics::GetAddr::Make(Visit(&op->data)); } +Expr IRCopyVisitor::Visit(const ir::intrinsics::GetAddr* op) { + return intrinsics::GetAddr::Make(Visit(&op->data)); +} Expr IRCopyVisitor::Visit(const ir::intrinsics::ArgsConstruct* op) { llvm::SmallVector args; for (auto& arg : op->args) { @@ -442,7 +469,8 @@ Expr IRCopyVisitor::Visit(const ir::intrinsics::ArgsConstruct* op) { return intrinsics::ArgsConstruct::Make(op->var, args); } Expr IRCopyVisitor::Visit(const ir::intrinsics::BuiltinIntrin* op) { - return intrinsics::BuiltinIntrin::Make(op->name, op->args, op->id, op->arg_nums, op->type()); + return intrinsics::BuiltinIntrin::Make( + op->name, op->args, op->id, op->arg_nums, op->type()); } Expr IRCopy(Expr x) { @@ -459,10 +487,12 @@ std::vector IRCopy(const std::vector& x) { return res; } -ir::ModuleExpr IRCopy(const ir::ModuleExpr& x) { return ir::ModuleExpr(IRCopy(x.GetExprs())); } +ir::ModuleExpr IRCopy(const ir::ModuleExpr& x) { + return ir::ModuleExpr(IRCopy(x.GetExprs())); +} ir::LoweredFunc IRCopy(const ir::LoweredFunc& x) { - ir::Expr copy_func_expr = IRCopy(static_cast(x)); + ir::Expr copy_func_expr = IRCopy(static_cast(x)); ir::_LoweredFunc_* copy_func_ptr = copy_func_expr.As(); return ir::LoweredFunc(copy_func_ptr); } diff --git a/paddle/cinn/optim/ir_replace.cc b/paddle/cinn/optim/ir_replace.cc index ce6f1f3c57f8c..5b80d8e59e28d 100755 --- a/paddle/cinn/optim/ir_replace.cc +++ b/paddle/cinn/optim/ir_replace.cc @@ -28,22 +28,27 @@ using utils::GetStreamCnt; namespace { struct IrReplaceMutator : ir::IRMutator { - std::set valid_nodetys{{ir::IrNodeTy::Broadcast, ir::IrNodeTy::_Var_}}; + std::set valid_nodetys{ + {ir::IrNodeTy::Broadcast, ir::IrNodeTy::_Var_}}; - IrReplaceMutator(ir::Expr from, Expr to) : from_(from), to_(to), from_repr_(GetStreamCnt(from)) { - CHECK(valid_nodetys.count(from->node_type())) << "Not valid node type got " << from->node_type(); + IrReplaceMutator(ir::Expr from, Expr to) + : from_(from), to_(to), from_repr_(GetStreamCnt(from)) { + CHECK(valid_nodetys.count(from->node_type())) + << "Not valid node type got " << from->node_type(); } void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } private: void Visit(const ir::_Var_* op, Expr* expr) override { - if (op->node_type() == from_->node_type() && from_repr_ == GetStreamCnt(*expr)) { + if (op->node_type() == from_->node_type() && + from_repr_ == GetStreamCnt(*expr)) { *expr = optim::IRCopy(to_); } } void Visit(const ir::Broadcast* op, Expr* expr) override { - if (op->node_type() == from_->node_type() && from_repr_ == GetStreamCnt(*expr)) { + if (op->node_type() == from_->node_type() && + from_repr_ == GetStreamCnt(*expr)) { *expr = optim::IRCopy(to_); } } diff --git a/paddle/cinn/optim/ir_simplify.cc b/paddle/cinn/optim/ir_simplify.cc index 910a6f17b03eb..cabe7ad0b01f9 100644 --- a/paddle/cinn/optim/ir_simplify.cc +++ b/paddle/cinn/optim/ir_simplify.cc @@ -41,23 +41,31 @@ using utils::Replace; namespace { -//! Simplify some sub-expression in the `expr`. Due to the simplify strategy just fit several kinds of IR noedes, we -//! partition the original expression to several sub-expression those supported by simplify, and process each of them. -void PartialSimplify(Expr* expr, const absl::flat_hash_map& var_intervals = {}) { +//! Simplify some sub-expression in the `expr`. Due to the simplify strategy +//! just fit several kinds of IR noedes, we partition the original expression to +//! several sub-expression those supported by simplify, and process each of +//! them. +void PartialSimplify( + Expr* expr, + const absl::flat_hash_map& var_intervals = + {}) { *expr = common::AutoSimplify(*expr, var_intervals); } //! Simplify the expression but Load. struct SimplifyButStoreLoadMutator : public ir::IRMutator { common::cas_intervals_t& var_intervals; - explicit SimplifyButStoreLoadMutator(common::cas_intervals_t& var_intervals) : var_intervals(var_intervals) {} + explicit SimplifyButStoreLoadMutator(common::cas_intervals_t& var_intervals) + : var_intervals(var_intervals) {} void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } using ir::IRMutator<>::Visit; -#define __(op__) \ - void Visit(const op__* op, Expr* expr) override { PartialSimplify(expr, var_intervals); } +#define __(op__) \ + void Visit(const op__* op, Expr* expr) override { \ + PartialSimplify(expr, var_intervals); \ + } __(Add) __(Mul) @@ -81,7 +89,7 @@ struct SimplifyButStoreLoadMutator : public ir::IRMutator { } void Visit(const PolyFor* op, Expr* expr) override { - auto* node = expr->As(); + auto* node = expr->As(); node->condition = common::SolveInequality(op->condition, op->iterator); Visit(&node->body, &node->body); @@ -91,12 +99,15 @@ struct SimplifyButStoreLoadMutator : public ir::IRMutator { auto* node = expr->As(); Visit(&node->min, &node->min); Visit(&node->extent, &node->extent); - auto* min_i = op->min.As(); + auto* min_i = op->min.As(); auto* extent_i = op->extent.As(); if (min_i && extent_i && extent_i->value > min_i->value) { - var_intervals.emplace(op->loop_var->name, common::CasInterval{min_i->value, extent_i->value - 1}); + var_intervals.emplace( + op->loop_var->name, + common::CasInterval{min_i->value, extent_i->value - 1}); } else { - var_intervals.emplace(op->loop_var->name, common::CasInterval{op->min, op->extent - 1}); + var_intervals.emplace(op->loop_var->name, + common::CasInterval{op->min, op->extent - 1}); } Visit(&node->body, &node->body); @@ -133,10 +144,12 @@ struct SimplifyLoadMutator : public ir::IRMutator { } void Visit(const For* op, Expr* expr) override { - auto* min_i = op->min.As(); + auto* min_i = op->min.As(); auto* extent_i = op->extent.As(); if (min_i && extent_i && extent_i->value > min_i->value) { - var_intervals_.emplace(op->loop_var->name, common::CasInterval{min_i->value, extent_i->value - 1}); + var_intervals_.emplace( + op->loop_var->name, + common::CasInterval{min_i->value, extent_i->value - 1}); } auto* node = expr->As(); @@ -169,10 +182,12 @@ struct SimplifyStoreMutator : public ir::IRMutator { } void Visit(const For* op, Expr* expr) override { - auto* min_i = op->min.As(); + auto* min_i = op->min.As(); auto* extent_i = op->extent.As(); if (min_i && extent_i) { - var_intervals_.emplace(op->loop_var->name, common::CasInterval{min_i->value, extent_i->value - 1}); + var_intervals_.emplace( + op->loop_var->name, + common::CasInterval{min_i->value, extent_i->value - 1}); } auto* node = expr->As(); @@ -194,24 +209,26 @@ struct SimplifyRampMutator : public ir::IRMutator { void Visit(const Ramp* op, Expr* expr) override { auto* node = expr->As(); - CHECK(common::IsPureMath(node->base)) << node->base << "is not a pure math!"; - CHECK(common::IsPureMath(node->stride)) << node->stride << "is not a pure math!"; + CHECK(common::IsPureMath(node->base)) + << node->base << "is not a pure math!"; + CHECK(common::IsPureMath(node->stride)) + << node->stride << "is not a pure math!"; ; Simplify(&node->base); Simplify(&node->stride); } // ramp + ramp void Visit(const Add* op, Expr* expr) override { - auto* node = expr->As(); - Expr a = node->a(); - Expr b = node->b(); + auto* node = expr->As(); + Expr a = node->a(); + Expr b = node->b(); auto a_ramp = a.As(); auto b_ramp = b.As(); if (a_ramp && b_ramp && a_ramp->lanes == b_ramp->lanes) { - Expr base_add = common::AutoSimplify(a_ramp->base + b_ramp->base); + Expr base_add = common::AutoSimplify(a_ramp->base + b_ramp->base); Expr stride_add = common::AutoSimplify(a_ramp->stride + b_ramp->stride); - *expr = ir::Ramp::Make(base_add, stride_add, a_ramp->lanes); + *expr = ir::Ramp::Make(base_add, stride_add, a_ramp->lanes); } } }; @@ -222,7 +239,7 @@ struct SimplifyIfThenElseMutator : public ir::IRMutator<> { using ir::IRMutator<>::Visit; void Visit(const IfThenElse* op, Expr* expr) override { - auto* node = expr->As(); + auto* node = expr->As(); node->condition = common::AutoSimplify(node->condition); if (node->true_case.defined()) Visit(&node->true_case, &node->true_case); @@ -313,13 +330,16 @@ struct SimplifyForLoopsMutator : public ir::IRMutator<> { auto* node = expr->As(); Visit(&node->min, &node->min); Visit(&node->extent, &node->extent); - auto* min_i = node->min.As(); + auto* min_i = node->min.As(); auto* extent_i = node->extent.As(); - if (min_i && extent_i && extent_i->value > min_i->value && extent_i->value - min_i->value == 1) { + if (min_i && extent_i && extent_i->value > min_i->value && + extent_i->value - min_i->value == 1) { VLOG(6) << "Simplify current For Loop"; std::string var_name = node->loop_var->name; - var_intervals.emplace(var_name, common::CasInterval{min_i->value, extent_i->value - 1}); - if (node->body.As() && node->body.As()->stmts.size() == 1) { + var_intervals.emplace( + var_name, common::CasInterval{min_i->value, extent_i->value - 1}); + if (node->body.As() && + node->body.As()->stmts.size() == 1) { *expr = node->body.As()->stmts[0]; } else { *expr = node->body; @@ -336,7 +356,7 @@ struct SimplifyForLoopsMutator : public ir::IRMutator<> { if (var_intervals.count(node->name)) { auto loop_range = var_intervals.at(node->name); - *expr = Expr(loop_range.l); + *expr = Expr(loop_range.l); } } }; diff --git a/paddle/cinn/optim/ir_simplify.h b/paddle/cinn/optim/ir_simplify.h index 972f3d44564ba..1d6abf1cd9a9f 100644 --- a/paddle/cinn/optim/ir_simplify.h +++ b/paddle/cinn/optim/ir_simplify.h @@ -25,7 +25,8 @@ namespace optim { * a*0 => 0 * A[i*0+2*a+3*a+1+2] => A[5*a+3] * - * This only works on the simple IR nodes such as Load, Store, and the math operators such as Add, Sub and so on. + * This only works on the simple IR nodes such as Load, Store, and the math + * operators such as Add, Sub and so on. */ void Simplify(Expr *expr); diff --git a/paddle/cinn/optim/ir_simplify_test.cc b/paddle/cinn/optim/ir_simplify_test.cc index 21d904abb2e77..fd2b5be74d062 100755 --- a/paddle/cinn/optim/ir_simplify_test.cc +++ b/paddle/cinn/optim/ir_simplify_test.cc @@ -52,14 +52,14 @@ TEST(IrSimplify, basic) { auto B = Compute( {Expr(100), Expr(20)}, [&](Expr i, Expr j) { - return x(i + 0, j + 0) + y(i, j * 0) * 1.f + 0.f * x(i, j) + 25.f + 100.f - 0.f + - 9.f * 10000.f * 1.f * 1.f * 0.f; + return x(i + 0, j + 0) + y(i, j * 0) * 1.f + 0.f * x(i, j) + 25.f + + 100.f - 0.f + 9.f * 10000.f * 1.f * 1.f * 0.f; }, "B"); auto stages = CreateStages({B}); - auto func = Lower("func", stages, {B}); - auto body = func->body; + auto func = Lower("func", stages, {B}); + auto body = func->body; LOG(INFO) << "original body:\n" << body; Simplify(&body); @@ -84,7 +84,8 @@ TEST(IrSimplify, basic) { auto B = Compute( {Expr(100), Expr(20)}, [&](Expr i, Expr j) { - return x(100 * 10 * 1 * i + 0, j * 0) + y(i, j * 0) / (1.f + 2.f) + 0.f * x(i, j) + 25.f + 100.f - 0.f + + return x(100 * 10 * 1 * i + 0, j * 0) + y(i, j * 0) / (1.f + 2.f) + + 0.f * x(i, j) + 25.f + 100.f - 0.f + 9.f * 10000.f * 1.f * 1.f * 0.f; }, "B"); @@ -119,7 +120,7 @@ TEST(reverse, prod) { {M, N}, [=](Var i, Var j) { return Expr(1.f) / A(i, j); }, "C"); auto stages = CreateStages({A, C}); - auto fn = Lower("fn", stages, {A, C}); + auto fn = Lower("fn", stages, {A, C}); LOG(INFO) << "fn:\n" << fn; } diff --git a/paddle/cinn/optim/lower_function_call_bind_vars.cc b/paddle/cinn/optim/lower_function_call_bind_vars.cc index 333ae623c620c..90ef5a1606a5a 100644 --- a/paddle/cinn/optim/lower_function_call_bind_vars.cc +++ b/paddle/cinn/optim/lower_function_call_bind_vars.cc @@ -38,10 +38,12 @@ struct LowerFunctionCallBindVarsMutator : public ir::IRMutator<> { auto* node = expr->As(); if (op->is_cinn_call()) { const std::string& target = op->name; - auto it = std::find_if(m_->functions.begin(), m_->functions.end(), [&](const Expr& x) { - return x.as_lowered_func()->name == target; - }); - CHECK(it != m_->functions.end()) << "The called function [" << target << "] is not exist"; + auto it = std::find_if( + m_->functions.begin(), m_->functions.end(), [&](const Expr& x) { + return x.as_lowered_func()->name == target; + }); + CHECK(it != m_->functions.end()) + << "The called function [" << target << "] is not exist"; std::vector extra_var_args; @@ -51,8 +53,11 @@ struct LowerFunctionCallBindVarsMutator : public ir::IRMutator<> { } } - // insert the extra var arguments to the begining of the original call's argument list. - node->read_args.insert(std::begin(op->read_args), extra_var_args.begin(), extra_var_args.end()); + // insert the extra var arguments to the begining of the original call's + // argument list. + node->read_args.insert(std::begin(op->read_args), + extra_var_args.begin(), + extra_var_args.end()); } ir::IRMutator<>::Visit(op, expr); diff --git a/paddle/cinn/optim/lower_intrin.cc b/paddle/cinn/optim/lower_intrin.cc index 7431f6b66d292..f6fce4009c548 100644 --- a/paddle/cinn/optim/lower_intrin.cc +++ b/paddle/cinn/optim/lower_intrin.cc @@ -44,9 +44,17 @@ void LowerIntrin(Expr *e, Target target) { Expr ret; if (node->type().is_float()) { if (const ir::Mul *mul = node->b().As()) { - ret = ir::Call::Make(node->type(), "fma", {mul->a(), mul->b(), node->a()}, {}, ir::CallType::Intrinsic); + ret = ir::Call::Make(node->type(), + "fma", + {mul->a(), mul->b(), node->a()}, + {}, + ir::CallType::Intrinsic); } else if (const ir::Mul *mul = node->a().As()) { - ret = ir::Call::Make(node->type(), "fma", {mul->a(), mul->b(), node->b()}, {}, ir::CallType::Intrinsic); + ret = ir::Call::Make(node->type(), + "fma", + {mul->a(), mul->b(), node->b()}, + {}, + ir::CallType::Intrinsic); } if (ret.defined()) { ir::IRMutator<>::Visit(&ret, &ret); diff --git a/paddle/cinn/optim/lower_intrin.h b/paddle/cinn/optim/lower_intrin.h index 86ac60bda9a84..2880caf4d056c 100644 --- a/paddle/cinn/optim/lower_intrin.h +++ b/paddle/cinn/optim/lower_intrin.h @@ -23,10 +23,12 @@ namespace cinn { namespace optim { static const std::set kIntrinsicCalls{ - {"exp", "exp2", "sqrt", "log", "log2", "log10", "floor", - "ceil", "round", "trunc", "cos", "cosh", "tan", "tanh", - "sin", "sinh", "fabs", "isnan", "isfinite", "isinf", "left_shift", - "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not", "fma", "rsqrt"}}; + {"exp", "exp2", "sqrt", "log", "log2", + "log10", "floor", "ceil", "round", "trunc", + "cos", "cosh", "tan", "tanh", "sin", + "sinh", "fabs", "isnan", "isfinite", "isinf", + "left_shift", "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", + "bitwise_not", "fma", "rsqrt"}}; /** * Map the Call nodes to llvm intrinsic. diff --git a/paddle/cinn/optim/map_extern_call.cc b/paddle/cinn/optim/map_extern_call.cc index 024a8e8385aee..3a9531391ca9d 100644 --- a/paddle/cinn/optim/map_extern_call.cc +++ b/paddle/cinn/optim/map_extern_call.cc @@ -23,9 +23,11 @@ namespace cinn { namespace optim { static const std::set kExternFp32CallsGPU{ - {"exp", "erf", "sigmoid", "sqrt", "log", "log2", "log10", "floor", "ceil", "round", "trunc", - "cos", "cosh", "tan", "sin", "sinh", "acos", "acosh", "asin", "asinh", "atan", "atanh", - "isnan", "tanh", "isfinite", "isinf", "remainder", "rsqrt", "cbrt", "abs", "pow", "mod"}}; + {"exp", "erf", "sigmoid", "sqrt", "log", "log2", "log10", + "floor", "ceil", "round", "trunc", "cos", "cosh", "tan", + "sin", "sinh", "acos", "acosh", "asin", "asinh", "atan", + "atanh", "isnan", "tanh", "isfinite", "isinf", "remainder", "rsqrt", + "cbrt", "abs", "pow", "mod"}}; static const std::set kExternInt32CallsGPU{{"left_shift", "right_shift", @@ -65,10 +67,11 @@ void MapExternCall(Expr *e, Target target) { if (kExternFp32CallsCPU.count(node->name)) { CHECK_GE(node->read_args.size(), 1UL); CHECK(node->read_args.front().type().is_float()) - << "CPU extern call instrinsices only support float now! Please check."; + << "CPU extern call instrinsices only support float now! Please " + "check."; if (node->read_args.front().type().is_float(32)) { auto out_type = node->type(); - *expr = lang::CallExtern(node->name + "f", node->read_args); + *expr = lang::CallExtern(node->name + "f", node->read_args); } } } @@ -80,16 +83,17 @@ void MapExternCall(Expr *e, Target target) { return; } const auto &dtype = node->read_args.front().type(); - const auto &name = node->name; + const auto &name = node->name; - bool node_in_extern_fp32 = kExternFp32CallsGPU.count(name); + bool node_in_extern_fp32 = kExternFp32CallsGPU.count(name); bool node_in_extern_int32 = kExternInt32CallsGPU.count(name); if (!node_in_extern_fp32 && !node_in_extern_int32) { return; } - std::string extern_func = hlir::GetExternFuncName(common::DefaultNVGPUTarget(), dtype, name); - *expr = lang::CallExtern(extern_func, node->read_args, node->attrs); + std::string extern_func = + hlir::GetExternFuncName(common::DefaultNVGPUTarget(), dtype, name); + *expr = lang::CallExtern(extern_func, node->read_args, node->attrs); } // Replace pow(x, 0.5) to sqrt(x) and pow(x, -0.5) to rsqrt(x), which @@ -98,7 +102,8 @@ void MapExternCall(Expr *e, Target target) { // Reference: // https://en.wikipedia.org/wiki/Fast_inverse_square_root void OptimizeConstantPow(ir::Call *node) { - if (node->name == "pow" && node->read_args.size() >= 2 && node->read_args[1].is_constant()) { + if (node->name == "pow" && node->read_args.size() >= 2 && + node->read_args[1].is_constant()) { float pow_constant = node->read_args[1].get_constant(); if (pow_constant == 0.5) { node->name = "sqrt"; diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index afd91b84be5d1..4899c73254ea0 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -42,7 +42,10 @@ DECLARE_bool(cinn_ir_schedule); namespace cinn { namespace optim { -Expr Optimize(Expr e, Target target, bool runtime_debug_info, bool remove_gpu_for_loops) { +Expr Optimize(Expr e, + Target target, + bool runtime_debug_info, + bool remove_gpu_for_loops) { CHECK(e.defined()); auto copied = IRCopy(e); diff --git a/paddle/cinn/optim/optimize.h b/paddle/cinn/optim/optimize.h index 7b7a0afbff672..e5beeb095781b 100644 --- a/paddle/cinn/optim/optimize.h +++ b/paddle/cinn/optim/optimize.h @@ -25,7 +25,10 @@ namespace optim { * @param runtime_debug_info * @return */ -Expr Optimize(Expr e, Target target, bool runtime_debug_info = false, bool remove_gpu_for_loops = true); +Expr Optimize(Expr e, + Target target, + bool runtime_debug_info = false, + bool remove_gpu_for_loops = true); /** * Optimize a Module. diff --git a/paddle/cinn/optim/remove_nested_block.cc b/paddle/cinn/optim/remove_nested_block.cc index a748b53867f18..de42ad46f03fc 100644 --- a/paddle/cinn/optim/remove_nested_block.cc +++ b/paddle/cinn/optim/remove_nested_block.cc @@ -33,7 +33,8 @@ Expr GetExprInsideBlock(Expr op) { return node; } -// This will remove the nested blocks, but it will also remove the block outside the forloop's body. +// This will remove the nested blocks, but it will also remove the block outside +// the forloop's body. struct NestedBlockSimplifer : public ir::IRMutator { void operator()(ir::Expr* expr) { Visit(expr); } @@ -67,7 +68,8 @@ struct NestedBlockRemover : public ir::IRMutator { auto* block = it->As(); if (block) { detect_nested = true; - new_exprs.insert(std::end(new_exprs), block->stmts.begin(), block->stmts.end()); + new_exprs.insert( + std::end(new_exprs), block->stmts.begin(), block->stmts.end()); } else { new_exprs.push_back(*it); } diff --git a/paddle/cinn/optim/remove_nested_block_test.cc b/paddle/cinn/optim/remove_nested_block_test.cc index 3ae1948a703a9..b91c8204e242f 100644 --- a/paddle/cinn/optim/remove_nested_block_test.cc +++ b/paddle/cinn/optim/remove_nested_block_test.cc @@ -28,7 +28,7 @@ namespace optim { TEST(RemoveNestedBlock, basic) { auto block0 = ir::Block::Make({Expr(1.f), Expr(1.f)}); auto block1 = ir::Block::Make({block0}); - auto e = Expr(block1); + auto e = Expr(block1); std::string origin = utils::GetStreamCnt(e); EXPECT_EQ(origin, utils::Trim(R"ROC( diff --git a/paddle/cinn/optim/remove_schedule_block.cc b/paddle/cinn/optim/remove_schedule_block.cc index 853c13811e2d6..007174801550d 100644 --- a/paddle/cinn/optim/remove_schedule_block.cc +++ b/paddle/cinn/optim/remove_schedule_block.cc @@ -30,11 +30,11 @@ struct ScheduleBlockRemover : public ir::IRMutator { void Visit(const ir::ScheduleBlockRealize* op, Expr* expr) override { auto* node = expr->As(); CHECK(node); - auto& iter_values = node->iter_values; + auto& iter_values = node->iter_values; auto* schedule_block = node->schedule_block.As(); CHECK(schedule_block); auto& iter_vars = schedule_block->iter_vars; - Expr body = schedule_block->body; + Expr body = schedule_block->body; CHECK_EQ(iter_vars.size(), iter_values.size()); for (int i = 0; i < iter_vars.size(); i++) { optim::ReplaceVarWithExpr(&body, iter_vars[i], iter_values[i]); diff --git a/paddle/cinn/optim/remove_schedule_block_test.cc b/paddle/cinn/optim/remove_schedule_block_test.cc index 556d10e4275fd..4fd2d7999e426 100755 --- a/paddle/cinn/optim/remove_schedule_block_test.cc +++ b/paddle/cinn/optim/remove_schedule_block_test.cc @@ -38,9 +38,11 @@ TEST(RemovescheduleBlock, basic) { // C = A * B Var k(20, "k0"); Tensor C = Compute( - {Expr(100), Expr(50)}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + {Expr(100), Expr(50)}, + [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, + "C"); auto stages = CreateStages({A, B, C}); - auto func = Lower("matmul", stages, {A, B, C}, {}, {}, nullptr, target, true); + auto func = Lower("matmul", stages, {A, B, C}, {}, {}, nullptr, target, true); LOG(INFO) << "func\n" << func; std::string origin = utils::GetStreamCnt(func); diff --git a/paddle/cinn/optim/replace_call_with_expr.cc b/paddle/cinn/optim/replace_call_with_expr.cc index 388f5ae076e88..3f1344fd6f021 100644 --- a/paddle/cinn/optim/replace_call_with_expr.cc +++ b/paddle/cinn/optim/replace_call_with_expr.cc @@ -23,7 +23,8 @@ namespace cinn { namespace optim { struct ReplaceCallWithExprModifier : public ir::IRMutator<> { - ReplaceCallWithExprModifier(const std::string &statement, const Expr &candidate) + ReplaceCallWithExprModifier(const std::string &statement, + const Expr &candidate) : statement_(statement), candidate_(candidate) {} void operator()(Expr *e) { IRMutator<>::Visit(e, e); } @@ -49,7 +50,9 @@ struct ReplaceCallWithExprModifier : public ir::IRMutator<> { const Expr &candidate_; }; -void ReplaceCallWithExpr(Expr *e, const std::string &statement, const Expr &candidate) { +void ReplaceCallWithExpr(Expr *e, + const std::string &statement, + const Expr &candidate) { ReplaceCallWithExprModifier modifier(statement, candidate); modifier(e); } @@ -62,16 +65,17 @@ void ReplaceIslCallWithExpr(Expr *e, Expr copied = IRCopy(candidate); // update the axis in the copied expression. - // we treat the Store node as the normal statement, the others like Call node has no axis. + // we treat the Store node as the normal statement, the others like Call node + // has no axis. std::map local_axis; std::vector origin_axes; std::map new_axis_map = axis_map; for (auto &item : axis_map) { origin_axes.push_back(item.first); } - // Add '_after' to the transformed var's name to avoid duplicating transforming. - // For example, given indices [i,j], if we want to switch 'i' and 'j'(i->j, j->i) - // When we don't add '_after', the processing will be : + // Add '_after' to the transformed var's name to avoid duplicating + // transforming. For example, given indices [i,j], if we want to switch 'i' + // and 'j'(i->j, j->i) When we don't add '_after', the processing will be : // 1. [i,j] to [j,j] // 2. [j,j] to [i,i] // Then we get result [i,i], which is different form the correct result [j,i] @@ -92,23 +96,28 @@ void ReplaceIslCallWithExpr(Expr *e, if (indice.is_var() || indice.is_constant()) { if (!new_axis_map.count(std::to_string(i))) continue; if (!indice.is_constant()) { - local_axis[indice.as_var()->name] = new_axis_map.at(std::to_string(i)); + local_axis[indice.as_var()->name] = + new_axis_map.at(std::to_string(i)); } } } - // the store indices just contains the ones of transform's domain, not the range. - // e.g. { s[i,j] -> s[i0,i1,j]: i0=i/4 and i1=i%4 }, the store's indices just contains i,j while in the final code, - // the axis are from the range, that is, there are some new axis not exists in store->indice, i0 and i1. + // the store indices just contains the ones of transform's domain, not the + // range. e.g. { s[i,j] -> s[i0,i1,j]: i0=i/4 and i1=i%4 }, the store's + // indices just contains i,j while in the final code, the axis are from the + // range, that is, there are some new axis not exists in store->indice, i0 + // and i1. } for (auto &laxis : local_axis) { - VLOG(3) << "local_axis Replacing axis: " << laxis.first << " to " << laxis.second; + VLOG(3) << "local_axis Replacing axis: " << laxis.first << " to " + << laxis.second; ReplaceVarWithExpr(&copied, Var(laxis.first), laxis.second); } // replace the remaining axis(in the transform's range) for (auto &item : new_axis_map) { if (!local_axis.count(item.first)) { - VLOG(3) << "new_axis_map Replacing axis: " << item.first << " to " << item.second; + VLOG(3) << "new_axis_map Replacing axis: " << item.first << " to " + << item.second; ReplaceVarWithExpr(&copied, Var(item.first), item.second); } } @@ -117,7 +126,8 @@ void ReplaceIslCallWithExpr(Expr *e, ReplaceVarWithExpr(&copied, Var(axis + "_after"), Expr(Var(axis))); } - VLOG(3) << "After replacing, the statement [" << statement << "] is : " << copied; + VLOG(3) << "After replacing, the statement [" << statement + << "] is : " << copied; ReplaceCallWithExpr(e, statement, copied); } diff --git a/paddle/cinn/optim/replace_call_with_expr.h b/paddle/cinn/optim/replace_call_with_expr.h index 219cb9984736b..d8e54865f2703 100644 --- a/paddle/cinn/optim/replace_call_with_expr.h +++ b/paddle/cinn/optim/replace_call_with_expr.h @@ -27,7 +27,9 @@ namespace optim { * @param statement The map from tuple_name to the expression candidate. * @param candidate Var of each axis in the expression candidate. */ -void ReplaceCallWithExpr(Expr *e, const std::string &statement, const Expr &candidate); +void ReplaceCallWithExpr(Expr *e, + const std::string &statement, + const Expr &candidate); /** * Replace a Call node with a Expr (inline). diff --git a/paddle/cinn/optim/replace_const_param_to_integer.cc b/paddle/cinn/optim/replace_const_param_to_integer.cc index 4e14f8c74c4aa..ad72439a1631a 100644 --- a/paddle/cinn/optim/replace_const_param_to_integer.cc +++ b/paddle/cinn/optim/replace_const_param_to_integer.cc @@ -28,7 +28,7 @@ struct Mutator : public ir::IRMutator<> { void Visit(const ir::_Var_* op, Expr* expr) override { if (utils::Startswith(op->name, poly::kIslParamConstPrefix)) { std::string value = op->name.substr(strlen(poly::kIslParamConstPrefix)); - *expr = Expr(std::stoi(value)); + *expr = Expr(std::stoi(value)); } } }; diff --git a/paddle/cinn/optim/replace_const_param_to_integer.h b/paddle/cinn/optim/replace_const_param_to_integer.h index 496540213fa01..a1bcb822b3df2 100644 --- a/paddle/cinn/optim/replace_const_param_to_integer.h +++ b/paddle/cinn/optim/replace_const_param_to_integer.h @@ -18,7 +18,8 @@ namespace cinn::optim { /** - * Replace the constant parameter(included in ISL param) to the corresponding integer. + * Replace the constant parameter(included in ISL param) to the corresponding + * integer. * * e.g. * diff --git a/paddle/cinn/optim/replace_var_with_expr.cc b/paddle/cinn/optim/replace_var_with_expr.cc index 6a1cece5be7a2..10dc22c80b097 100644 --- a/paddle/cinn/optim/replace_var_with_expr.cc +++ b/paddle/cinn/optim/replace_var_with_expr.cc @@ -28,7 +28,9 @@ namespace cinn { namespace optim { struct ReplaceVarWithExprMutator : public ir::IRMutator<> { - ReplaceVarWithExprMutator(const Var& var, const Expr& expr, const std::string& tensor_name) + ReplaceVarWithExprMutator(const Var& var, + const Expr& expr, + const std::string& tensor_name) : var_(var), expr_(expr), tensor_name_(tensor_name) {} void operator()(Expr* expr) { @@ -40,7 +42,7 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<> { void Visit(const ir::_Var_* expr, Expr* op) override { if (expr->name == var_->name && (do_replace_ || visit_all_)) { auto copied = IRCopy(expr_); - *op = copied; + *op = copied; } } @@ -49,7 +51,8 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<> { ir::IRMutator<>::Visit(&node->min, &node->min); ir::IRMutator<>::Visit(&node->extent, &node->extent); ir::IRMutator<>::Visit(&node->body, &node->body); - if (node->loop_var->name == var_->name && expr_.As() && visit_all_) { + if (node->loop_var->name == var_->name && expr_.As() && + visit_all_) { node->loop_var = expr_.As(); } } @@ -60,13 +63,14 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<> { ir::IRMutator<>::Visit(&node->condition, &node->condition); ir::IRMutator<>::Visit(&node->inc, &node->inc); ir::IRMutator<>::Visit(&node->body, &node->body); - if (node->iterator->name == var_->name && expr_.As() && visit_all_) { + if (node->iterator->name == var_->name && expr_.As() && + visit_all_) { node->iterator = expr_.As(); } } void Visit(const ir::Store* op, Expr* expr) override { - auto* node = expr->As(); + auto* node = expr->As(); auto* tensor = node->tensor.as_tensor(); if (tensor->name == tensor_name_) { @@ -83,7 +87,7 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<> { } void Visit(const ir::Load* expr, Expr* op) override { - auto* node = op->As(); + auto* node = op->As(); auto* tensor = node->tensor.as_tensor(); if (tensor->name == tensor_name_) { do_replace_ = true; @@ -103,13 +107,17 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<> { const std::string& tensor_name_; }; -void ReplaceVarWithExpr(Expr* source, const Var& var, const Expr& expr, const std::string& tensor_name) { +void ReplaceVarWithExpr(Expr* source, + const Var& var, + const Expr& expr, + const std::string& tensor_name) { ReplaceVarWithExprMutator mutator(var, expr, tensor_name); mutator(source); } struct CollectTensorIndexMutator : public ir::IRMutator<> { - CollectTensorIndexMutator(const std::string& tensor_name) : tensor_name_(tensor_name) {} + CollectTensorIndexMutator(const std::string& tensor_name) + : tensor_name_(tensor_name) {} std::vector> operator()(Expr* expr) { IRMutator::Visit(expr, expr); @@ -128,7 +136,7 @@ struct CollectTensorIndexMutator : public ir::IRMutator<> { } void Visit(const ir::Load* expr, Expr* op) override { - auto* node = op->As(); + auto* node = op->As(); auto* tensor = node->tensor.as_tensor(); if (tensor->name == tensor_name_) { ir::IRMutator<>::Visit(&node->tensor, &node->tensor); @@ -144,7 +152,8 @@ struct CollectTensorIndexMutator : public ir::IRMutator<> { const std::string& tensor_name_; }; -std::vector> CollectTensorIndex(Expr* source, const std::string& tensor_name) { +std::vector> CollectTensorIndex( + Expr* source, const std::string& tensor_name) { CollectTensorIndexMutator mutator(tensor_name); std::vector> result = mutator(source); for (auto& i : result) { diff --git a/paddle/cinn/optim/replace_var_with_expr.h b/paddle/cinn/optim/replace_var_with_expr.h index 6f99de5dc0002..c56848f358052 100644 --- a/paddle/cinn/optim/replace_var_with_expr.h +++ b/paddle/cinn/optim/replace_var_with_expr.h @@ -27,8 +27,9 @@ namespace optim { * Replace the variable with a expression. * @param var The variable to replace. * @param expr The candidate expression. - * @param tensor_name Name of the tensor whose indices will be edited. If it is empty, means we will - * do the replace in all Expr instead of only in specific tensor's indices. + * @param tensor_name Name of the tensor whose indices will be edited. If it is + * empty, means we will do the replace in all Expr instead of only in specific + * tensor's indices. */ /** * Example 1: ReplaceVarWithExpr(source, Var("i"), Expr(0), "A") @@ -53,12 +54,16 @@ namespace optim { * for(j, 0, 10) * B[k,j] = A[k,j] */ -void ReplaceVarWithExpr(Expr *source, const Var &var, const Expr &expr, const std::string &tensor_name = ""); +void ReplaceVarWithExpr(Expr *source, + const Var &var, + const Expr &expr, + const std::string &tensor_name = ""); /** * Collect the specific tensor's indices. * @param tensor_name The specific tensor's name. - * @return Return a vector containing all the indices of the specific tensor appeared in source. + * @return Return a vector containing all the indices of the specific tensor + * appeared in source. */ /** * Example: CollectTensorIndex(source, "A") @@ -71,7 +76,8 @@ void ReplaceVarWithExpr(Expr *source, const Var &var, const Expr &expr, const st * Return value: * {{i,j},{0,j}} */ -std::vector> CollectTensorIndex(Expr *source, const std::string &tensor_name); +std::vector> CollectTensorIndex( + Expr *source, const std::string &tensor_name); } // namespace optim } // namespace cinn diff --git a/paddle/cinn/optim/tensor_write_tell.h b/paddle/cinn/optim/tensor_write_tell.h index 6a5ba1fd03877..6e7b8b9f48f32 100644 --- a/paddle/cinn/optim/tensor_write_tell.h +++ b/paddle/cinn/optim/tensor_write_tell.h @@ -26,12 +26,16 @@ struct TensorWriteTeller : public ir::IRMutator { //! Collect the write info in \p op. void Collect(const Expr* op) { Visit(op, op); } - bool IsWrite(const std::string& tensor_name) const { return tensor_written.count(tensor_name); } + bool IsWrite(const std::string& tensor_name) const { + return tensor_written.count(tensor_name); + } private: std::set tensor_written; - void Visit(const Expr* expr, const Expr* op) override { IRMutator::Visit(expr, op); } + void Visit(const Expr* expr, const Expr* op) override { + IRMutator::Visit(expr, op); + } void Visit(const ir::Store* expr, const Expr* op) override { auto* node = op->As(); diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc index e99f2fdd81399..a95d0a3425af9 100644 --- a/paddle/cinn/optim/transform_gpu_forloop.cc +++ b/paddle/cinn/optim/transform_gpu_forloop.cc @@ -39,14 +39,16 @@ namespace optim { /** * 1. Determine the grid and block dimensions. - * It takes the domains like `[0, 20]` or `[0, min(20, M/2)]`, the domain should have a integer right bound. + * It takes the domains like `[0, 20]` or `[0, min(20, M/2)]`, the domain should + * have a integer right bound. * - * 2. Replace the grid/thread iterators with something like `threadIdx.x`, `threadIdx.y`. + * 2. Replace the grid/thread iterators with something like `threadIdx.x`, + * `threadIdx.y`. * * 3. Remove the forloops owning the gpu axis. * 1. if the extent is an IntImm, just remove this forloop. - * 2. if the extent is a Min, replace the forloop with an IfThenElse, with forloop's condition, new check will add (if - * the min of forloop is not zero). + * 2. if the extent is a Min, replace the forloop with an IfThenElse, with + * forloop's condition, new check will add (if the min of forloop is not zero). * * @param expr The expression to mutate. */ @@ -80,10 +82,12 @@ void RemoveGpuForloopsAxis(Expr *expr) { } } - bool NeedToReplaceForloopWithIfThenElse(const ir::For *n) const { return true; } + bool NeedToReplaceForloopWithIfThenElse(const ir::For *n) const { + return true; + } void ReplaceForloopWithIfThenElse(Expr *expr) { - auto *for_n = expr->As(); + auto *for_n = expr->As(); auto *poly_for_n = expr->As(); CHECK(for_n || poly_for_n); @@ -109,7 +113,8 @@ void RemoveGpuForloopsAxis(Expr *expr) { condition_append(ir::LT::Make(for_n->loop_var, for_n->extent)); } else { if (poly_for_n->init != common::make_const(0)) { - condition_append(ir::GE::Make(poly_for_n->iterator, poly_for_n->init)); + condition_append( + ir::GE::Make(poly_for_n->iterator, poly_for_n->init)); } condition_append(poly_for_n->condition); @@ -125,7 +130,8 @@ void RemoveGpuForloopsAxis(Expr *expr) { } void Visit(const ir::PolyFor *op, Expr *expr) override { - const auto msg = "PolyFor is not allowed for GPU, only For nodes are allowed"; + const auto msg = + "PolyFor is not allowed for GPU, only For nodes are allowed"; CHECK(op->for_type() != ir::ForType::GPUBlock) << msg; CHECK(op->for_type() != ir::ForType::GPUThread) << msg; CHECK(op->for_type() != ir::ForType::GPULane) << msg; @@ -137,8 +143,9 @@ void RemoveGpuForloopsAxis(Expr *expr) { } /** - * The generated __syncthreads call will be wrapped with a `if (xxxx == 0) { }`, this is the problem of isl AST output, - * drop it to make it run in all the threads. + * The generated __syncthreads call will be wrapped with a `if (xxxx == 0) { }`, + * this is the problem of isl AST output, drop it to make it run in all the + * threads. */ void CudaSyncThreadsDropIfThenElse(Expr *expr) { struct Mutator : public ir::IRMutator<> { @@ -202,11 +209,15 @@ class ReplaceIndexToBindExpr : public ir::IRMutator<> { private: void Visit(const ir::ScheduleBlockRealize *op, Expr *expr) override { - ir::ScheduleBlockRealize *schedule_block_realize = expr->As(); + ir::ScheduleBlockRealize *schedule_block_realize = + expr->As(); CHECK(schedule_block_realize->schedule_block.As()); std::vector iter_values = schedule_block_realize->iter_values; - ir::Expr body = schedule_block_realize->schedule_block.As()->body; - std::vector iter_vars = schedule_block_realize->schedule_block.As()->iter_vars; + ir::Expr body = + schedule_block_realize->schedule_block.As()->body; + std::vector iter_vars = + schedule_block_realize->schedule_block.As() + ->iter_vars; CHECK_EQ(iter_values.size(), iter_vars.size()); for (int idx = 0; idx < iter_values.size(); ++idx) { @@ -225,11 +236,14 @@ class CollectTensorLoopVisitor : public ir::IRMutator<> { void Visit(const ir::Store *op, Expr *expr) override { auto tensor = op->tensor.as_tensor_ref(); // if buffer defined and buffer is not Heap. - if (tensor->buffer.defined() && tensor->buffer->memory_type != ir::MemoryType::Heap) { + if (tensor->buffer.defined() && + tensor->buffer->memory_type != ir::MemoryType::Heap) { if (buffer_tensor_loop_map_.count(tensor->buffer->name)) { - buffer_tensor_loop_map_[tensor->buffer->name].push_back(std::make_pair(*expr, loops_)); + buffer_tensor_loop_map_[tensor->buffer->name].push_back( + std::make_pair(*expr, loops_)); } else { - buffer_tensor_loop_map_[tensor->buffer->name] = {std::make_pair(*expr, loops_)}; + buffer_tensor_loop_map_[tensor->buffer->name] = { + std::make_pair(*expr, loops_)}; } } @@ -242,11 +256,14 @@ class CollectTensorLoopVisitor : public ir::IRMutator<> { } auto tensor = op->tensor.as_tensor_ref(); // if buffer defined and buffer is not Heap. - if (tensor->buffer.defined() && tensor->buffer->memory_type != ir::MemoryType::Heap) { + if (tensor->buffer.defined() && + tensor->buffer->memory_type != ir::MemoryType::Heap) { if (buffer_tensor_loop_map_.count(tensor->buffer->name)) { - buffer_tensor_loop_map_[tensor->buffer->name].push_back(std::make_pair(*expr, loops_)); + buffer_tensor_loop_map_[tensor->buffer->name].push_back( + std::make_pair(*expr, loops_)); } else { - buffer_tensor_loop_map_[tensor->buffer->name] = {std::make_pair(*expr, loops_)}; + buffer_tensor_loop_map_[tensor->buffer->name] = { + std::make_pair(*expr, loops_)}; } } @@ -259,11 +276,14 @@ class CollectTensorLoopVisitor : public ir::IRMutator<> { loops_.pop_back(); } - void Visit(const ir::PolyFor *op, Expr *expr) override { LOG(FATAL) << "Unkown PolyFor!"; } + void Visit(const ir::PolyFor *op, Expr *expr) override { + LOG(FATAL) << "Unkown PolyFor!"; + } public: std::vector loops_; - std::unordered_map> buffer_tensor_loop_map_; + std::unordered_map> + buffer_tensor_loop_map_; }; void UpdateBufferAxisPass(ir::Expr *expr) { @@ -276,10 +296,12 @@ void UpdateBufferAxisPass(ir::Expr *expr) { auto tensor_loop_v = tmp.second; auto &front = tensor_loop_v.front(); - int count = tensor_loop_v.size() > 1 ? front.second.size() : 0; + int count = tensor_loop_v.size() > 1 ? front.second.size() : 0; for (int idx = 1; idx < tensor_loop_v.size(); ++idx) { auto &other = tensor_loop_v[idx]; - for (int idy = 0; idy < std::min(front.second.size(), other.second.size()); ++idy) { + for (int idy = 0; + idy < std::min(front.second.size(), other.second.size()); + ++idy) { if (front.second[idy] != other.second[idy]) { count = std::min(count, idy); break; @@ -289,7 +311,8 @@ void UpdateBufferAxisPass(ir::Expr *expr) { auto get_thread_bind_var = [](const std::vector &loops) { // threadidx loop_var,extent. - using ThreadLoopVarExtentMap = std::unordered_map>; + using ThreadLoopVarExtentMap = + std::unordered_map>; ThreadLoopVarExtentMap thread_loop_var_exent_map; for (auto loop : loops) { auto loop_ir = loop.As(); @@ -307,10 +330,12 @@ void UpdateBufferAxisPass(ir::Expr *expr) { if (thread_loop_var_exent_map.count(axis)) { auto &loop_var_extent = thread_loop_var_exent_map[axis]; if (loop_var_extent.second >= loop_ir->extent.as_int32()) { - thread_loop_var_exent_map[axis] = std::make_pair(loop_ir->loop_var->name, loop_ir->extent.as_int32()); + thread_loop_var_exent_map[axis] = std::make_pair( + loop_ir->loop_var->name, loop_ir->extent.as_int32()); } } else { - thread_loop_var_exent_map[axis] = std::make_pair(loop_ir->loop_var->name, loop_ir->extent.as_int32()); + thread_loop_var_exent_map[axis] = std::make_pair( + loop_ir->loop_var->name, loop_ir->extent.as_int32()); } } } @@ -323,9 +348,10 @@ void UpdateBufferAxisPass(ir::Expr *expr) { return loop_var_map; }; - auto load = front.first.As(); - auto store = front.first.As(); - auto tensor = load ? load->tensor.as_tensor_ref() : store->tensor.as_tensor_ref(); + auto load = front.first.As(); + auto store = front.first.As(); + auto tensor = + load ? load->tensor.as_tensor_ref() : store->tensor.as_tensor_ref(); // find store and load keep loop for shared std::vector> keep_loop_vars; if (tensor->buffer->memory_type == ir::MemoryType::GPUShared) { @@ -338,14 +364,15 @@ void UpdateBufferAxisPass(ir::Expr *expr) { auto &loops = front.second; for (int idx = 0; idx < count; ++idx) { auto loop_expr = loops[idx]; - auto loop_ir = loop_expr.As(); - auto loop_var = loop_ir->loop_var; + auto loop_ir = loop_expr.As(); + auto loop_var = loop_ir->loop_var; for (int idy = 0; idy < tensor_loop_v.size(); ++idy) { - auto expr = tensor_loop_v[idy].first; - auto load = expr.As(); + auto expr = tensor_loop_v[idy].first; + auto load = expr.As(); auto store = expr.As(); - if (keep_loop_vars.size() == 0 || !keep_loop_vars[idy].count(loop_var->name)) { + if (keep_loop_vars.size() == 0 || + !keep_loop_vars[idy].count(loop_var->name)) { auto &indices = load ? load->indices : store->indices; for (auto &indice : indices) { optim::ReplaceVarWithExpr(&indice, loop_var, ir::Expr(0)); @@ -377,15 +404,19 @@ class ReplaceLoopVarToGpu : public ir::IRMutator<> { var_name = "z"; if (for_ir->is_gpu_block_binded()) { var_name = "blockIdx." + var_name; - optim::ReplaceVarWithExpr(expr, op->loop_var, ir::Expr(ir::Var(var_name))); + optim::ReplaceVarWithExpr( + expr, op->loop_var, ir::Expr(ir::Var(var_name))); } else if (for_ir->is_gpu_thread_binded()) { var_name = "threadIdx." + var_name; - optim::ReplaceVarWithExpr(expr, op->loop_var, ir::Expr(ir::Var(var_name))); + optim::ReplaceVarWithExpr( + expr, op->loop_var, ir::Expr(ir::Var(var_name))); } ir::IRMutator<>::Visit(&for_ir->body, &for_ir->body); } - void Visit(const ir::PolyFor *op, Expr *expr) override { LOG(FATAL) << "Unkown PolyFor!"; } + void Visit(const ir::PolyFor *op, Expr *expr) override { + LOG(FATAL) << "Unkown PolyFor!"; + } }; class SharedAxisVisitor : public ir::IRMutator<> { @@ -399,7 +430,8 @@ class SharedAxisVisitor : public ir::IRMutator<> { return; } - if (store->tensor.as_tensor_ref()->buffer->memory_type == ir::MemoryType::GPUShared) { + if (store->tensor.as_tensor_ref()->buffer->memory_type == + ir::MemoryType::GPUShared) { for (auto &indice : store->indices) { for (auto axis : gpu_axis) { optim::ReplaceVarWithExpr(&indice, ir::Var(axis), ir::Expr(0)); @@ -419,7 +451,8 @@ class SharedAxisVisitor : public ir::IRMutator<> { return; } - if (load->tensor.as_tensor_ref()->buffer->memory_type == ir::MemoryType::GPUShared) { + if (load->tensor.as_tensor_ref()->buffer->memory_type == + ir::MemoryType::GPUShared) { for (auto &indice : load->indices) { for (auto axis : gpu_axis) { optim::ReplaceVarWithExpr(&indice, ir::Var(axis), ir::Expr(0)); @@ -430,7 +463,8 @@ class SharedAxisVisitor : public ir::IRMutator<> { ir::IRMutator<>::Visit(op, expr); } - const std::vector gpu_axis = {"blockIdx.x", "blockIdx.y", "blockIdx.z"}; + const std::vector gpu_axis = { + "blockIdx.x", "blockIdx.y", "blockIdx.z"}; }; class LocalAxisVisitor : public ir::IRMutator<> { @@ -444,7 +478,8 @@ class LocalAxisVisitor : public ir::IRMutator<> { return; } - if (store->tensor.as_tensor_ref()->buffer->memory_type == ir::MemoryType::GPULocal) { + if (store->tensor.as_tensor_ref()->buffer->memory_type == + ir::MemoryType::GPULocal) { for (auto &indice : store->indices) { for (auto axis : gpu_axis) { optim::ReplaceVarWithExpr(&indice, ir::Var(axis), ir::Expr(0)); @@ -464,7 +499,8 @@ class LocalAxisVisitor : public ir::IRMutator<> { return; } - if (load->tensor.as_tensor_ref()->buffer->memory_type == ir::MemoryType::GPULocal) { + if (load->tensor.as_tensor_ref()->buffer->memory_type == + ir::MemoryType::GPULocal) { for (auto &indice : load->indices) { for (auto axis : gpu_axis) { optim::ReplaceVarWithExpr(&indice, ir::Var(axis), ir::Expr(0)); @@ -475,8 +511,12 @@ class LocalAxisVisitor : public ir::IRMutator<> { ir::IRMutator<>::Visit(op, expr); } - const std::vector gpu_axis = { - "blockIdx.x", "blockIdx.y", "blockIdx.z", "threadIdx.x", "threadIdx.y", "threadIdx.z"}; + const std::vector gpu_axis = {"blockIdx.x", + "blockIdx.y", + "blockIdx.z", + "threadIdx.x", + "threadIdx.y", + "threadIdx.z"}; }; class ResizeBufferSizeVisitor : public ir::IRMutator<> { @@ -485,7 +525,7 @@ class ResizeBufferSizeVisitor : public ir::IRMutator<> { private: void Visit(const ir::Store *op, Expr *expr) override { - auto store = expr->As(); + auto store = expr->As(); auto store_tensor = store->tensor.as_tensor_ref(); if (!store_tensor->buffer.defined()) { @@ -497,8 +537,8 @@ class ResizeBufferSizeVisitor : public ir::IRMutator<> { } auto &indices = store->indices; - auto &shape = store_tensor->shape; - auto &buffer = store_tensor->buffer->shape; + auto &shape = store_tensor->shape; + auto &buffer = store_tensor->buffer->shape; shape.clear(); buffer.clear(); @@ -515,15 +555,18 @@ class ResizeBufferSizeVisitor : public ir::IRMutator<> { return; } - if (load->tensor.as_tensor_ref()->buffer->memory_type == ir::MemoryType::Heap) { + if (load->tensor.as_tensor_ref()->buffer->memory_type == + ir::MemoryType::Heap) { ir::IRMutator<>::Visit(op, expr); return; } - load->tensor.as_tensor_ref()->shape = load->tensor.as_tensor_ref()->buffer->shape; + load->tensor.as_tensor_ref()->shape = + load->tensor.as_tensor_ref()->buffer->shape; - // For the moment, align the load tensor indices with the tensor shape using the trick method. - // A better way would be to modify the FlattenLoop Schedule. + // For the moment, align the load tensor indices with the tensor shape using + // the trick method. A better way would be to modify the FlattenLoop + // Schedule. int cnt = load->indices.size() - load->tensor.as_tensor_ref()->shape.size(); for (int i = 0; i < cnt; i++) { load->indices.erase(load->indices.begin()); @@ -533,7 +576,7 @@ class ResizeBufferSizeVisitor : public ir::IRMutator<> { void Visit(const ir::For *op, Expr *expr) override { CHECK(expr->As()); - auto for_ir = expr->As(); + auto for_ir = expr->As(); auto var_name = for_ir->loop_var->name; auto extent_i = for_ir->extent; @@ -543,11 +586,13 @@ class ResizeBufferSizeVisitor : public ir::IRMutator<> { int BufferSize(ir::Expr indice) { auto copy = IRCopy(indice); - auto vars = ir::CollectIRNodesInOrder(copy, [](const ir::Expr *expr) { return expr->As(); }); + auto vars = ir::CollectIRNodesInOrder( + copy, [](const ir::Expr *expr) { return expr->As(); }); int max_range = 1; // using recursion funcitons index range. - std::function compute_range = [&](const int deep, ir::Expr index) { + std::function compute_range = [&](const int deep, + ir::Expr index) { auto var = vars[deep].as_var_ref(); CHECK(loop_2_extent_.count(var->name)) << var->name; auto extent = loop_2_extent_.find(var->name)->second; @@ -558,7 +603,7 @@ class ResizeBufferSizeVisitor : public ir::IRMutator<> { if (deep == vars.size() - 1) { auto simplify = common::AutoSimplify(tmp); - auto range = common::AutoSimplify(simplify); + auto range = common::AutoSimplify(simplify); CHECK(range.is_constant()); max_range = std::max(max_range, range.as_int32() + 1); } else { @@ -614,11 +659,12 @@ class ReplaceVarToZero : public ir::IRMutator<> { void Visit(const ir::For *op, Expr *expr) override { CHECK(expr->As()); - auto for_ir = expr->As(); + auto for_ir = expr->As(); auto var_name = for_ir->loop_var->name; auto extent_i = for_ir->extent; - if (extent_i.is_constant() && extent_i.as_int32() == 1) loop_var_.insert(var_name); + if (extent_i.is_constant() && extent_i.as_int32() == 1) + loop_var_.insert(var_name); ir::IRMutator<>::Visit(op, expr); loop_var_.erase(var_name); } diff --git a/paddle/cinn/optim/transform_gpu_forloop.h b/paddle/cinn/optim/transform_gpu_forloop.h index 23884bd583394..907c8e51e90aa 100644 --- a/paddle/cinn/optim/transform_gpu_forloop.h +++ b/paddle/cinn/optim/transform_gpu_forloop.h @@ -33,8 +33,8 @@ void OptimizeExprGPU(Expr* expr); */ /** - * Remove the forloops of block and thread axis, add the kernel launch thread dimension information to the outermost - * LoweredFunc. + * Remove the forloops of block and thread axis, add the kernel launch thread + * dimension information to the outermost LoweredFunc. * * For example, input the code: * \code @@ -51,8 +51,9 @@ void OptimizeExprGPU(Expr* expr); * A(blockIdx.x, threadIdx.x) * \endcode * - * \note For that the dimensions of each threadIdx or blockIdx should be constant, so this only takes For nodes, not - * \note PolyFor nodes is allowed to be GPU related. + * \note For that the dimensions of each threadIdx or blockIdx should be + * constant, so this only takes For nodes, not \note PolyFor nodes is allowed to + * be GPU related. */ void RemoveGpuForloopsAxis(Expr* expr); diff --git a/paddle/cinn/optim/transform_polyfor_to_for.cc b/paddle/cinn/optim/transform_polyfor_to_for.cc index 544ead54780ce..ebee754b787ba 100644 --- a/paddle/cinn/optim/transform_polyfor_to_for.cc +++ b/paddle/cinn/optim/transform_polyfor_to_for.cc @@ -89,21 +89,24 @@ struct PolyForWithSimpleConditionToForMutator : public ir::IRMutator { if (!(lt_n || le_n)) return; // check the lhs is the iterator - bool can_extract_extent = (lt_n && lt_n->a().as_var() && lt_n->a().as_var()->name == op->iterator->name) || - (le_n && le_n->a().as_var() && le_n->a().as_var()->name == op->iterator->name); + bool can_extract_extent = + (lt_n && lt_n->a().as_var() && + lt_n->a().as_var()->name == op->iterator->name) || + (le_n && le_n->a().as_var() && + le_n->a().as_var()->name == op->iterator->name); if (!can_extract_extent) { if (node->condition.As()) { auto le = node->condition.As(); CHECK(le->a().As()); CHECK_EQ(le->b().As()->value, 0UL); - auto sub = le->a().As(); + auto sub = le->a().As(); node->condition = ir::LE::Make(sub->a(), sub->b()); } else if (node->condition.As()) { auto lt = node->condition.As(); CHECK(lt->a().As()); CHECK_EQ(lt->b().As()->value, 0UL); - auto sub = lt->a().As(); + auto sub = lt->a().As(); node->condition = ir::LT::Make(sub->a(), sub->b()); } else { LOG(FATAL) << "Unkown Type!"; @@ -116,12 +119,17 @@ struct PolyForWithSimpleConditionToForMutator : public ir::IRMutator { Expr lhs = lt_n ? lt_n->a() : le_n->a(); Expr rhs = lt_n ? lt_n->b() : PlusOneWithMinMax(le_n->b()); - rhs = common::AutoSimplify(rhs); + rhs = common::AutoSimplify(rhs); if (op->is_vectorized()) CHECK(op->vectorize_info().valid()); - Expr new_for = - ir::For::Make(op->iterator, op->init, rhs, op->for_type(), op->device_api, op->body, op->vectorize_info()); + Expr new_for = ir::For::Make(op->iterator, + op->init, + rhs, + op->for_type(), + op->device_api, + op->body, + op->vectorize_info()); *expr = new_for; Visit(&new_for.As()->body); @@ -130,7 +138,9 @@ struct PolyForWithSimpleConditionToForMutator : public ir::IRMutator { } // namespace -void TransformPolyForToFor(Expr* expr, bool auto_separate) { PolyForWithSimpleConditionToForMutator()(expr); } +void TransformPolyForToFor(Expr* expr, bool auto_separate) { + PolyForWithSimpleConditionToForMutator()(expr); +} } // namespace optim } // namespace cinn diff --git a/paddle/cinn/optim/transform_polyfor_to_for.h b/paddle/cinn/optim/transform_polyfor_to_for.h index d4942333ae5b7..a23e84f46b070 100644 --- a/paddle/cinn/optim/transform_polyfor_to_for.h +++ b/paddle/cinn/optim/transform_polyfor_to_for.h @@ -18,8 +18,8 @@ namespace cinn { namespace optim { -//! Transform the PolyFor node to For node. This will also separate the PolyFor with Min or Max conditions into two For -//! nodes if \p auto_separate is true. +//! Transform the PolyFor node to For node. This will also separate the PolyFor +//! with Min or Max conditions into two For nodes if \p auto_separate is true. void TransformPolyForToFor(Expr* expr, bool auto_separate = true); namespace detail { diff --git a/paddle/cinn/optim/transform_polyfor_to_for_test.cc b/paddle/cinn/optim/transform_polyfor_to_for_test.cc index 9fedeb9c65c62..b6f7c073df154 100644 --- a/paddle/cinn/optim/transform_polyfor_to_for_test.cc +++ b/paddle/cinn/optim/transform_polyfor_to_for_test.cc @@ -34,7 +34,9 @@ TEST(Expr, basic) { Var k(K.as_int32(), "k0"); Tensor C = Compute( - {M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + {M, N}, + [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, + "C"); auto stages = CreateStages({C}); @@ -49,7 +51,7 @@ TEST(Expr, basic) { Target target; target.arch = Target::Arch ::X86; target.bits = Target::Bit ::k32; - target.os = Target::OS ::Linux; + target.os = Target::OS ::Linux; { ir::Module::Builder builder("module1", target); diff --git a/paddle/cinn/optim/unroll_loops.cc b/paddle/cinn/optim/unroll_loops.cc index 3727b87464171..6d48104bd47fa 100755 --- a/paddle/cinn/optim/unroll_loops.cc +++ b/paddle/cinn/optim/unroll_loops.cc @@ -45,7 +45,8 @@ struct UnrollMutator : public ir::IRMutator { std::swap(auto_max_step_, value); return; } else { - LOG(WARNING) << "Get invalid value of attr:" << ir::attr::auto_unroll_max_step; + LOG(WARNING) << "Get invalid value of attr:" + << ir::attr::auto_unroll_max_step; } } ir::IRMutator<>::Visit(op, expr); @@ -68,10 +69,12 @@ struct UnrollMutator : public ir::IRMutator { // predicate this for-loop can be unrolled by auto-unroll conditions bool unrollable = - (op->is_serial() && extent >= 0 && not_unrolled_depth_ == 0 && extent * flat_step_ <= auto_max_step_); + (op->is_serial() && extent >= 0 && not_unrolled_depth_ == 0 && + extent * flat_step_ <= auto_max_step_); // predicate this for-loop can be unrolled by the unrolled tag - unrollable = (unrollable || op->is_unrolled()) && extent <= max_unroll_extent_; + unrollable = + (unrollable || op->is_unrolled()) && extent <= max_unroll_extent_; if (unrollable) { Unroll(op, expr); @@ -85,7 +88,7 @@ struct UnrollMutator : public ir::IRMutator { void Unroll(const ir::For* op, Expr* expr) { std::vector body; - auto* min = op->min.As(); + auto* min = op->min.As(); auto* extent = op->extent.As(); if (!(min && extent)) return; diff --git a/paddle/cinn/optim/unroll_loops_test.cc b/paddle/cinn/optim/unroll_loops_test.cc index 809a1e30ab893..87995fd7bf8eb 100644 --- a/paddle/cinn/optim/unroll_loops_test.cc +++ b/paddle/cinn/optim/unroll_loops_test.cc @@ -40,7 +40,8 @@ TEST(UnrollLoops, unrolled_tag) { auto stages = CreateStages({C}); Target target = common::DefaultHostTarget(); - auto func = cinn::lang::LowerVec("test_unrolled_tag", stages, {A, B, C}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_unrolled_tag", stages, {A, B, C}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; ir::ModuleExpr mod_expr({ast_expr}); @@ -74,11 +75,14 @@ TEST(UnrollLoops, auto_unroll) { // B = A + 2.11 Tensor B = Compute( - {M, N, O}, [&](Var i, Var j, Var k) { return A(i, j, k) + const_value; }, "B"); + {M, N, O}, + [&](Var i, Var j, Var k) { return A(i, j, k) + const_value; }, + "B"); - auto stages = CreateStages({B}); + auto stages = CreateStages({B}); Target target = common::DefaultHostTarget(); - auto func = cinn::lang::LowerVec("test_auto_unroll", stages, {A, B}, {}, {}, nullptr, target, true); + auto func = cinn::lang::LowerVec( + "test_auto_unroll", stages, {A, B}, {}, {}, nullptr, target, true); auto ast_expr = func[0]->body; ir::ModuleExpr mod_expr({ast_expr}); ir::IRSchedule ir_sch(mod_expr); @@ -87,8 +91,11 @@ TEST(UnrollLoops, auto_unroll) { // check after the last UnrollLoop pass it will remain unchanged ASSERT_EQ(ir_sch.GetLoops("B").size(), 3); - ASSERT_TRUE(ast_expr.As()->stmts.front().As() != nullptr); - auto* block_realize = ast_expr.As()->stmts.front().As(); + ASSERT_TRUE( + ast_expr.As()->stmts.front().As() != + nullptr); + auto* block_realize = + ast_expr.As()->stmts.front().As(); auto* schedule_block = block_realize->schedule_block.As(); // set the 'auto_unroll_max_step' attribute as value 25 that is bigger than // the product of extent of the inner 2 loops diff --git a/paddle/cinn/optim/var_mod_simplify.cc b/paddle/cinn/optim/var_mod_simplify.cc index 81ab15797a3d6..dcd6de24fef2e 100644 --- a/paddle/cinn/optim/var_mod_simplify.cc +++ b/paddle/cinn/optim/var_mod_simplify.cc @@ -30,11 +30,11 @@ struct ReplaceModWithDivMutator : public ir::IRMutator<> { void Visit(const Mod* op, Expr* expr) override { auto* node = expr->As(); - auto a = node->operand(0); - auto b = node->operand(1); - *expr = ir::Div::Make(a, b); - *expr = ir::Mul::Make(b, *expr); - *expr = ir::Sub::Make(a, *expr); + auto a = node->operand(0); + auto b = node->operand(1); + *expr = ir::Div::Make(a, b); + *expr = ir::Mul::Make(b, *expr); + *expr = ir::Sub::Make(a, *expr); } }; @@ -53,16 +53,17 @@ struct ReplaceDivWithVarMutator : public ir::IRMutator<> { auto b_int = b.As(); CHECK(a_var); CHECK(b_int); - std::string var_name = a_var->name + "/" + std::to_string(b_int->value); + std::string var_name = a_var->name + "/" + std::to_string(b_int->value); div_var_map_[var_name] = ir::Div::Make(a, b); - *expr = Var(var_name); + *expr = Var(var_name); } } }; struct ReplaceVarWithDivMutator : public ir::IRMutator<> { absl::flat_hash_map div_var_map_; - void operator()(Expr* x, const absl::flat_hash_map& div_var_map) { + void operator()(Expr* x, + const absl::flat_hash_map& div_var_map) { div_var_map_ = div_var_map; ir::IRMutator<>::Visit(x, x); } @@ -83,7 +84,7 @@ void VarModSimplify(Expr* e) { ReplaceModWithDivMutator()(e); ReplaceDivWithVarMutator mutator; mutator(e); - *e = common::AutoSimplify(*e); + *e = common::AutoSimplify(*e); auto div_var_map = mutator.div_var_map_; ReplaceVarWithDivMutator()(e, mutator.div_var_map_); } diff --git a/paddle/cinn/optim/vectorize_loops.cc b/paddle/cinn/optim/vectorize_loops.cc index d58c4b6cf0538..eb75b457e9a4f 100644 --- a/paddle/cinn/optim/vectorize_loops.cc +++ b/paddle/cinn/optim/vectorize_loops.cc @@ -51,7 +51,8 @@ Expr Widen(Expr e, int lanes) { } } - CHECK_EQ(e.type().lanes(), 1) << "Cannot broadcast lanes from " << e.type().lanes() << " to " << lanes; + CHECK_EQ(e.type().lanes(), 1) + << "Cannot broadcast lanes from " << e.type().lanes() << " to " << lanes; return ir::Broadcast::Make(e, lanes); } @@ -59,9 +60,11 @@ Expr Widen(Expr e, int lanes) { // of tensors which meet all check predicates of vectoring class TensorVectorizeTeller : public ir::IRMutator { public: - TensorVectorizeTeller(const Var &iter_var, - const int factor, - const absl::flat_hash_map *var_intervals) + TensorVectorizeTeller( + const Var &iter_var, + const int factor, + const absl::flat_hash_map + *var_intervals) : iter_var_(iter_var), factor_(factor), var_intervals_(var_intervals) {} void Collect(const Expr *op) { IRMutator::Visit(op, op); } @@ -73,10 +76,12 @@ class TensorVectorizeTeller : public ir::IRMutator { } private: - const Var iter_var_; // loop var of new for-loop split from the vectorized loop + const Var + iter_var_; // loop var of new for-loop split from the vectorized loop const int factor_; const absl::flat_hash_map *var_intervals_; - // save (tensor name) -> (bool flag) to indentify whether tensors can be vectorized or not + // save (tensor name) -> (bool flag) to indentify whether tensors can be + // vectorized or not std::unordered_map tensor2flag_; void Visit(const ir::Store *expr, const Expr *op) override { @@ -88,7 +93,7 @@ class TensorVectorizeTeller : public ir::IRMutator { // a tensor should pass all check of pre-conditions in every time it appears if (!tensor2flag_.count(tensor->name) || tensor2flag_.at(tensor->name)) { - bool flag = MeetConditions(node->tensor, node->indices); + bool flag = MeetConditions(node->tensor, node->indices); tensor2flag_[tensor->name] = flag; } } @@ -101,7 +106,7 @@ class TensorVectorizeTeller : public ir::IRMutator { // a tensor should pass all check of pre-conditions in every time it appears if (!tensor2flag_.count(tensor->name) || tensor2flag_.at(tensor->name)) { - bool flag = MeetConditions(node->tensor, node->indices); + bool flag = MeetConditions(node->tensor, node->indices); tensor2flag_[tensor->name] = flag; } } @@ -109,19 +114,25 @@ class TensorVectorizeTeller : public ir::IRMutator { // return true if the tensor meets all conditions of vectorizing bool MeetConditions(const Expr &expr, const std::vector &indices) { const ir::_Tensor_ *tensor = expr.As(); - auto find_matched_var_fn = [&](const Expr *x) { return x->As<_Var_>() && x->As<_Var_>()->name == iter_var_->name; }; + auto find_matched_var_fn = [&](const Expr *x) { + return x->As<_Var_>() && x->As<_Var_>()->name == iter_var_->name; + }; // the size of the last dim should be divisible by factor Expr last_size = tensor->shape.back(); - if (tensor->shape.empty() || !tensor->shape.back().As() || tensor->shape.back().as_int32() % factor_ != 0) { - VLOG(5) << "Size of the last dim of tensor:" << tensor->name << " can't be divisible by factor:" << factor_ + if (tensor->shape.empty() || !tensor->shape.back().As() || + tensor->shape.back().as_int32() % factor_ != 0) { + VLOG(5) << "Size of the last dim of tensor:" << tensor->name + << " can't be divisible by factor:" << factor_ << ", shape:" << utils::Join(tensor->shape, ","); return false; } // the iter val must appear in the last index - if (indices.empty() || ir::CollectIRNodes(indices.back(), find_matched_var_fn).empty()) { - VLOG(5) << "Loop var:" << iter_var_->name << " is not used in the last index"; + if (indices.empty() || + ir::CollectIRNodes(indices.back(), find_matched_var_fn).empty()) { + VLOG(5) << "Loop var:" << iter_var_->name + << " is not used in the last index"; return false; } @@ -129,7 +140,8 @@ class TensorVectorizeTeller : public ir::IRMutator { for (int i = 0; i < indices.size() - 1; ++i) { auto repeat_found = ir::CollectIRNodes(indices[i], find_matched_var_fn); if (!repeat_found.empty()) { - VLOG(5) << "Loop var:" << iter_var_->name << " is used at more than last index, current:" << i; + VLOG(5) << "Loop var:" << iter_var_->name + << " is used at more than last index, current:" << i; return false; } } @@ -143,19 +155,24 @@ class TensorVectorizeTeller : public ir::IRMutator { optim::IrReplace(&next_idx, Expr(iter_var_), Expr(i)); auto gap = common::AutoSimplify(Expr(next_idx - first_idx)); if (!gap.As() || gap.as_int32() != i) { - VLOG(5) << "Tensor:" << tensor->name << " is not accessed sequentially, next:" << next_idx + VLOG(5) << "Tensor:" << tensor->name + << " is not accessed sequentially, next:" << next_idx << ", first:" << first_idx << ", gap:" << gap; return false; } - VLOG(5) << "Tensor:" << tensor->name << " is accessed sequentially, next:" << next_idx << ", first:" << first_idx - << ", gap:" << gap; + VLOG(5) << "Tensor:" << tensor->name + << " is accessed sequentially, next:" << next_idx + << ", first:" << first_idx << ", gap:" << gap; } auto dtype = expr->type().ElementOf(); - bool type_supported = - dtype.is_float(32) || dtype.is_int(32) || dtype.is_uint(32) || dtype.is_float16() || dtype.is_bfloat16(); + bool type_supported = dtype.is_float(32) || dtype.is_int(32) || + dtype.is_uint(32) || dtype.is_float16() || + dtype.is_bfloat16(); if (!type_supported) { - VLOG(5) << "Only support vectorizing int,uint,float,float16,bloat16, but got " << dtype; + VLOG(5) + << "Only support vectorizing int,uint,float,float16,bloat16, but got " + << dtype; return false; } return true; @@ -179,17 +196,23 @@ class CudaVectorizer : public IRMutator { static constexpr int CudaVectorTypeMaxLanes = 8; CudaVectorizer(const Var &iter_var, const int factor, - const absl::flat_hash_map *var_intervals) - : iter_var_(iter_var), factor_(factor), vectorized_teller_(iter_var, factor, var_intervals) { + const absl::flat_hash_map + *var_intervals) + : iter_var_(iter_var), + factor_(factor), + vectorized_teller_(iter_var, factor, var_intervals) { CHECK(factor <= CudaVectorTypeMaxLanes) - << "The maximum lanes of valid CUDA vector types: " << CudaVectorTypeMaxLanes << ", but factor: " << factor; + << "The maximum lanes of valid CUDA vector types: " + << CudaVectorTypeMaxLanes << ", but factor: " << factor; } // return all cast statements collected through vectorizing std::vector VectorizedTypeCastExprs() { return vectorized_cast_exprs_; } // return all store statements collected through vectorizing - std::vector VectorizedTypeStoreExprs() { return vectorized_store_exprs_; } + std::vector VectorizedTypeStoreExprs() { + return vectorized_store_exprs_; + } void Visit(Expr *expr) { write_teller_.Collect(expr); @@ -198,15 +221,16 @@ class CudaVectorizer : public IRMutator { } void Visit(const Load *op, Expr *expr) override { - auto *node = expr->As(); + auto *node = expr->As(); auto *tensor = node->tensor.As(); - if (node->is_addr_tensor() && vectorized_teller_.CanBeVectorized(tensor->name)) { + if (node->is_addr_tensor() && + vectorized_teller_.CanBeVectorized(tensor->name)) { TensorVectorized(node, &node->indices, false); } } void Visit(const Store *op, Expr *expr) override { - auto *node = expr->As(); + auto *node = expr->As(); auto *tensor = node->tensor.As(); CHECK(tensor); if (vectorized_teller_.CanBeVectorized(tensor->name)) { @@ -217,7 +241,9 @@ class CudaVectorizer : public IRMutator { } private: - void TensorVectorized(ir::LoadStoreAddrMnger *node, std::vector *indices, bool is_store) { + void TensorVectorized(ir::LoadStoreAddrMnger *node, + std::vector *indices, + bool is_store) { auto *tensor = node->tensor.As(); VLOG(5) << "Vectorizing tensor:" << tensor->name; @@ -228,8 +254,14 @@ class CudaVectorizer : public IRMutator { auto vectorized_var = tensor2vectorized_vars_.at(tensor->name); // substitue a new tensor with the vector name and dtype - auto t = vectorized_var->type().is_cpp_handle() ? node->tensor->type().PointerOf() : node->tensor->type(); - node->tensor = ir::Tensor(vectorized_var->name, t, {Expr(factor_)}, {Expr(factor_)}, tensor->operation); + auto t = vectorized_var->type().is_cpp_handle() + ? node->tensor->type().PointerOf() + : node->tensor->type(); + node->tensor = ir::Tensor(vectorized_var->name, + t, + {Expr(factor_)}, + {Expr(factor_)}, + tensor->operation); // remain the last iterative indice indices->assign({iter_var_}); } @@ -253,8 +285,10 @@ class CudaVectorizer : public IRMutator { return ""; } - void AppendCast(Expr tensor, const std::vector &indices, bool is_store) { - auto *node = tensor.As(); + void AppendCast(Expr tensor, + const std::vector &indices, + bool is_store) { + auto *node = tensor.As(); bool is_const = !write_teller_.IsWrite(node->name); // generate the corresponding vector type @@ -270,7 +304,7 @@ class CudaVectorizer : public IRMutator { // generate a local vector variable to be used in subsequent statements std::string vectorized_name = "vectorized_" + node->name; - Var vectorized_var = _Var_::Make(vectorized_name, vector_type); + Var vectorized_var = _Var_::Make(vectorized_name, vector_type); tensor2vectorized_vars_.emplace(node->name, vectorized_var); // generate a get_addr expr to get the address of the tensor @@ -282,11 +316,12 @@ class CudaVectorizer : public IRMutator { auto cast = ir::Cast::Make(vector_type_ptr, get_addr); if (!is_store) { auto load = Load::Make(cast, {make_const(0)}); - auto let = Let::Make(vectorized_var, load); + auto let = Let::Make(vectorized_var, load); vectorized_cast_exprs_.emplace_back(let); VLOG(5) << "Append a vectorized expr:" << let; } else { - Var vectorized_ptr = _Var_::Make(vectorized_name + "_ptr", vector_type_ptr); + Var vectorized_ptr = + _Var_::Make(vectorized_name + "_ptr", vector_type_ptr); auto let1 = Let::Make(vectorized_ptr, cast); auto let2 = Let::Make(vectorized_var, Expr(0)); @@ -296,8 +331,11 @@ class CudaVectorizer : public IRMutator { VLOG(5) << "Append a vectorized expr:" << let1; VLOG(5) << "Append a vectorized expr:" << let2; - auto t = - ir::Tensor(vectorized_ptr->name, node->type().PointerOf(), {Expr(factor_)}, {Expr(factor_)}, node->operation); + auto t = ir::Tensor(vectorized_ptr->name, + node->type().PointerOf(), + {Expr(factor_)}, + {Expr(factor_)}, + node->operation); auto store = Store::Make(t, vectorized_var, {make_const(0)}); vectorized_store_exprs_.emplace_back(store); @@ -325,7 +363,10 @@ class Vectorizer : public IRMutator { std::string widen_suffix; public: - Vectorizer(const Var &var, int lanes, const absl::flat_hash_map &var_intervals = {}) + Vectorizer(const Var &var, + int lanes, + const absl::flat_hash_map + &var_intervals = {}) : var(var), lanes_(lanes), var_intervals_(var_intervals) { // the identity ramp. ramp_ = Ramp::Make(make_zero(), make_one(), lanes_); @@ -343,7 +384,7 @@ class Vectorizer : public IRMutator { void Visit(const Cast *op, Expr *expr) override { auto *node = expr->As(); - auto v0 = node->v(); + auto v0 = node->v(); Visit(&node->v()); if (v0.same_as(node->v())) return; @@ -357,46 +398,64 @@ class Vectorizer : public IRMutator { } } - void Visit(const Add *op, Expr *expr) override { MutateAddSubOperator(op, expr); } - void Visit(const Sub *op, Expr *expr) override { MutateAddSubOperator(op, expr); } - void Visit(const Mul *op, Expr *expr) override { MutateMulDivOperator(op, expr); } - void Visit(const Div *op, Expr *expr) override { MutateMulDivOperator(op, expr); } - void Visit(const Mod *op, Expr *expr) override { MutateMulDivOperator(op, expr); } - void Visit(const Min *op, Expr *expr) override { BinaryOperatorVec(op, expr); } - void Visit(const Max *op, Expr *expr) override { BinaryOperatorVec(op, expr); } + void Visit(const Add *op, Expr *expr) override { + MutateAddSubOperator(op, expr); + } + void Visit(const Sub *op, Expr *expr) override { + MutateAddSubOperator(op, expr); + } + void Visit(const Mul *op, Expr *expr) override { + MutateMulDivOperator(op, expr); + } + void Visit(const Div *op, Expr *expr) override { + MutateMulDivOperator(op, expr); + } + void Visit(const Mod *op, Expr *expr) override { + MutateMulDivOperator(op, expr); + } + void Visit(const Min *op, Expr *expr) override { + BinaryOperatorVec(op, expr); + } + void Visit(const Max *op, Expr *expr) override { + BinaryOperatorVec(op, expr); + } void Visit(const EQ *op, Expr *expr) override { BinaryOperatorVec(op, expr); } void Visit(const NE *op, Expr *expr) override { BinaryOperatorVec(op, expr); } void Visit(const LT *op, Expr *expr) override { BinaryOperatorVec(op, expr); } void Visit(const LE *op, Expr *expr) override { BinaryOperatorVec(op, expr); } void Visit(const GT *op, Expr *expr) override { BinaryOperatorVec(op, expr); } void Visit(const GE *op, Expr *expr) override { BinaryOperatorVec(op, expr); } - void Visit(const And *op, Expr *expr) override { BinaryOperatorVec(op, expr); } + void Visit(const And *op, Expr *expr) override { + BinaryOperatorVec(op, expr); + } void Visit(const Or *op, Expr *expr) override { BinaryOperatorVec(op, expr); } void Visit(const Ramp *op, Expr *expr) override {} void Visit(const Select *op, Expr *expr) override { - auto *node = expr->As(); + auto condition0 = node->condition; + auto true_value0 = node->true_value; auto false_value0 = node->false_value; Visit(&node->condition); Visit(&node->true_value); Visit(&node->false_value); - if (condition0.same_as(node->condition) && true_value0.same_as(node->true_value) && + if (condition0.same_as(node->condition) && + true_value0.same_as(node->true_value) && false_value0.same_as(node->false_value)) return; - int lanes = - utils::Max(node->condition.type().lanes(), node->true_value.type().lanes(), node->false_value.type().lanes()); - node->true_value = Widen(node->true_value, lanes); + int lanes = utils::Max(node->condition.type().lanes(), + node->true_value.type().lanes(), + node->false_value.type().lanes()); + node->true_value = Widen(node->true_value, lanes); node->false_value = Widen(node->false_value, lanes); } void Visit(const Load *op, Expr *expr) override { - auto *node = expr->As(); + auto *node = expr->As(); std::vector indices = node->indices; // We ignore the predicate here. bool need_visit = false; @@ -419,7 +478,7 @@ class Vectorizer : public IRMutator { } void Visit(const Store *op, Expr *expr) override { - auto *node = expr->As(); + auto *node = expr->As(); auto value0 = node->value; Visit(&node->value); @@ -451,12 +510,12 @@ class Vectorizer : public IRMutator { } void Visit(const Call *op, Expr *expr) override { - std::vector read_args = op->read_args; + std::vector read_args = op->read_args; std::vector write_args = op->write_args; - auto *node = expr->As(); + auto *node = expr->As(); ir::IRMutator<>::Visit(op, expr); bool is_changed = false; - int lanes = 0; + int lanes = 0; for (int i = 0; i < node->read_args.size(); i++) { lanes = std::max(node->read_args[i].type().lanes(), lanes); if (!node->read_args[i].same_as(read_args[i])) { @@ -480,7 +539,7 @@ class Vectorizer : public IRMutator { CHECK(!read_args.empty()); Type type = op->type().with_lanes(lanes); - *expr = Call::Make(type, + *expr = Call::Make(type, node->name, node->read_args, node->write_args, @@ -505,7 +564,9 @@ class Vectorizer : public IRMutator { LOG(ERROR) << "Ignore Width IfThenElse"; } - void Visit(const For *op, Expr *expr) override { ir::IRMutator<>::Visit(op, expr); } + void Visit(const For *op, Expr *expr) override { + ir::IRMutator<>::Visit(op, expr); + } void Scalarize(Expr *expr) { Var idx(var->name + "_s", Int(32)); @@ -513,15 +574,19 @@ class Vectorizer : public IRMutator { var_map[var.As()] = idx; common::Substitute(expr, var_map); - *expr = - ir::For::Make(idx, common::make_const(0), common::make_const(lanes_), ForType::Serial, DeviceAPI::Host, *expr); + *expr = ir::For::Make(idx, + common::make_const(0), + common::make_const(lanes_), + ForType::Serial, + DeviceAPI::Host, + *expr); } template void MutateAddSubOperator(const T *op, Expr *expr) { auto *node = expr->As(); - Expr a0 = node->a(); - Expr b0 = node->b(); + Expr a0 = node->a(); + Expr b0 = node->b(); Visit(&node->a()); Visit(&node->b()); @@ -552,8 +617,8 @@ class Vectorizer : public IRMutator { template void MutateMulDivOperator(const T *op, Expr *expr) { - Expr a0 = op->a(); - Expr b0 = op->b(); + Expr a0 = op->a(); + Expr b0 = op->b(); auto *node = expr->As(); Visit(&node->a()); Visit(&node->b()); @@ -586,14 +651,14 @@ class Vectorizer : public IRMutator { template void BinaryOperatorVec(const T *op, Expr *expr) { auto *node = expr->As(); - Expr a0 = node->a(); - Expr b0 = node->b(); + Expr a0 = node->a(); + Expr b0 = node->b(); Visit(&node->a()); Visit(&node->b()); // if (a0.same_as(node->a()) && b0.same_as(node->b())) return *expr; int lanes = std::max(node->a().type().lanes(), node->b().type().lanes()); - *expr = T::Make(Widen(node->a(), lanes), Widen(node->b(), lanes)); + *expr = T::Make(Widen(node->a(), lanes), Widen(node->b(), lanes)); } }; @@ -607,7 +672,7 @@ struct VectorizeLoops_ : public IRMutator { void operator()(Expr *expr) { IRMutator::Visit(expr, expr); } void Visit(const Load *op, Expr *expr) override { - auto *node = expr->As(); + auto *node = expr->As(); std::vector indices = node->indices; bool is_changed = false; @@ -630,7 +695,7 @@ struct VectorizeLoops_ : public IRMutator { IRMutator::Visit(&node->value, &node->value); std::vector indices = node->indices; - bool is_changed = false; + bool is_changed = false; // simplify the complicated index from poly in the format of div/mod for (int i = 0; i < indices.size(); i++) { node->indices[i] = common::AutoSimplify(node->indices[i], var_intervals); @@ -652,12 +717,14 @@ struct VectorizeLoops_ : public IRMutator { } void Visit(const For *forloop, Expr *expr) { - auto *node = expr->As(); + auto *node = expr->As(); auto loopvar_name = forloop->loop_var->name; if (forloop->extent.As()) { - var_intervals.emplace(loopvar_name, common::CasInterval{0, forloop->extent.as_int32() - 1}); + var_intervals.emplace( + loopvar_name, common::CasInterval{0, forloop->extent.as_int32() - 1}); } else { - var_intervals.emplace(loopvar_name, common::CasInterval{Expr(0), forloop->extent - 1}); + var_intervals.emplace(loopvar_name, + common::CasInterval{Expr(0), forloop->extent - 1}); } // the extent the forloops marked as Vectorized should be int constant if (forloop->is_vectorized()) { @@ -668,7 +735,7 @@ struct VectorizeLoops_ : public IRMutator { CHECK(is_zero(forloop->min)); Expr for_extent = common::AutoSimplify(forloop->extent); Simplify(&for_extent); - node->extent = for_extent; + node->extent = for_extent; auto *extent_min = for_extent.As(); auto *extent_max = for_extent.As(); @@ -676,9 +743,12 @@ struct VectorizeLoops_ : public IRMutator { IRMutator<>::Visit(&node->body, &node->body); if (target == common::DefaultNVGPUTarget()) { - if (!forloop->extent.As() || forloop->extent.as_int32() % forloop->vectorize_info().factor != 0) { + if (!forloop->extent.As() || + forloop->extent.as_int32() % forloop->vectorize_info().factor != + 0) { vectorizable_ = false; - VLOG(5) << "GPU vectorize only support extent is a multiple of factor"; + VLOG(5) + << "GPU vectorize only support extent is a multiple of factor"; } } @@ -689,7 +759,7 @@ struct VectorizeLoops_ : public IRMutator { return; } - const int factor = forloop->vectorize_info().factor; + const int factor = forloop->vectorize_info().factor; auto _new_forloop = SplitForLoop(node, factor); if (!_new_forloop.defined()) { IRMutator<>::Visit(&node->body, &node->body); @@ -701,8 +771,9 @@ struct VectorizeLoops_ : public IRMutator { auto *new_forloop = _new_forloop.As(); - // The forloop generated from polyhedral analysis might have a complex condition that is not something like - // "i<20" or "i<=20", those cases is not possible to extract the extent. + // The forloop generated from polyhedral analysis might have a complex + // condition that is not something like "i<20" or "i<=20", those cases is + // not possible to extract the extent. auto *extent_int = new_forloop->extent.As(); if (!extent_int) { @@ -712,35 +783,45 @@ struct VectorizeLoops_ : public IRMutator { } int extent = extent_int->value; - CHECK_GT(extent, 0) << "Loop over " << Expr(new_forloop->loop_var) << " has extent " << new_forloop->extent - << ". Can only vectorize loops over a constant extent > 1"; + CHECK_GT(extent, 0) + << "Loop over " << Expr(new_forloop->loop_var) << " has extent " + << new_forloop->extent + << ". Can only vectorize loops over a constant extent > 1"; - VLOG(2) << "Vectorizing " << new_forloop->loop_var << " extent " << extent; + VLOG(2) << "Vectorizing " << new_forloop->loop_var << " extent " + << extent; VLOG(2) << "before vectorize body:\n" << node->body; if (target == common::DefaultNVGPUTarget()) { - CudaVectorizer cuda_vectorizer(new_forloop->loop_var, factor, &var_intervals); + CudaVectorizer cuda_vectorizer( + new_forloop->loop_var, factor, &var_intervals); cuda_vectorizer.Visit(&new_forloop->body); - // unroll the new forloop to compute each element of the vector iteratively + // unroll the new forloop to compute each element of the vector + // iteratively auto copied_loop = optim::IRCopy(_new_forloop); copied_loop.As()->set_unrolled(); optim::UnrollLoop(&copied_loop); // add cast exprs of vector type in the front of vectorized forloop, - // and replace original compute statements with the correspond unrolled ones + // and replace original compute statements with the correspond unrolled + // ones auto unroll_body = copied_loop.As()->stmts; - auto cast_exprs = cuda_vectorizer.VectorizedTypeCastExprs(); + auto cast_exprs = cuda_vectorizer.VectorizedTypeCastExprs(); auto store_exprs = cuda_vectorizer.VectorizedTypeStoreExprs(); auto &body_stmts = new_forloop->body.As()->stmts; body_stmts.assign(cast_exprs.begin(), cast_exprs.end()); - body_stmts.insert(body_stmts.end(), unroll_body.begin(), unroll_body.end()); - body_stmts.insert(body_stmts.end(), store_exprs.begin(), store_exprs.end()); + body_stmts.insert( + body_stmts.end(), unroll_body.begin(), unroll_body.end()); + body_stmts.insert( + body_stmts.end(), store_exprs.begin(), store_exprs.end()); } else { - Vectorizer(new_forloop->loop_var, extent, var_intervals).Visit(&new_forloop->body); + Vectorizer(new_forloop->loop_var, extent, var_intervals) + .Visit(&new_forloop->body); } VLOG(2) << "after vectorize body:\n" << node->body; - // Remove the forloop, the new_forloop's body is vectorized to Ramp, so no forloop is needed. + // Remove the forloop, the new_forloop's body is vectorized to Ramp, so no + // forloop is needed. if (is_zero(forloop->extent - 1)) { *expr = new_forloop->body; } else { @@ -752,7 +833,8 @@ struct VectorizeLoops_ : public IRMutator { var_intervals.erase(loopvar_name); } - //! unroll the forloop if its' extent is min type by solving the condition extent + //! unroll the forloop if its' extent is min type by solving the condition + //! extent //! @return The new forloop. bool UnrollCmpFor(For *outer_for, For *inner_for, Expr *expr) { CHECK(outer_for); @@ -765,39 +847,48 @@ struct VectorizeLoops_ : public IRMutator { // simplify the complicated indices of load/store from poly IRMutator::Visit(&inner_for->body, &inner_for->body); Expr a, b, condition; - a = extent_min->a(); - b = extent_min->b(); + a = extent_min->a(); + b = extent_min->b(); auto a_int = a.As(); auto b_int = a.As(); if (a_int || b_int) { - condition = common::SolveInequality(LE::Make(a, b), outer_for->loop_var); + condition = + common::SolveInequality(LE::Make(a, b), outer_for->loop_var); Simplify(&condition); } if (condition.defined()) { - auto le_n = condition.As(); + auto le_n = condition.As(); bool can_split = le_n && le_n->b().is_constant(); if (le_n && le_n->b().is_constant()) { - Expr inner_for_a = Block::Make({For::Make(inner_for->loop_var, - inner_for->min, - a, - ForType::Vectorized, - DeviceAPI::UNK, - inner_for->body, - inner_for->vectorize_info())}); + Expr inner_for_a = + Block::Make({For::Make(inner_for->loop_var, + inner_for->min, + a, + ForType::Vectorized, + DeviceAPI::UNK, + inner_for->body, + inner_for->vectorize_info())}); Expr new_extent_a = common::AutoSimplify(le_n->b() + 1); - Expr out_for_a = For::Make(outer_for->loop_var, + Expr out_for_a = For::Make(outer_for->loop_var, outer_for->min, new_extent_a, outer_for->for_type(), outer_for->device_api, inner_for_a, outer_for->vectorize_info()); - Var new_iterator_inner(common::UniqName(inner_for->loop_var->name + "_s")); - Var new_iterator_outer(common::UniqName(outer_for->loop_var->name + "_s")); + Var new_iterator_inner( + common::UniqName(inner_for->loop_var->name + "_s")); + Var new_iterator_outer( + common::UniqName(outer_for->loop_var->name + "_s")); - Expr inner_for_b = Block::Make({For::Make( - new_iterator_inner, inner_for->min, b, ForType::Serial, DeviceAPI::UNK, IRCopy(inner_for->body))}); - optim::IrReplace(&inner_for_b, inner_for->loop_var, Expr(new_iterator_inner)); + Expr inner_for_b = Block::Make({For::Make(new_iterator_inner, + inner_for->min, + b, + ForType::Serial, + DeviceAPI::UNK, + IRCopy(inner_for->body))}); + optim::IrReplace( + &inner_for_b, inner_for->loop_var, Expr(new_iterator_inner)); Expr out_for_b = For::Make(new_iterator_outer, new_extent_a, @@ -806,7 +897,8 @@ struct VectorizeLoops_ : public IRMutator { outer_for->device_api, inner_for_b, outer_for->vectorize_info()); - optim::IrReplace(&out_for_b, outer_for->loop_var, Expr(new_iterator_outer)); + optim::IrReplace( + &out_for_b, outer_for->loop_var, Expr(new_iterator_outer)); *expr = Block::Make({out_for_a, out_for_b}); VLOG(2) << *expr; IRMutator::Visit(expr, expr); @@ -829,12 +921,14 @@ struct VectorizeLoops_ : public IRMutator { auto *extent_ptr = forloop->extent.As(); Expr times; if (extent_ptr) { - int extent_int = forloop->extent.as_int32(); + int extent_int = forloop->extent.as_int32(); int extent_trunc = extent_int / factor; - int extent_times = extent_int % factor == 0 ? extent_trunc : extent_trunc + 1; - times = common::make_const(forloop->extent->type(), extent_times); + int extent_times = + extent_int % factor == 0 ? extent_trunc : extent_trunc + 1; + times = common::make_const(forloop->extent->type(), extent_times); } else { - times = common::AutoSimplify(Div::Make(forloop->extent, make_const(factor))); + times = + common::AutoSimplify(Div::Make(forloop->extent, make_const(factor))); Simplify(×); } @@ -844,16 +938,20 @@ struct VectorizeLoops_ : public IRMutator { forloop->extent = times; if (times_int && forloop->extent.as_int32() >= 1) { - var_intervals.emplace(forloop->loop_var->name, common::CasInterval{0, forloop->extent.as_int32() - 1}); + var_intervals.emplace( + forloop->loop_var->name, + common::CasInterval{0, forloop->extent.as_int32() - 1}); } else { var_intervals.erase(forloop->loop_var->name); - var_intervals.emplace(forloop->loop_var->name, common::CasInterval{Expr(0), forloop->extent - 1}); + var_intervals.emplace(forloop->loop_var->name, + common::CasInterval{Expr(0), forloop->extent - 1}); } // create the new forloop { Var new_iterator(Context::Global().NewName("vi")); - var_intervals.emplace(new_iterator->name, common::CasInterval{0, factor - 1}); + var_intervals.emplace(new_iterator->name, + common::CasInterval{0, factor - 1}); // eliminate for 1 Expr new_index; if (common::is_zero(times - 1)) { @@ -869,13 +967,15 @@ struct VectorizeLoops_ : public IRMutator { DeviceAPI::UNK, forloop->body, forloop->vectorize_info()); - forloop->body = Block::Make({new_forloop}); + forloop->body = Block::Make({new_forloop}); return new_forloop; } } }; -void VectorizeLoops(Expr *expr, const Target &target) { return VectorizeLoops_(target)(expr); } +void VectorizeLoops(Expr *expr, const Target &target) { + return VectorizeLoops_(target)(expr); +} namespace detail { diff --git a/paddle/cinn/optim/vectorize_loops_test.cc b/paddle/cinn/optim/vectorize_loops_test.cc index 01b0a10f3ca52..55f8b3097d91e 100644 --- a/paddle/cinn/optim/vectorize_loops_test.cc +++ b/paddle/cinn/optim/vectorize_loops_test.cc @@ -57,14 +57,14 @@ TEST(Vectorize, replace_var) { Target target; target.arch = Target::Arch ::X86; target.bits = Target::Bit ::k32; - target.os = Target::OS ::Linux; + target.os = Target::OS ::Linux; ir::Module::Builder builder("module1", target); builder.AddFunction(ir::LoweredFunc(func.As())); CodeGenC codegen(target); codegen.SetInlineBuiltinCodes(false); - auto out = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + auto out = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); auto target_out = R"ROC( #include #include @@ -101,7 +101,7 @@ TEST(Vectorize, TestMarkVectorize) { Target target; target.arch = Target::Arch ::X86; target.bits = Target::Bit ::k32; - target.os = Target::OS ::Linux; + target.os = Target::OS ::Linux; Placeholder A("A", {M, N}); Placeholder B("B", {M, N}); @@ -196,10 +196,11 @@ TEST(Vectorize, vectorize) { Placeholder C("C", std::vector{{10}}); auto expr = Load::Make(ir::Tensor(A), {a * 2 + b * 2}); - expr = expr + 10.f * expr; + expr = expr + 10.f * expr; detail::Vectorize(a, 16, &expr); EXPECT_EQ(GetStreamCnt(expr), - "(A[Ramp(((b * 2) + (0 * 2)),(1 * 2),16)] + (Broadcast(10.0000000f,16) * A[Ramp(((b * 2) + (0 * 2)),(1 * " + "(A[Ramp(((b * 2) + (0 * 2)),(1 * 2),16)] + " + "(Broadcast(10.0000000f,16) * A[Ramp(((b * 2) + (0 * 2)),(1 * " "2),16)]))"); } } @@ -216,7 +217,7 @@ TEST(Vectorize, single_for) { ir::Load::Make(ir::Tensor(A), {Expr(loop_var)}), ir::Load::Make(ir::Tensor(B), {Expr(loop_var)})), {Expr(loop_var)}); - body = ir::Block::Make({body}); + body = ir::Block::Make({body}); VectorizeInfo vectorize_info(0, 16); auto forloop = ir::For::Make(loop_var, @@ -244,7 +245,7 @@ TEST(Vectorize, cuda_vectorize) { auto stages = CreateStages({C}); stages[C]->Vectorize(1, 4); Target target = common::DefaultNVGPUTarget(); - auto func = Lower("matmul", stages, {A, B, C}, {}, {}, nullptr, target); + auto func = Lower("matmul", stages, {A, B, C}, {}, {}, nullptr, target); auto target_expr = R"ROC( function matmul (_A, _B, _C) @@ -281,7 +282,7 @@ TEST(Vectorize, cuda_vectorize_with_constant) { auto stages = CreateStages({C}); stages[C]->Vectorize(1, 4); Target target = common::DefaultNVGPUTarget(); - auto func = Lower("mul_const", stages, {A, C}, {}, {}, nullptr, target); + auto func = Lower("mul_const", stages, {A, C}, {}, {}, nullptr, target); } } // namespace optim diff --git a/paddle/cinn/poly/ast_gen.cc b/paddle/cinn/poly/ast_gen.cc index a2baa023845fc..f71ec5fed9ed6 100644 --- a/paddle/cinn/poly/ast_gen.cc +++ b/paddle/cinn/poly/ast_gen.cc @@ -42,12 +42,13 @@ struct AstGen::Impl { isl::ctx ctx() const; /** - * Help to collect the map from the axis(and the pos) in statement to the transformed indice. - * e.g. If s[i,j] will be generated to something like s[a+2, b] in the final AST, this will return + * Help to collect the map from the axis(and the pos) in statement to the + * transformed indice. e.g. If s[i,j] will be generated to something like + * s[a+2, b] in the final AST, this will return * - a map { i->a+2, j->b, 0->a+2, 1->b }. */ - static std::map ExtractIslTransformedIndiceMap(const isl::set& iterator_domain, - isl_ast_build* build); + static std::map ExtractIslTransformedIndiceMap( + const isl::set& iterator_domain, isl_ast_build* build); //! Get the polyhedral stages. const std::vector>& stages() const { return stages_; } @@ -58,7 +59,8 @@ struct AstGen::Impl { const poly::ScheduleGroup& schedule_group_; std::vector iterator_names_; //! tuple name -> { axis -> isl_ast } - std::map> transformed_indice_map_; + std::map> + transformed_indice_map_; isl::union_map build_options_; friend class AstGen; @@ -68,8 +70,8 @@ isl::union_set AstGen::domain() const { return impl_->domain(); } isl::union_set AstGen::Impl::domain() const { CHECK(!stages_.empty()); - auto sets = - utils::Map>, isl::set>(stages_, [](const Shared& e) { return e->domain(); }); + auto sets = utils::Map>, isl::set>( + stages_, [](const Shared& e) { return e->domain(); }); return isl_sets_to_union_set(sets); } @@ -99,21 +101,28 @@ isl::set TransIdentityExtentToContextId(isl::set set) { isl::set res_set = set; for (auto offset_val : iden_dim_offsets) { auto& offset = std::get<0>(offset_val); - auto& val = std::get<1>(offset_val); - res_set = isl::manage(isl_set_drop_constraints_involving_dims(res_set.copy(), isl_dim_set, offset, 1)); - - std::string const_param_name = llvm::formatv("{0}{1}", kIslParamConstPrefix, val); - - std::string cond_str = llvm::formatv( - "{0} <= {1} <= {2}", val, isl_set_get_dim_name(res_set.get(), isl_dim_set, offset), const_param_name); - std::string param_cond_str = llvm::formatv("{0} <= {1} < {2}", val, const_param_name, val + 2); - - std::string set_repr = llvm::formatv("[{0}] -> { {1}[{2}]: {3} and {4} }", - const_param_name, - isl_set_get_tuple_name(res_set.get()), - utils::Join(isl_get_dim_names(res_set.get()), ","), - cond_str, - param_cond_str); + auto& val = std::get<1>(offset_val); + res_set = isl::manage(isl_set_drop_constraints_involving_dims( + res_set.copy(), isl_dim_set, offset, 1)); + + std::string const_param_name = + llvm::formatv("{0}{1}", kIslParamConstPrefix, val); + + std::string cond_str = + llvm::formatv("{0} <= {1} <= {2}", + val, + isl_set_get_dim_name(res_set.get(), isl_dim_set, offset), + const_param_name); + std::string param_cond_str = + llvm::formatv("{0} <= {1} < {2}", val, const_param_name, val + 2); + + std::string set_repr = + llvm::formatv("[{0}] -> { {1}[{2}]: {3} and {4} }", + const_param_name, + isl_set_get_tuple_name(res_set.get()), + utils::Join(isl_get_dim_names(res_set.get()), ","), + cond_str, + param_cond_str); VLOG(4) << "repr: " << set_repr; @@ -129,7 +138,7 @@ isl::union_set TransIdentityExtentToContextId(isl::union_set set) { llvm::SmallVector sets; for (int i = 0; i < isl_set_list_n_set(set_list); i++) { auto set = isl::manage(isl_set_list_get_set(set_list, i)); - set = TransIdentityExtentToContextId(set); + set = TransIdentityExtentToContextId(set); sets.push_back(set); } isl_set_list_free(set_list); @@ -143,7 +152,8 @@ isl::ast_node AstGen::Build() { std::vector maps; for (auto& stage : impl_->stages_) { auto it = schedule_map.find(stage->id()); - CHECK(it != std::end(schedule_map)) << "stage " << stage->id() << " not found in the map"; + CHECK(it != std::end(schedule_map)) + << "stage " << stage->id() << " not found in the map"; maps.push_back(it->second); } auto schedule = isl_maps_to_union_map(maps); @@ -152,39 +162,51 @@ isl::ast_node AstGen::Build() { auto ast_build = isl::ast_build::from_context(impl_->context_); if (!impl_->build_options_.is_null()) - ast_build = isl::manage(isl_ast_build_set_options(ast_build.release(), impl_->build_options_.release())); + ast_build = isl::manage(isl_ast_build_set_options( + ast_build.release(), impl_->build_options_.release())); // Set iterators names for readable code. - auto iterator_names = - impl_->iterator_names_.empty() ? impl_->schedule_group_.dimension_names : impl_->iterator_names_; + auto iterator_names = impl_->iterator_names_.empty() + ? impl_->schedule_group_.dimension_names + : impl_->iterator_names_; - iterator_names = SchedulerBase::WrapIteratorNames(iterator_names); - isl::id_list ids = isl::manage(isl_id_list_alloc(ctx().get(), iterator_names.size())); + iterator_names = SchedulerBase::WrapIteratorNames(iterator_names); + isl::id_list ids = + isl::manage(isl_id_list_alloc(ctx().get(), iterator_names.size())); for (int i = 0; i < iterator_names.size(); i++) { - ids = isl::manage(isl_id_list_add(ids.release(), isl_id_alloc(ctx().get(), iterator_names[i].c_str(), nullptr))); + ids = isl::manage(isl_id_list_add( + ids.release(), + isl_id_alloc(ctx().get(), iterator_names[i].c_str(), nullptr))); } - ast_build = isl::manage(isl_ast_build_set_iterators(ast_build.release(), ids.release())); + ast_build = isl::manage( + isl_ast_build_set_iterators(ast_build.release(), ids.release())); // collect iterator map auto get_domain_by_name = [this](const std::string& name) -> isl::set { auto ele_it = std::find_if( - impl_->stages_.begin(), impl_->stages_.end(), [&name](const Shared& ele) { return ele->id() == name; }); + impl_->stages_.begin(), + impl_->stages_.end(), + [&name](const Shared& ele) { return ele->id() == name; }); CHECK(ele_it != std::end(impl_->stages_)); return (*ele_it)->domain(); }; - auto collect = [&](isl::ast_node node, isl::ast_build build) -> isl::ast_node { + auto collect = [&](isl::ast_node node, + isl::ast_build build) -> isl::ast_node { auto tuple_name = detail::GetTupleName(node.get()); - auto indice_map = impl_->ExtractIslTransformedIndiceMap(get_domain_by_name(tuple_name), build.get()); + auto indice_map = impl_->ExtractIslTransformedIndiceMap( + get_domain_by_name(tuple_name), build.get()); impl_->transformed_indice_map_[tuple_name] = indice_map; return node; }; ast_build = ast_build.set_at_each_domain(collect); - isl::union_map transformed_schedule = impl_->transform().apply_range(schedule); + isl::union_map transformed_schedule = + impl_->transform().apply_range(schedule); VLOG(4) << "transformed_schedule: " << transformed_schedule; - isl::union_map schedule_domain = transformed_schedule.intersect_domain(impl_->domain()); + isl::union_map schedule_domain = + transformed_schedule.intersect_domain(impl_->domain()); VLOG(4) << "domain: " << impl_->domain(); VLOG(4) << "transform schedule " << impl_->stages()[0]->transform(); VLOG(4) << "schedule: " << schedule; @@ -199,22 +221,26 @@ AstGen& AstGen::SetIteratorNames(const std::vector& names) { return *this; } -isl::ast_expr CreateIslAstIndexExpression(isl_ast_build* build, const isl::map& access); +isl::ast_expr CreateIslAstIndexExpression(isl_ast_build* build, + const isl::map& access); -std::map AstGen::Impl::ExtractIslTransformedIndiceMap(const isl::set& iterator_domain, - isl_ast_build* build) { +std::map +AstGen::Impl::ExtractIslTransformedIndiceMap(const isl::set& iterator_domain, + isl_ast_build* build) { std::map iterator_map; isl::map identity = isl::manage(isl_set_identity(iterator_domain.copy())); isl::map schedule = identity; - identity = identity.apply_domain(schedule); - isl::ast_expr idx_expr = CreateIslAstIndexExpression(build, identity); + identity = identity.apply_domain(schedule); + isl::ast_expr idx_expr = CreateIslAstIndexExpression(build, identity); isl::space domain_space = iterator_domain.space(); for (int i = 1; i < isl_ast_expr_get_op_n_arg(idx_expr.get()); i++) { if (isl_space_has_dim_name(domain_space.get(), isl_dim_set, i - 1)) { - std::string original_idx_name = isl_space_get_dim_name(domain_space.get(), isl_dim_set, i - 1); - isl::ast_expr transformed_index = isl::manage(isl_ast_expr_get_op_arg(idx_expr.get(), i)); + std::string original_idx_name = + isl_space_get_dim_name(domain_space.get(), isl_dim_set, i - 1); + isl::ast_expr transformed_index = + isl::manage(isl_ast_expr_get_op_arg(idx_expr.get(), i)); VLOG(4) << "axis-" << i - 1 << " named " << original_idx_name << ", is " << isl_ast_expr_to_C_str(transformed_index.get()); iterator_map.emplace(original_idx_name, transformed_index); @@ -225,13 +251,15 @@ std::map AstGen::Impl::ExtractIslTransformedIndiceMa return iterator_map; } -const std::map& AstGen::axis2ast(const std::string& tuple_name) const { +const std::map& AstGen::axis2ast( + const std::string& tuple_name) const { auto it = impl_->transformed_indice_map_.find(tuple_name); CHECK(it != impl_->transformed_indice_map_.end()) << "no id " << tuple_name; return it->second; } -const std::map AstGen::axis2expr(const std::string& tuple_name) const { +const std::map AstGen::axis2expr( + const std::string& tuple_name) const { const auto& axis_to_ast = axis2ast(tuple_name); std::map res; for (auto item : axis_to_ast) { @@ -242,26 +270,35 @@ const std::map AstGen::axis2expr(const std::string& tuple_nam return res; } -isl::ast_expr CreateIslAstIndexExpression(isl_ast_build* build, const isl::map& access) { +isl::ast_expr CreateIslAstIndexExpression(isl_ast_build* build, + const isl::map& access) { CHECK(build); - isl::map schedule = isl::manage(isl_map_from_union_map(isl_ast_build_get_schedule(build))); + isl::map schedule = + isl::manage(isl_map_from_union_map(isl_ast_build_get_schedule(build))); // get identity access from schedule. - auto statement = isl_map_get_statement_repr(schedule.get(), isl_dim_in); - auto statement_set = isl::manage(isl_set_read_from_str(isl_map_get_ctx(schedule.get()), - utils::StringFormat("{ %s : }", statement.c_str()).c_str())); + auto statement = isl_map_get_statement_repr(schedule.get(), isl_dim_in); + auto statement_set = isl::manage(isl_set_read_from_str( + isl_map_get_ctx(schedule.get()), + utils::StringFormat("{ %s : }", statement.c_str()).c_str())); auto identity_access = isl::manage(isl_set_identity(statement_set.release())); - isl::map map = isl::manage(isl_map_reverse(schedule.copy())); - - isl::pw_multi_aff iterator_map = isl::manage(isl_pw_multi_aff_from_map(map.copy())); - isl::pw_multi_aff index_aff = isl::manage(isl_pw_multi_aff_from_map(identity_access.copy())); - - isl::space model2 = iterator_map.space(); - index_aff = isl::manage(isl_pw_multi_aff_align_params(index_aff.copy(), model2.copy())); - isl::space model = index_aff.space(); - iterator_map = isl::manage(isl_pw_multi_aff_align_params(iterator_map.copy(), model.copy())); - iterator_map = isl::manage(isl_pw_multi_aff_pullback_pw_multi_aff(index_aff.copy(), iterator_map.copy())); - isl::ast_expr index_expr = isl::manage(isl_ast_build_access_from_pw_multi_aff(build, iterator_map.copy())); + isl::map map = isl::manage(isl_map_reverse(schedule.copy())); + + isl::pw_multi_aff iterator_map = + isl::manage(isl_pw_multi_aff_from_map(map.copy())); + isl::pw_multi_aff index_aff = + isl::manage(isl_pw_multi_aff_from_map(identity_access.copy())); + + isl::space model2 = iterator_map.space(); + index_aff = isl::manage( + isl_pw_multi_aff_align_params(index_aff.copy(), model2.copy())); + isl::space model = index_aff.space(); + iterator_map = isl::manage( + isl_pw_multi_aff_align_params(iterator_map.copy(), model.copy())); + iterator_map = isl::manage(isl_pw_multi_aff_pullback_pw_multi_aff( + index_aff.copy(), iterator_map.copy())); + isl::ast_expr index_expr = isl::manage( + isl_ast_build_access_from_pw_multi_aff(build, iterator_map.copy())); return index_expr; } @@ -278,7 +315,7 @@ namespace detail { std::string GetTupleName(isl_ast_node* node) { auto expr = isl::manage(isl_ast_node_user_get_expr(node)); - auto arg = isl::manage(isl_ast_expr_get_op_arg(expr.get(), 0)); + auto arg = isl::manage(isl_ast_expr_get_op_arg(expr.get(), 0)); auto name = isl_id_get_name(isl_ast_expr_get_id(arg.get())); return name; } @@ -322,7 +359,8 @@ void IslAstNodeToCinnExpr(const isl::ast_node& node, ir::Expr* expr) { // EatMark(node, expr); } break; default: - LOG(FATAL) << "Unexpected ISL node type " << isl_ast_node_get_type(node.get()); + LOG(FATAL) << "Unexpected ISL node type " + << isl_ast_node_get_type(node.get()); break; } } @@ -333,10 +371,12 @@ void EatBlock(const isl::ast_node& node, ir::Expr* expr) { CHECK(!node.is_null()); CHECK(expr); CHECK_EQ(isl_ast_node_get_type(node.get()), isl_ast_node_block); - isl::ast_node_list list = isl::manage(isl_ast_node_block_get_children(node.get())); + isl::ast_node_list list = + isl::manage(isl_ast_node_block_get_children(node.get())); std::vector exprs; for (int i = 0; i < isl_ast_node_list_n_ast_node(list.get()); i++) { - isl::ast_node child = isl::manage(isl_ast_node_list_get_ast_node(list.get(), i)); + isl::ast_node child = + isl::manage(isl_ast_node_list_get_ast_node(list.get(), i)); // visit child ir::Expr child_expr; IslAstNodeToCinnExpr(child, &child_expr); @@ -355,15 +395,16 @@ void EatFor(const isl::ast_node& node, ir::Expr* expr) { CHECK_EQ(isl_ast_node_get_type(node.get()), isl_ast_node_for); // iter name - isl::ast_expr iter = isl::manage(isl_ast_node_for_get_iterator(node.get())); - isl::id iter_id = isl::manage(isl_ast_expr_get_id(iter.get())); + isl::ast_expr iter = isl::manage(isl_ast_node_for_get_iterator(node.get())); + isl::id iter_id = isl::manage(isl_ast_expr_get_id(iter.get())); std::string iter_name = iter_id.name(); // get condition - isl::ast_expr condition = isl::manage(isl_ast_node_for_get_cond(node.get())); + isl::ast_expr condition = isl::manage(isl_ast_node_for_get_cond(node.get())); isl::ast_expr incrementor = isl::manage(isl_ast_node_for_get_inc(node.get())); - isl::ast_expr initializer = isl::manage(isl_ast_node_for_get_init(node.get())); - isl::ast_node body = isl::manage(isl_ast_node_for_get_body(node.get())); + isl::ast_expr initializer = + isl::manage(isl_ast_node_for_get_init(node.get())); + isl::ast_node body = isl::manage(isl_ast_node_for_get_body(node.get())); ir::Expr ir_body; IslAstNodeToCinnExpr(body, &ir_body); @@ -384,8 +425,13 @@ void EatFor(const isl::ast_node& node, ir::Expr* expr) { ir::Var ir_iter(iter_name); - *expr = ir::PolyFor::Make( - ir::Var(iter_name), ir_initializer, ir_condition, ir_inc, ir::ForType::Serial, ir::DeviceAPI ::Host, ir_body); + *expr = ir::PolyFor::Make(ir::Var(iter_name), + ir_initializer, + ir_condition, + ir_inc, + ir::ForType::Serial, + ir::DeviceAPI ::Host, + ir_body); } void EatIf(const isl::ast_node& node, ir::Expr* expr) { @@ -416,11 +462,11 @@ void IslAstExprToCinnExpr(const isl::ast_expr& node, ir::Expr* expr) { switch (isl_ast_expr_get_type(node.get())) { case isl_ast_expr_int: { isl::val val = isl::manage(isl_ast_expr_get_val(node.get())); - *expr = ir::Expr(static_cast(isl_val_get_num_si(val.get()))); + *expr = ir::Expr(static_cast(isl_val_get_num_si(val.get()))); } break; case isl_ast_expr_id: { isl::id id = isl::manage(isl_ast_expr_get_id(node.get())); - *expr = ir::Var(id.name()); + *expr = ir::Var(id.name()); } break; case isl_ast_expr_op: { std::vector ops; @@ -428,7 +474,8 @@ void IslAstExprToCinnExpr(const isl::ast_expr& node, ir::Expr* expr) { for (int i = 0; i < n_args; i++) { ir::Expr op; - isl::ast_expr expr0 = isl::manage(isl_ast_expr_get_op_arg(node.get(), i)); + isl::ast_expr expr0 = + isl::manage(isl_ast_expr_get_op_arg(node.get(), i)); IslAstExprToCinnExpr(expr0, &op); ops.push_back(op); } @@ -501,13 +548,20 @@ void IslAstExprToCinnExpr(const isl::ast_expr& node, ir::Expr* expr) { std::string caller = caller_expr.As()->name; ops.erase(ops.begin()); // NOTE the type here is not important. - *expr = ir::Call::Make(Float(32), caller, ops, {}, ir::CallType::ISL, ir::FunctionRef(), 0); + *expr = ir::Call::Make(Float(32), + caller, + ops, + {}, + ir::CallType::ISL, + ir::FunctionRef(), + 0); } break; case isl_ast_op_fdiv_q: *expr = ir::Div::Make(ops[0], ops[1]); break; case isl_ast_op_select: - CHECK_EQ(ops.size(), 3UL) << "In ir::Select, the ops size should be 3"; + CHECK_EQ(ops.size(), 3UL) + << "In ir::Select, the ops size should be 3"; ops[0]->set_type(Bool()); *expr = ir::Select::Make(ops[0], ops[1], ops[2]); break; @@ -520,13 +574,15 @@ void IslAstExprToCinnExpr(const isl::ast_expr& node, ir::Expr* expr) { } } -void AddUnitLoopOfDomain(const isl::ast_node& node, const isl::set& domain, ir::Expr* expr) { +void AddUnitLoopOfDomain(const isl::ast_node& node, + const isl::set& domain, + ir::Expr* expr) { std::vector dim_names = isl_get_dim_names(domain); std::vector> dim_min_max; for (int i = 0; i < dim_names.size(); ++i) { auto minv_maxv = isl_set_get_axis_range(domain.get(), i); - int min_iv = std::get<0>(minv_maxv).get_num_si(); - int max_iv = std::get<1>(minv_maxv).get_num_si(); + int min_iv = std::get<0>(minv_maxv).get_num_si(); + int max_iv = std::get<1>(minv_maxv).get_num_si(); dim_min_max.emplace_back(i, min_iv, max_iv); } @@ -534,7 +590,9 @@ void AddUnitLoopOfDomain(const isl::ast_node& node, const isl::set& domain, ir:: mutator(expr); } -void IslAstNodeToCinnExpr(const isl::ast_node& node, const isl::union_set& domain, ir::Expr* expr) { +void IslAstNodeToCinnExpr(const isl::ast_node& node, + const isl::union_set& domain, + ir::Expr* expr) { IslAstNodeToCinnExpr(node, expr); isl_set_list* set_list = isl_union_set_get_set_list(domain.get()); @@ -552,13 +610,19 @@ void AstGen::Impl::InitIslAstConfig() { isl_options_set_ast_build_allow_else(ctx().get(), 1); } -AstGen::AstGen(const isl::set& context, const std::vector& stages, const poly::ScheduleGroup& group) +AstGen::AstGen(const isl::set& context, + const std::vector& stages, + const poly::ScheduleGroup& group) : impl_(new Impl(context, group)) { for (auto* x : stages) impl_->stages_.emplace_back(x); impl_->InitIslAstConfig(); } -void AstGen::SetBuildOptions(const isl::union_map& options) { impl_->build_options_ = options; } -bool AstGen::ContainsStatement(const std::string& name) const { return impl_->transformed_indice_map_.count(name); } +void AstGen::SetBuildOptions(const isl::union_map& options) { + impl_->build_options_ = options; +} +bool AstGen::ContainsStatement(const std::string& name) const { + return impl_->transformed_indice_map_.count(name); +} AstGen::~AstGen() {} diff --git a/paddle/cinn/poly/ast_gen.h b/paddle/cinn/poly/ast_gen.h index faa28e5a30c68..48e295930d93b 100644 --- a/paddle/cinn/poly/ast_gen.h +++ b/paddle/cinn/poly/ast_gen.h @@ -13,8 +13,8 @@ // limitations under the License. /** - * This file implements the isl AST build interface, it helps to generate isl AST given the polyhedral domain and - * schedule. + * This file implements the isl AST build interface, it helps to generate isl + * AST given the polyhedral domain and schedule. */ #pragma once #include @@ -40,7 +40,9 @@ static const char* kIslParamConstPrefix = "_const_"; */ class AstGen { public: - AstGen(const isl::set& context, const std::vector& stages, const poly::ScheduleGroup& group); + AstGen(const isl::set& context, + const std::vector& stages, + const poly::ScheduleGroup& group); ~AstGen(); /** @@ -54,10 +56,13 @@ class AstGen { isl::ast_node Build(); - //! Get the map from original CINN iterators to the transformed actual ISL ast nodes. - const std::map& axis2ast(const std::string& tuple_name) const; + //! Get the map from original CINN iterators to the transformed actual ISL ast + //! nodes. + const std::map& axis2ast( + const std::string& tuple_name) const; - const std::map axis2expr(const std::string& tuple_name) const; + const std::map axis2expr( + const std::string& tuple_name) const; bool ContainsStatement(const std::string& name) const; @@ -70,13 +75,17 @@ class AstGen { std::unique_ptr impl_; }; -void AddUnitLoopOfDomain(const isl::ast_node& node, const isl::set& domain, ir::Expr* expr); +void AddUnitLoopOfDomain(const isl::ast_node& node, + const isl::set& domain, + ir::Expr* expr); /** * Transform the isl ast to Expr. */ void IslAstNodeToCinnExpr(const isl::ast_node& node, ir::Expr* expr); -void IslAstNodeToCinnExpr(const isl::ast_node& node, const isl::union_set& domain, ir::Expr* expr); +void IslAstNodeToCinnExpr(const isl::ast_node& node, + const isl::union_set& domain, + ir::Expr* expr); void IslAstExprToCinnExpr(const isl::ast_expr& node, ir::Expr* expr); /** diff --git a/paddle/cinn/poly/ast_gen_test.cc b/paddle/cinn/poly/ast_gen_test.cc index 308dd9fe15ecf..0add6b1ce2137 100644 --- a/paddle/cinn/poly/ast_gen_test.cc +++ b/paddle/cinn/poly/ast_gen_test.cc @@ -42,7 +42,8 @@ TEST(TransIdentityExtentToContextId, basic) { LOG(INFO) << new_set; ASSERT_EQ(utils::GetStreamCnt(new_set), - "[_const_0] -> { s[i, j, k] : _const_0 <= 1 and 0 <= i <= 11 and 0 <= j <= _const_0 and 13 <= k <= 31 }"); + "[_const_0] -> { s[i, j, k] : _const_0 <= 1 and 0 <= i <= 11 and 0 " + "<= j <= _const_0 and 13 <= k <= 31 }"); } TEST(TransIdentityExtentToContextId, basic1) { @@ -69,18 +70,24 @@ TEST(AstGen_Build, not_delete_length1_loop) { len1_shape[i] = Expr(1); } } - LOG(INFO) << "index_length1 hint = " << index_length1[0] << index_length1[1] << index_length1[2] - << index_length1[3]; + LOG(INFO) << "index_length1 hint = " << index_length1[0] + << index_length1[1] << index_length1[2] << index_length1[3]; Placeholder A("A", len1_shape); Tensor B = lang::Compute( - len1_shape, [&](const std::vector& indice) { return lang::Relu(A(indice), 0); }, "relu_test"); + len1_shape, + [&](const std::vector& indice) { + return lang::Relu(A(indice), 0); + }, + "relu_test"); StageMap stage_map = CreateStages({B}); std::vector stages; stages.push_back(stage_map[B]); - std::unique_ptr schedule = - poly::CreateSchedule(stages, poly::ScheduleKind::Poly, std::vector>()); + std::unique_ptr schedule = poly::CreateSchedule( + stages, + poly::ScheduleKind::Poly, + std::vector>()); for (auto& group : schedule->groups) { isl::set context(Context::isl_ctx(), "{:}"); @@ -94,8 +101,8 @@ TEST(AstGen_Build, not_delete_length1_loop) { std::stringstream ss; ss << e; - std::string expr_str = ss.str(); - std::string target_str = R"ROC(poly_for (i, 0, (i <= 9), 1) + std::string expr_str = ss.str(); + std::string target_str = R"ROC(poly_for (i, 0, (i <= 9), 1) { poly_for (j, 0, (j <= 9), 1) { @@ -108,12 +115,12 @@ TEST(AstGen_Build, not_delete_length1_loop) { } } })ROC"; - int pos = -1; + int pos = -1; std::vector iterator_names = {'i', 'j', 'k', 'a'}; for (int i = 0; i < origin_shape.size(); ++i) { pos = target_str.find("9", pos + 1); if (index_length1[i] == 1) { - target_str[pos] = '0'; + target_str[pos] = '0'; target_str[target_str.rfind(iterator_names[i])] = '0'; } } diff --git a/paddle/cinn/poly/compute_at_transform.cc b/paddle/cinn/poly/compute_at_transform.cc index 05bcc61b283ae..cc2023c3c2014 100755 --- a/paddle/cinn/poly/compute_at_transform.cc +++ b/paddle/cinn/poly/compute_at_transform.cc @@ -19,7 +19,7 @@ namespace poly { void ComputeAtTransform::AdjustPdomain() { isl::map ct_with_params = ctransform_with_params(); - isl::set ct_domain = ct_with_params.domain(); + isl::set ct_domain = ct_with_params.domain(); isl::set cdomain1 = isl::manage(AddParamsTo(cdomain_.copy())); @@ -38,7 +38,8 @@ void ComputeAtTransform::AdjustPdomain() { auto pdomain_params = isl::manage(AddParamsTo(pdomain_.copy())); VLOG(4) << "pdomain: " << pdomain; VLOG(4) << "pdomain_params: " << pdomain_params; - adjusted_pdomain_ = isl::manage(isl_set_intersect(pdomain.release(), pdomain_params.release())); + adjusted_pdomain_ = isl::manage( + isl_set_intersect(pdomain.release(), pdomain_params.release())); adjusted_pdomain_ = isl::manage(isl_simplify(adjusted_pdomain_.release())); VLOG(4) << "adjusted pdomain: " << adjusted_pdomain_; } @@ -49,50 +50,69 @@ void ComputeAtTransform::AdjustPtransform() { { // insert empty dims to ptransform's range adjusted_ptransform_ = ptransform_; - adjusted_ptransform_ = isl::manage(isl_map_insert_dims(adjusted_ptransform_.release(), isl_dim_out, 0, level_ + 1)); + adjusted_ptransform_ = isl::manage(isl_map_insert_dims( + adjusted_ptransform_.release(), isl_dim_out, 0, level_ + 1)); // update the tuple name - adjusted_ptransform_ = isl::manage(isl_map_set_tuple_name(adjusted_ptransform_.release(), isl_dim_in, ptuple())); - adjusted_ptransform_ = isl::manage(isl_map_set_tuple_name(adjusted_ptransform_.release(), isl_dim_out, ptuple())); + adjusted_ptransform_ = isl::manage(isl_map_set_tuple_name( + adjusted_ptransform_.release(), isl_dim_in, ptuple())); + adjusted_ptransform_ = isl::manage(isl_map_set_tuple_name( + adjusted_ptransform_.release(), isl_dim_out, ptuple())); } { - // make ctransform range the same space with ptransform's range so that we can copy the dims - isl::set ct_range = cdomain_.apply(ctransform_); + // make ctransform range the same space with ptransform's range so that we + // can copy the dims + isl::set ct_range = cdomain_.apply(ctransform_); isl::set ct_range1 = isl::manage(isl_set_project_out( - ct_range.release(), isl_dim_set, level_ + 1, isl_set_dim(ct_range.get(), isl_dim_set) - level_ - 1)); - ct_range1 = isl::manage(isl_set_add_dims( - ct_range1.release(), isl_dim_set, isl_map_dim(adjusted_ptransform_.get(), isl_dim_out) - level_ - 1)); + ct_range.release(), + isl_dim_set, + level_ + 1, + isl_set_dim(ct_range.get(), isl_dim_set) - level_ - 1)); + ct_range1 = isl::manage(isl_set_add_dims( + ct_range1.release(), + isl_dim_set, + isl_map_dim(adjusted_ptransform_.get(), isl_dim_out) - level_ - 1)); // set as the producer's tuple to make a same space - ct_range1 = isl::manage(isl_set_set_tuple_name(ct_range1.release(), ptuple())); + ct_range1 = + isl::manage(isl_set_set_tuple_name(ct_range1.release(), ptuple())); adjusted_ptransform_ = adjusted_ptransform_.intersect_range(ct_range1); VLOG(4) << "adjusted_ptransform: " << adjusted_ptransform_; } { // add params - adjusted_ptransform_ = isl::manage(AddParamsTo(adjusted_ptransform_.release())); + adjusted_ptransform_ = + isl::manage(AddParamsTo(adjusted_ptransform_.release())); } } isl::set ComputeAtTransform::cdomain_with_params() { // add level+1 param to consumer transform - isl::set cd_with_params = isl::manage(isl_set_add_dims(cdomain_.copy(), isl_dim_param, level_ + 1)); + isl::set cd_with_params = + isl::manage(isl_set_add_dims(cdomain_.copy(), isl_dim_param, level_ + 1)); return cd_with_params; } isl::map ComputeAtTransform::ctransform_with_params() { // add level+1 param to consumer transform - int num_existing_param = isl_map_dim(ctransform_.get(), isl_dim_param); + int num_existing_param = isl_map_dim(ctransform_.get(), isl_dim_param); isl::map ct_with_params = isl::manage(AddParamsTo(ctransform_.copy())); { - isl_local_space* local_space = isl_local_space_from_space(ct_with_params.space().release()); + isl_local_space* local_space = + isl_local_space_from_space(ct_with_params.space().release()); for (int i = 0; i < level_ + 1; i++) { - isl_constraint* cst = isl_constraint_alloc_equality(isl_local_space_copy(local_space)); - cst = isl_constraint_set_coefficient_val( - cst, isl_dim_param, num_existing_param + i, isl_val_int_from_si(ctransform_.ctx().get(), -1)); - cst = isl_constraint_set_coefficient_val(cst, isl_dim_out, i, isl_val_int_from_si(ctransform_.ctx().get(), 1)); - ct_with_params = isl::manage(isl_map_add_constraint(ct_with_params.release(), cst)); + isl_constraint* cst = + isl_constraint_alloc_equality(isl_local_space_copy(local_space)); + cst = isl_constraint_set_coefficient_val( + cst, + isl_dim_param, + num_existing_param + i, + isl_val_int_from_si(ctransform_.ctx().get(), -1)); + cst = isl_constraint_set_coefficient_val( + cst, isl_dim_out, i, isl_val_int_from_si(ctransform_.ctx().get(), 1)); + ct_with_params = + isl::manage(isl_map_add_constraint(ct_with_params.release(), cst)); } isl_local_space_free(local_space); } @@ -107,18 +127,24 @@ void ComputeAtTransform::DisplayC(isl_map* pschedule, isl_map* cschedule) { auto adjusted_ptransform = adjusted_ptransform_; if (cschedule) { - adjusted_ctransform = isl::manage(isl_map_apply_range(adjusted_ctransform.release(), cschedule)); + adjusted_ctransform = isl::manage( + isl_map_apply_range(adjusted_ctransform.release(), cschedule)); } if (pschedule) { - adjusted_ptransform = isl::manage(isl_map_apply_range(adjusted_ptransform.release(), pschedule)); + adjusted_ptransform = isl::manage( + isl_map_apply_range(adjusted_ptransform.release(), pschedule)); } - auto whole_domain = isl::manage(isl_union_set_from_set(adjusted_pdomain_.copy())); - whole_domain = isl::manage(isl_union_set_add_set(whole_domain.release(), adjusted_cdomain_.copy())); + auto whole_domain = + isl::manage(isl_union_set_from_set(adjusted_pdomain_.copy())); + whole_domain = isl::manage( + isl_union_set_add_set(whole_domain.release(), adjusted_cdomain_.copy())); VLOG(3) << "whole domain: " << whole_domain; - auto whole_schedule = isl::manage(isl_union_map_from_map(adjusted_ptransform.copy())); - whole_schedule = isl::manage(isl_union_map_add_map(whole_schedule.release(), adjusted_ctransform.copy())); + auto whole_schedule = + isl::manage(isl_union_map_from_map(adjusted_ptransform.copy())); + whole_schedule = isl::manage(isl_union_map_add_map( + whole_schedule.release(), adjusted_ctransform.copy())); VLOG(3) << "whole_schedule: " << whole_schedule; isl::set context(whole_domain.ctx(), "{:}"); @@ -126,7 +152,8 @@ void ComputeAtTransform::DisplayC(isl_map* pschedule, isl_map* cschedule) { auto intersect_schedule = whole_schedule.intersect_domain(whole_domain); auto* build = isl_ast_build_from_context(context.release()); - auto* node = isl_ast_build_node_from_schedule_map(build, intersect_schedule.release()); + auto* node = + isl_ast_build_node_from_schedule_map(build, intersect_schedule.release()); VLOG(3) << "code:\n\n" << isl_ast_node_to_C_str(node); @@ -135,30 +162,36 @@ void ComputeAtTransform::DisplayC(isl_map* pschedule, isl_map* cschedule) { isl_set* ComputeAtTransform::AddParamsTo(isl_set* set) { int existing_params = isl_set_dim(set, isl_dim_param); - set = isl_set_add_dims(set, isl_dim_param, level_ + 1); + set = isl_set_add_dims(set, isl_dim_param, level_ + 1); // set name for (int i = 0; i < level_ + 1; i++) { std::string pname = GenConsumerParamName(ctuple(), i); - set = isl_set_set_dim_name(set, isl_dim_param, existing_params + i, pname.c_str()); + set = isl_set_set_dim_name( + set, isl_dim_param, existing_params + i, pname.c_str()); } return set; } isl_map* ComputeAtTransform::AddParamsTo(isl_map* map) { int existing_params = isl_map_dim(map, isl_dim_param); - map = isl_map_add_dims(map, isl_dim_param, level_ + 1); + map = isl_map_add_dims(map, isl_dim_param, level_ + 1); // set name for (int i = 0; i < level_ + 1; i++) { std::string pname = GenConsumerParamName(ctuple(), i); - map = isl_map_set_dim_name(map, isl_dim_param, existing_params + i, pname.c_str()); + map = isl_map_set_dim_name( + map, isl_dim_param, existing_params + i, pname.c_str()); } return map; } -ComputeAtTransform::ComputeAtTransform( - isl::set pdomain, isl::set cdomain, isl::map access, isl::map ptransform, isl::map ctransform, int level) +ComputeAtTransform::ComputeAtTransform(isl::set pdomain, + isl::set cdomain, + isl::map access, + isl::map ptransform, + isl::map ctransform, + int level) : pdomain_(pdomain), cdomain_(cdomain), access_(access), @@ -172,7 +205,7 @@ ComputeAtTransform::ComputeAtTransform( VLOG(2) << "access: " << access; adjusted_ctransform_ = isl::manage(AddParamsTo(ctransform_.copy())); - adjusted_cdomain_ = isl::manage(AddParamsTo(cdomain_.copy())); + adjusted_cdomain_ = isl::manage(AddParamsTo(cdomain_.copy())); } std::string GenConsumerParamName(const char* tuple, int id) { @@ -181,13 +214,18 @@ std::string GenConsumerParamName(const char* tuple, int id) { std::vector ComputeAtTransform::GetProducerAdjustedShape() const { VLOG(3) << "domain: " << adjusted_pdomain(); - isl::set param_limit = isl::manage(isl_set_universe(adjusted_pdomain().space().release())); + isl::set param_limit = + isl::manage(isl_set_universe(adjusted_pdomain().space().release())); // set all the params to 0 - isl_local_space* local_space = isl_local_space_from_space(param_limit.space().release()); + isl_local_space* local_space = + isl_local_space_from_space(param_limit.space().release()); for (int i = 0; i < isl_set_dim(param_limit.get(), isl_dim_param); i++) { - isl_constraint* cst = isl_constraint_alloc_equality(isl_local_space_copy(local_space)); - cst = isl_constraint_set_coefficient_val(cst, isl_dim_param, i, isl_val_int_from_si(ctransform_.ctx().get(), 1)); - param_limit = isl::manage(isl_set_add_constraint(param_limit.release(), cst)); + isl_constraint* cst = + isl_constraint_alloc_equality(isl_local_space_copy(local_space)); + cst = isl_constraint_set_coefficient_val( + cst, isl_dim_param, i, isl_val_int_from_si(ctransform_.ctx().get(), 1)); + param_limit = + isl::manage(isl_set_add_constraint(param_limit.release(), cst)); } VLOG(3) << "param_limit: " << param_limit; @@ -197,15 +235,16 @@ std::vector ComputeAtTransform::GetProducerAdjustedShape() const { // collect the min and max and get the num elements for each axis. for (int i = 0; i < isl_set_dim(domain.get(), isl_dim_set); i++) { auto _minv_maxv_ = isl_set_get_axis_range(domain.get(), i); - auto& minv = std::get<0>(_minv_maxv_); - auto& maxv = std::get<1>(_minv_maxv_); + auto& minv = std::get<0>(_minv_maxv_); + auto& maxv = std::get<1>(_minv_maxv_); int num_elements = maxv.num_si() - minv.num_si() + 1; shape.push_back(num_elements); } return shape; } -std::vector ComputeAtTransform::GetAccessesPrecedingIndicesMinAssumingParamsZero() { +std::vector +ComputeAtTransform::GetAccessesPrecedingIndicesMinAssumingParamsZero() { std::vector res; isl::set cdomain_with_param = isl::manage(AddParamsTo(cdomain_.copy())); @@ -218,11 +257,15 @@ std::vector ComputeAtTransform::GetAccessesPrecedingIndicesMinAssumingParam isl::set access_domain = param_limited_cdomain.apply(access_with_param); // set all the params to 0 - isl_local_space* local_space = isl_local_space_from_space(access_domain.space().release()); + isl_local_space* local_space = + isl_local_space_from_space(access_domain.space().release()); for (int i = 0; i < isl_set_dim(access_domain.get(), isl_dim_param); i++) { - isl_constraint* cst = isl_constraint_alloc_equality(isl_local_space_copy(local_space)); - cst = isl_constraint_set_coefficient_val(cst, isl_dim_param, i, isl_val_int_from_si(ctransform_.ctx().get(), 1)); - access_domain = isl::manage(isl_set_add_constraint(access_domain.release(), cst)); + isl_constraint* cst = + isl_constraint_alloc_equality(isl_local_space_copy(local_space)); + cst = isl_constraint_set_coefficient_val( + cst, isl_dim_param, i, isl_val_int_from_si(ctransform_.ctx().get(), 1)); + access_domain = + isl::manage(isl_set_add_constraint(access_domain.release(), cst)); } isl_local_space_free(local_space); @@ -232,8 +275,8 @@ std::vector ComputeAtTransform::GetAccessesPrecedingIndicesMinAssumingParam for (int i = 0; i < level_ + 1; i++) { auto _minv_maxv_ = isl_set_get_axis_range(access_domain.get(), i); - auto& minv = std::get<0>(_minv_maxv_); - auto& maxv = std::get<1>(_minv_maxv_); + auto& minv = std::get<0>(_minv_maxv_); + auto& maxv = std::get<1>(_minv_maxv_); res.push_back(minv.get_num_si()); } diff --git a/paddle/cinn/poly/compute_at_transform.h b/paddle/cinn/poly/compute_at_transform.h index 6c650bbcd509f..b8ad01024d433 100644 --- a/paddle/cinn/poly/compute_at_transform.h +++ b/paddle/cinn/poly/compute_at_transform.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! This file implements the class ComputeAtTransform, which help to perform the isl transformation in `compute_at` -//! optimization. +//! This file implements the class ComputeAtTransform, which help to perform the +//! isl transformation in `compute_at` optimization. #pragma once #include @@ -46,24 +46,31 @@ std::string GenConsumerParamName(const char* tuple, int id); /** * \brief The ComputeAt transform implemented in polyhedral way. * - * The current implementation for `ComputeAt` schedule primitive is quite complex, it contains the polyhedral transform - * before the AST generation, and the several passes after AST generation. This class only contains the polyhedral - * transform: + * The current implementation for `ComputeAt` schedule primitive is quite + * complex, it contains the polyhedral transform before the AST generation, and + * the several passes after AST generation. This class only contains the + * polyhedral transform: * 1. Adjust the producer's domain by the consume accesses. * 2. Adjust the producer's transform by - * a. Insert the preceding level+1 consumer axis to the head of the original producer transform's domain, to make it - * compute in the level of consumer forloops. b. - * b. Adjust the range of the producer's transform by fixing the preceding axis(from the previous step). + * a. Insert the preceding level+1 consumer axis to the head of the original + * producer transform's domain, to make it compute in the level of consumer + * forloops. b. b. Adjust the range of the producer's transform by fixing the + * preceding axis(from the previous step). * * The latter process after the execution of this class remains, including * 1. Get the adjusted shape of the producer after compute_at * 2. Update the adjusted buffer's shape - * 3. Normalize the accesses of the consumers(by making the leftmost access start from zero). + * 3. Normalize the accesses of the consumers(by making the leftmost access + * start from zero). */ class ComputeAtTransform { public: - ComputeAtTransform( - isl::set pdomain, isl::set cdomain, isl::map access, isl::map ptransform, isl::map ctransform, int level); + ComputeAtTransform(isl::set pdomain, + isl::set cdomain, + isl::map access, + isl::map ptransform, + isl::map ctransform, + int level); void operator()() { AdjustPdomain(); @@ -74,13 +81,15 @@ class ComputeAtTransform { const isl::map& adjusted_ptransform() const { return adjusted_ptransform_; } //! Display C code - void DisplayC(isl_map* __isl_give pschedule = nullptr, isl_map* __isl_give cschedule = nullptr); + void DisplayC(isl_map* __isl_give pschedule = nullptr, + isl_map* __isl_give cschedule = nullptr); //! Re-calculate the producer buffer shape after compute_at transform. std::vector GetProducerAdjustedShape() const; - //! Get the the minimum of the preceding level+1 axis in accesses by assuming all the isl param is zero(for the - //! consumer, the preceding level+1 axis is fixed in producer computation). + //! Get the the minimum of the preceding level+1 axis in accesses by assuming + //! all the isl param is zero(for the consumer, the preceding level+1 axis is + //! fixed in producer computation). std::vector GetAccessesPrecedingIndicesMinAssumingParamsZero(); protected: diff --git a/paddle/cinn/poly/compute_at_transform_test.cc b/paddle/cinn/poly/compute_at_transform_test.cc index 825a9738f3517..dd3168b2bd33d 100644 --- a/paddle/cinn/poly/compute_at_transform_test.cc +++ b/paddle/cinn/poly/compute_at_transform_test.cc @@ -22,23 +22,30 @@ namespace poly { TEST(ComputeAtTransform2, basic) { isl::ctx ctx(isl_ctx_alloc()); isl::set pdomain(ctx, "{ p[i,j]: 0<=i,j<100 }"); - isl::map ptransform(ctx, "{ p[i,j]->p[t0,t1,t2]: t0=i%4 and t1=i/4 and t2=j }"); + isl::map ptransform(ctx, + "{ p[i,j]->p[t0,t1,t2]: t0=i%4 and t1=i/4 and t2=j }"); isl::set cdomain(ctx, "{ c[i,j,k]: 0<=i,j,k<50 }"); - isl::map ctransform(ctx, "{ c[i,j,k]->c[t0,t1,t2,t3]: t0=i/4 and t1=i%4 and t2=j and t3=k }"); + isl::map ctransform( + ctx, "{ c[i,j,k]->c[t0,t1,t2,t3]: t0=i/4 and t1=i%4 and t2=j and t3=k }"); - isl::map access(ctx, "{ c[i,j,k]->p[i,j]; c[i,j,k]->p[i+1,j]; c[i,j,k]->p[i-1,j] }"); + isl::map access( + ctx, "{ c[i,j,k]->p[i,j]; c[i,j,k]->p[i+1,j]; c[i,j,k]->p[i-1,j] }"); - poly::ComputeAtTransform t(pdomain, cdomain, access, ptransform, ctransform, 1); + poly::ComputeAtTransform t( + pdomain, cdomain, access, ptransform, ctransform, 1); t(); t.DisplayC(); isl::map pschedule(ctx, - "{ p[i0,i1,i2,i3,i4] -> [t0,t1,t1t, t2,t3,t4,t5]: t0=i0 and t1=i1 and t2=i2 and t3=i3 and t4=i4 " + "{ p[i0,i1,i2,i3,i4] -> [t0,t1,t1t, t2,t3,t4,t5]: t0=i0 " + "and t1=i1 and t2=i2 and t3=i3 and t4=i4 " "and t5=0 and t1t=0 }"); - isl::map cschedule(ctx, - "[_c_0,_c_1] -> { c[i0,i1,i2,i3] -> [t0,t1,t1t,t2,t3,t4,t5]: t0=i0 and t1=i1 and t2=i2 and t3=i3 " - "and t4=0 and t5=0 and t1t=1 }"); + isl::map cschedule( + ctx, + "[_c_0,_c_1] -> { c[i0,i1,i2,i3] -> [t0,t1,t1t,t2,t3,t4,t5]: t0=i0 and " + "t1=i1 and t2=i2 and t3=i3 " + "and t4=0 and t5=0 and t1t=1 }"); t.DisplayC(pschedule.release(), cschedule.release()); diff --git a/paddle/cinn/poly/dim.cc b/paddle/cinn/poly/dim.cc index 5094da439ef2a..e72a3e5ab264c 100644 --- a/paddle/cinn/poly/dim.cc +++ b/paddle/cinn/poly/dim.cc @@ -22,8 +22,10 @@ namespace cinn { namespace poly { std::string Dim::range_repr() const { - return utils::StringFormat( - "%s <= %s <= %s", utils::GetStreamCnt(lower_bound).c_str(), id.c_str(), utils::GetStreamCnt(upper_bound).c_str()); + return utils::StringFormat("%s <= %s <= %s", + utils::GetStreamCnt(lower_bound).c_str(), + id.c_str(), + utils::GetStreamCnt(upper_bound).c_str()); } Dim::Dim(std::string id, ir::Expr lower_bound, ir::Expr upper_bound) diff --git a/paddle/cinn/poly/dim.h b/paddle/cinn/poly/dim.h index c0e2896a95d24..6b197eaf214ff 100644 --- a/paddle/cinn/poly/dim.h +++ b/paddle/cinn/poly/dim.h @@ -58,7 +58,9 @@ struct Dim { //! Return the range composed of (lower_bound, upper_bound). range_t range() const { return std::make_pair(lower_bound, upper_bound); } - bool is_param() const { return !lower_bound.defined() && !lower_bound.defined(); } + bool is_param() const { + return !lower_bound.defined() && !lower_bound.defined(); + } //! Return the ISL style range representation, such as '0 <= i <= 20'. std::string range_repr() const; diff --git a/paddle/cinn/poly/domain.cc b/paddle/cinn/poly/domain.cc index 56a7b4d2c8f89..09f988b920620 100644 --- a/paddle/cinn/poly/domain.cc +++ b/paddle/cinn/poly/domain.cc @@ -32,21 +32,32 @@ namespace poly { std::string Domain::__str__() const { CHECK(!id.empty()) << "domain's id is empty"; std::vector range_fields; - std::transform( - dims.begin(), dims.end(), std::back_inserter(range_fields), [](const Dim& x) { return x.range_repr(); }); + std::transform(dims.begin(), + dims.end(), + std::back_inserter(range_fields), + [](const Dim& x) { return x.range_repr(); }); std::string range_repr = utils::Join(range_fields, " and "); std::vector dim_fields; - std::transform(dims.begin(), dims.end(), std::back_inserter(dim_fields), [](const Dim& x) { return x.id; }); + std::transform(dims.begin(), + dims.end(), + std::back_inserter(dim_fields), + [](const Dim& x) { return x.id; }); std::string dims_repr = utils::Join(dim_fields, ", "); // parameters std::vector param_reprs; - std::transform(params.begin(), params.end(), std::back_inserter(param_reprs), [](const Dim& x) { return x.id; }); + std::transform(params.begin(), + params.end(), + std::back_inserter(param_reprs), + [](const Dim& x) { return x.id; }); std::string params_repr = utils::Join(param_reprs, ", "); - return utils::StringFormat( - "[%s]->{ %s[%s]: %s }", params_repr.c_str(), id.c_str(), dims_repr.c_str(), range_repr.c_str()); + return utils::StringFormat("[%s]->{ %s[%s]: %s }", + params_repr.c_str(), + id.c_str(), + dims_repr.c_str(), + range_repr.c_str()); } isl::set Domain::to_isl() const { @@ -59,7 +70,8 @@ void Domain::ExtractParams() { std::unordered_set var_names; auto collect_param_fn = [&](Expr& e) { if (!e.is_constant()) { - auto vars = ir::CollectIRNodes(e, [](const Expr* e) { return e->is_var(); }); + auto vars = + ir::CollectIRNodes(e, [](const Expr* e) { return e->is_var(); }); for (auto& var : vars) var_names.insert(var.As()->name); } }; diff --git a/paddle/cinn/poly/domain.h b/paddle/cinn/poly/domain.h index 6511b6ca37c15..41e21db5c3dfd 100644 --- a/paddle/cinn/poly/domain.h +++ b/paddle/cinn/poly/domain.h @@ -35,7 +35,8 @@ struct Domain { //! The ISL context. isl::ctx ctx; - Domain(isl::ctx ctx, std::string id, std::vector dims) : ctx(ctx), id(std::move(id)), dims(std::move(dims)) { + Domain(isl::ctx ctx, std::string id, std::vector dims) + : ctx(ctx), id(std::move(id)), dims(std::move(dims)) { ExtractParams(); } diff --git a/paddle/cinn/poly/domain_add_unit_loop_mutator.cc b/paddle/cinn/poly/domain_add_unit_loop_mutator.cc index 63abd9567c5b1..d526012e7862d 100644 --- a/paddle/cinn/poly/domain_add_unit_loop_mutator.cc +++ b/paddle/cinn/poly/domain_add_unit_loop_mutator.cc @@ -26,8 +26,9 @@ namespace cinn { namespace poly { -DomainAddUnitLoopMutator::DomainAddUnitLoopMutator(const std::vector& dim_names, - const std::vector>& dim_min_max) +DomainAddUnitLoopMutator::DomainAddUnitLoopMutator( + const std::vector& dim_names, + const std::vector>& dim_min_max) : dim_names_(dim_names), dim_min_max_(dim_min_max) {} void DomainAddUnitLoopMutator::operator()(ir::Expr* expr) { @@ -40,12 +41,13 @@ void DomainAddUnitLoopMutator::operator()(ir::Expr* expr) { void DomainAddUnitLoopMutator::Visit(const ir::For* op, Expr* expr) { VLOG(6) << "DomainAddUnitLoopMutator Visit For"; - ir::For* node = expr->As(); + ir::For* node = expr->As(); bool add_unit_loop = false; if (parent_for_.size() < dim_names_.size()) { - std::string check_name = dim_names_[parent_for_.size()]; + std::string check_name = dim_names_[parent_for_.size()]; std::tuple t = dim_min_max_[parent_for_.size()]; - if (!utils::Startswith(node->loop_var->name, check_name) && (std::get<2>(t) - std::get<1>(t) == 0)) { + if (!utils::Startswith(node->loop_var->name, check_name) && + (std::get<2>(t) - std::get<1>(t) == 0)) { ir::Expr unit_loop = ir::For::Make(ir::Var(check_name), ir::Expr(0), ir::Expr(1), @@ -57,7 +59,8 @@ void DomainAddUnitLoopMutator::Visit(const ir::For* op, Expr* expr) { parent_for_.push_back(unit_loop.As()); longest_loop_.push_back(unit_loop); add_unit_loop = true; - } else if (parent_for_.back()->body.As() && parent_for_.back()->body == *expr) { + } else if (parent_for_.back()->body.As() && + parent_for_.back()->body == *expr) { parent_for_.back()->body = ir::Block::Make({unit_loop}); parent_for_.push_back(unit_loop.As()); longest_loop_.push_back(unit_loop); @@ -75,7 +78,8 @@ void DomainAddUnitLoopMutator::Visit(const ir::For* op, Expr* expr) { } if (add_unit_loop) { - ir::IRMutator<>::Visit(&(parent_for_.back()->body), &(parent_for_.back()->body)); + ir::IRMutator<>::Visit(&(parent_for_.back()->body), + &(parent_for_.back()->body)); parent_for_.pop_back(); } else { parent_for_.push_back(node); @@ -87,26 +91,29 @@ void DomainAddUnitLoopMutator::Visit(const ir::For* op, Expr* expr) { void DomainAddUnitLoopMutator::Visit(const ir::PolyFor* op, Expr* expr) { VLOG(6) << "DomainAddUnitLoopMutator Visit PolyFor"; - ir::PolyFor* node = expr->As(); + ir::PolyFor* node = expr->As(); bool add_unit_loop = false; if (parent_poly_for_.size() < dim_names_.size()) { - std::string check_name = dim_names_[parent_poly_for_.size()]; + std::string check_name = dim_names_[parent_poly_for_.size()]; std::tuple t = dim_min_max_[parent_poly_for_.size()]; - if (!utils::Startswith(node->iterator->name, check_name) && (std::get<2>(t) - std::get<1>(t) == 0)) { - ir::Expr unit_loop = ir::PolyFor::Make(ir::Var(check_name), - ir::Expr(0), - ir::LE::Make(ir::Var(check_name), ir::Expr(0)), - ir::Expr(1), - ir::ForType::Serial, - node->device_api, - ir::Block::Make({*expr})); + if (!utils::Startswith(node->iterator->name, check_name) && + (std::get<2>(t) - std::get<1>(t) == 0)) { + ir::Expr unit_loop = + ir::PolyFor::Make(ir::Var(check_name), + ir::Expr(0), + ir::LE::Make(ir::Var(check_name), ir::Expr(0)), + ir::Expr(1), + ir::ForType::Serial, + node->device_api, + ir::Block::Make({*expr})); if (parent_poly_for_.empty()) { *expr = unit_loop; parent_poly_for_.push_back(unit_loop.As()); longest_loop_.push_back(unit_loop); add_unit_loop = true; - } else if (parent_poly_for_.back()->body.As() && parent_poly_for_.back()->body == *expr) { + } else if (parent_poly_for_.back()->body.As() && + parent_poly_for_.back()->body == *expr) { parent_poly_for_.back()->body = ir::Block::Make({unit_loop}); parent_poly_for_.push_back(unit_loop.As()); longest_loop_.push_back(unit_loop); @@ -124,7 +131,8 @@ void DomainAddUnitLoopMutator::Visit(const ir::PolyFor* op, Expr* expr) { } if (add_unit_loop) { - ir::IRMutator<>::Visit(&(parent_poly_for_.back()->body), &(parent_poly_for_.back()->body)); + ir::IRMutator<>::Visit(&(parent_poly_for_.back()->body), + &(parent_poly_for_.back()->body)); parent_poly_for_.pop_back(); } else { parent_poly_for_.push_back(node); @@ -145,8 +153,9 @@ void DomainAddUnitLoopMutator::MutateAfterVisit(ir::Expr* expr) { std::tuple t = dim_min_max_[i]; if (longest_loop_[i].As()) { const ir::For* node = longest_loop_[i].As(); - if (utils::Startswith(node->loop_var->name, dim_names_[i]) && node->min.is_constant() && - node->min.as_int32() == std::get<1>(t) && node->extent.is_constant() && + if (utils::Startswith(node->loop_var->name, dim_names_[i]) && + node->min.is_constant() && node->min.as_int32() == std::get<1>(t) && + node->extent.is_constant() && node->extent.as_int32() == std::get<2>(t)) { ++loop_match_len; } else { @@ -155,9 +164,10 @@ void DomainAddUnitLoopMutator::MutateAfterVisit(ir::Expr* expr) { } } else if (longest_loop_[i].As()) { const ir::PolyFor* node = longest_loop_[i].As(); - if (utils::Startswith(node->iterator->name, dim_names_[i]) && node->init.is_constant() && - node->init.as_int32() == std::get<1>(t) && - node->condition == ir::LE::Make(ir::Var(dim_names_[i]), ir::Expr(std::get<2>(t)))) { + if (utils::Startswith(node->iterator->name, dim_names_[i]) && + node->init.is_constant() && node->init.as_int32() == std::get<1>(t) && + node->condition == + ir::LE::Make(ir::Var(dim_names_[i]), ir::Expr(std::get<2>(t)))) { ++loop_match_len; } else { loop_match_len = -1; @@ -182,7 +192,9 @@ void DomainAddUnitLoopMutator::MutateAfterVisit(ir::Expr* expr) { } if (longest_loop_.empty() || longest_loop_.back().As()) { - ir::Expr body = longest_loop_.empty() ? *expr : longest_loop_.back().As()->body; + ir::Expr body = longest_loop_.empty() + ? *expr + : longest_loop_.back().As()->body; for (int i = dim_min_max_.size() - 1; i >= loop_match_len; --i) { if (!body.As()) { body = ir::Block::Make({body}); @@ -193,7 +205,9 @@ void DomainAddUnitLoopMutator::MutateAfterVisit(ir::Expr* expr) { ir::LE::Make(ir::Var(dim_names_[i]), ir::Expr(0)), ir::Expr(1), ir::ForType::Serial, - longest_loop_.empty() ? ir::DeviceAPI::UNK : longest_loop_.back().As()->device_api, + longest_loop_.empty() + ? ir::DeviceAPI::UNK + : longest_loop_.back().As()->device_api, body); } if (longest_loop_.empty()) { @@ -205,8 +219,12 @@ void DomainAddUnitLoopMutator::MutateAfterVisit(ir::Expr* expr) { ir::For* node = longest_loop_.back().As(); ir::Expr body = node->body; for (int i = dim_min_max_.size() - 1; i >= loop_match_len; --i) { - ir::Expr unit_loop = - ir::For::Make(ir::Var(dim_names_[i]), ir::Expr(0), ir::Expr(1), ir::ForType::Serial, node->device_api, body); + ir::Expr unit_loop = ir::For::Make(ir::Var(dim_names_[i]), + ir::Expr(0), + ir::Expr(1), + ir::ForType::Serial, + node->device_api, + body); body = ir::Block::Make({unit_loop}); } node->body = body; diff --git a/paddle/cinn/poly/domain_add_unit_loop_mutator.h b/paddle/cinn/poly/domain_add_unit_loop_mutator.h index ea20088754c7b..0b2ce648b760c 100644 --- a/paddle/cinn/poly/domain_add_unit_loop_mutator.h +++ b/paddle/cinn/poly/domain_add_unit_loop_mutator.h @@ -27,8 +27,9 @@ namespace poly { */ class DomainAddUnitLoopMutator : public ir::IRMutator<> { public: - DomainAddUnitLoopMutator(const std::vector& dim_names, - const std::vector>& dim_min_max); + DomainAddUnitLoopMutator( + const std::vector& dim_names, + const std::vector>& dim_min_max); void operator()(ir::Expr* expr); diff --git a/paddle/cinn/poly/graph.cc b/paddle/cinn/poly/graph.cc index 8a0d2841043e9..c647cf49565dc 100755 --- a/paddle/cinn/poly/graph.cc +++ b/paddle/cinn/poly/graph.cc @@ -36,20 +36,21 @@ DataFlowGraphNode* DataFlowGraphNode::group_ancestor() { return p; } -bool DataFlowGraphNode::TransformedDomainIsSame(const DataFlowGraphNode* a, const DataFlowGraphNode* b) { +bool DataFlowGraphNode::TransformedDomainIsSame(const DataFlowGraphNode* a, + const DataFlowGraphNode* b) { VLOG(3) << "a.domain " << a->stage->domain(); VLOG(3) << "a.transform " << a->stage->transform(); VLOG(3) << "b.domain " << b->stage->domain(); VLOG(3) << "b.transform " << b->stage->transform(); auto a_domain = a->stage->transformed_domain(); auto b_domain = b->stage->transformed_domain(); - a_domain = isl::manage(isl_set_set_tuple_name(a_domain.release(), "")); - b_domain = isl::manage(isl_set_set_tuple_name(b_domain.release(), "")); + a_domain = isl::manage(isl_set_set_tuple_name(a_domain.release(), "")); + b_domain = isl::manage(isl_set_set_tuple_name(b_domain.release(), "")); return isl_set_is_equal(a_domain.get(), b_domain.get()); } int DataFlowGraphNode::group_height() const { - int h = 0; + int h = 0; auto* p = this; while (p) { ++h; @@ -59,19 +60,20 @@ int DataFlowGraphNode::group_height() const { return h; } -DataFlowGraphNode* DataFlowGraphNode::MergeGroup(DataFlowGraphNode* a, DataFlowGraphNode* b) { - int ah = a->group_height(); - int bh = b->group_height(); +DataFlowGraphNode* DataFlowGraphNode::MergeGroup(DataFlowGraphNode* a, + DataFlowGraphNode* b) { + int ah = a->group_height(); + int bh = b->group_height(); auto* a_anc = a->group_ancestor(); auto* b_anc = b->group_ancestor(); DataFlowGraphNode* common_anc{}; if (ah < bh) { // take a's ancestor b_anc->group_parent = a_anc; - b->group_parent = a_anc; + b->group_parent = a_anc; return a_anc; } else { a_anc->group_parent = b_anc; - a->group_parent = b_anc; + a->group_parent = b_anc; return b_anc; } } @@ -81,26 +83,33 @@ std::string DataFlowGraphNode::id() const { } bool DataFlowGraphNode::IsLinkedTo(const DataFlowGraphNode* node) const { - bool found = std::find_if(inlinks_.begin(), inlinks_.end(), [=](const Shared& x) { - return x->source() == node; - }) != std::end(inlinks_); - return found || std::find_if(outlinks_.begin(), outlinks_.end(), [=](const Shared& x) { - return x->sink() == node; - }) != std::end(outlinks_); + bool found = std::find_if(inlinks_.begin(), + inlinks_.end(), + [=](const Shared& x) { + return x->source() == node; + }) != std::end(inlinks_); + return found || std::find_if(outlinks_.begin(), + outlinks_.end(), + [=](const Shared& x) { + return x->sink() == node; + }) != std::end(outlinks_); } -std::unique_ptr CreateGraph(const std::vector& stages, - const std::vector>& extra_links) { +std::unique_ptr CreateGraph( + const std::vector& stages, + const std::vector>& extra_links) { std::map> id2stage; for (auto* x : stages) id2stage[x->id()] = make_shared(x); for (auto* stage : stages) { auto depend_statement_names = stage->input_statements(); - VLOG(3) << stage->id() << " depend " << utils::Join(depend_statement_names, ", "); + VLOG(3) << stage->id() << " depend " + << utils::Join(depend_statement_names, ", "); for (auto& depend_statement : depend_statement_names) { auto input_it = id2stage.find(depend_statement); - // We removed some node in the original stages(such as placeholders), so that there might be missing of some input - // nodes, just ignore the dependence. + // We removed some node in the original stages(such as placeholders), so + // that there might be missing of some input nodes, just ignore the + // dependence. if (input_it != std::end(id2stage)) { auto& input_node = input_it->second; input_node->Controls(id2stage.at(stage->id()).get()); @@ -120,7 +129,8 @@ std::unique_ptr CreateGraph(const std::vector& stages, } std::unique_ptr graph(new DataFlowGraph); - for (auto& item : id2stage) graph->RegisterNode(item.first, item.second.get()); + for (auto& item : id2stage) + graph->RegisterNode(item.first, item.second.get()); VLOG(3) << "created graph:\n" << graph->Visualize(); return graph; } diff --git a/paddle/cinn/poly/graph.h b/paddle/cinn/poly/graph.h index fc10a0a14aa3e..e0c15f7be793f 100644 --- a/paddle/cinn/poly/graph.h +++ b/paddle/cinn/poly/graph.h @@ -44,15 +44,18 @@ struct DataFlowGraphNode : public common::GraphNode { //! Get the tree height for union find. int group_height() const; - //! Tell whether this node is connected to another `node`, either inlink or outlink. + //! Tell whether this node is connected to another `node`, either inlink or + //! outlink. bool IsLinkedTo(const DataFlowGraphNode* node) const; //! Merge two nodes into the same group. //! returns: the common ancestor. - static DataFlowGraphNode* MergeGroup(DataFlowGraphNode* a, DataFlowGraphNode* b); + static DataFlowGraphNode* MergeGroup(DataFlowGraphNode* a, + DataFlowGraphNode* b); //! Compare the the iteration_domain.apply(transform), return true if same. - static bool TransformedDomainIsSame(const DataFlowGraphNode* a, const DataFlowGraphNode* b); + static bool TransformedDomainIsSame(const DataFlowGraphNode* a, + const DataFlowGraphNode* b); }; struct DataFlowGraphEdge : public common::GraphEdge {}; @@ -69,24 +72,26 @@ struct DataFlowGraph : public common::Graph {}; * @param stages The stages. * @param extra_links The extra links, each element is a pair of (a -> b) */ -std::unique_ptr CreateGraph(const std::vector& stages, - const std::vector>& extra_links = {}); +std::unique_ptr CreateGraph( + const std::vector& stages, + const std::vector>& extra_links = {}); namespace detail { struct Group { Group() = default; - explicit Group(const std::vector>& nodes) : nodes(nodes) {} + explicit Group(const std::vector>& nodes) + : nodes(nodes) {} std::vector> nodes; std::vector dimension_names; }; /** - * GraphPartitionBySpace partitions a data flow graph into several sub-graph with consider of the dependency and space - * of the iteration domain. - * If two Nodes has the stages has dependency relation and has the same iteration domain, then they will be put in the - * same sub-graph. + * GraphPartitionBySpace partitions a data flow graph into several sub-graph + * with consider of the dependency and space of the iteration domain. If two + * Nodes has the stages has dependency relation and has the same iteration + * domain, then they will be put in the same sub-graph. */ std::vector PartitionGraphByIterationDomain(common::Graph* graph); diff --git a/paddle/cinn/poly/isl_utils.cc b/paddle/cinn/poly/isl_utils.cc index 3b3712dbaf1b1..34f1dd21fac2f 100644 --- a/paddle/cinn/poly/isl_utils.cc +++ b/paddle/cinn/poly/isl_utils.cc @@ -36,7 +36,8 @@ std::vector isl_get_dim_names(const isl::set &x) { return res; } -std::vector isl_get_dim_names(const isl::map &x, isl_dim_type dim_type) { +std::vector isl_get_dim_names(const isl::map &x, + isl_dim_type dim_type) { std::vector res; for (int i = 0; i < isl_map_dim(x.get(), dim_type); i++) { res.push_back(isl_map_get_dim_name(x.get(), dim_type, i)); @@ -52,12 +53,15 @@ std::vector isl_get_dim_names(isl_set *set) { return res; } -void isl_set_dim_names(isl::map *map, isl_dim_type dim_type, const std::vector &names) { +void isl_set_dim_names(isl::map *map, + isl_dim_type dim_type, + const std::vector &names) { const int dim = isl_map_dim(map->get(), dim_type); CHECK_EQ(dim, names.size()); for (int i = 0; i < dim; i++) { - *map = isl::manage(isl_map_set_dim_name(map->release(), dim_type, i, names[i].c_str())); + *map = isl::manage( + isl_map_set_dim_name(map->release(), dim_type, i, names[i].c_str())); } } @@ -66,13 +70,15 @@ void isl_set_dim_names(isl::set *set, const std::vector &names) { CHECK_EQ(dim, names.size()); for (int i = 0; i < dim; i++) { - *set = isl::manage(isl_set_set_dim_name(set->release(), isl_dim_set, i, names[i].c_str())); + *set = isl::manage( + isl_set_set_dim_name(set->release(), isl_dim_set, i, names[i].c_str())); } } isl::union_map isl_maps_to_union_map(const std::vector &maps) { CHECK(!maps.empty()); - isl::union_map umap = isl::manage(isl_union_map_from_map(maps.front().copy())); + isl::union_map umap = + isl::manage(isl_union_map_from_map(maps.front().copy())); for (int i = 1; i < maps.size(); i++) { umap = isl::manage(isl_union_map_add_map(umap.release(), maps[i].copy())); } @@ -81,14 +87,16 @@ isl::union_map isl_maps_to_union_map(const std::vector &maps) { isl::union_set isl_sets_to_union_set(const std::vector &sets) { CHECK(!sets.empty()); - isl::union_set uset = isl::manage(isl_union_set_from_set(sets.front().copy())); + isl::union_set uset = + isl::manage(isl_union_set_from_set(sets.front().copy())); for (int i = 1; i < sets.size(); i++) { uset = isl::manage(isl_union_set_add_set(uset.release(), sets[i].copy())); } return uset; } -std::string isl_map_get_statement_repr(__isl_keep isl_map *map, isl_dim_type type) { +std::string isl_map_get_statement_repr(__isl_keep isl_map *map, + isl_dim_type type) { CHECK(map); auto tuple_name = isl_map_get_tuple_name(map, type); std::vector dims; @@ -99,7 +107,8 @@ std::string isl_map_get_statement_repr(__isl_keep isl_map *map, isl_dim_type typ return StringFormat("%s[%s]", tuple_name, Join(dims, ", ").c_str()); } -std::vector isl_get_dim_names(isl_map *map, isl_dim_type dim_type) { +std::vector isl_get_dim_names(isl_map *map, + isl_dim_type dim_type) { std::vector res; int n = isl_map_dim(map, dim_type); for (int i = 0; i < n; i++) { @@ -110,23 +119,26 @@ std::vector isl_get_dim_names(isl_map *map, isl_dim_type dim_type) isl::set SetGetDims(isl::set set, const std::vector &dims) { std::string tuple_name = isl_set_get_tuple_name(set.get()); - auto dim_names = isl_get_dim_names(set); + auto dim_names = isl_get_dim_names(set); std::vector selected_dim_names; for (int v : dims) { CHECK_LT(v, dim_names.size()); selected_dim_names.push_back(dim_names[v]); } - std::string transform_repr = StringFormat("{ %s[%s] -> %s[%s] }", - tuple_name.c_str(), // - Join(dim_names, ", ").c_str(), // - tuple_name.c_str(), // - Join(selected_dim_names, ", ").c_str()); + std::string transform_repr = + StringFormat("{ %s[%s] -> %s[%s] }", + tuple_name.c_str(), // + Join(dim_names, ", ").c_str(), // + tuple_name.c_str(), // + Join(selected_dim_names, ", ").c_str()); isl::map transform(set.ctx(), transform_repr); return set.apply(transform); } -isl_set *isl_get_precending_aixs(isl_set *set, int level, bool with_tuple_name) { +isl_set *isl_get_precending_aixs(isl_set *set, + int level, + bool with_tuple_name) { int n = isl_set_dim(set, isl_dim_set); CHECK_LT(level, n); @@ -143,17 +155,20 @@ isl_set *isl_get_precending_aixs(isl_set *set, int level, bool with_tuple_name) const char *statement = isl_set_get_tuple_name(set); - std::string repr = utils::StringFormat("{ %s[%s] -> %s[%s] }", - statement, - utils::Join(domain_iterators, ", ").c_str(), - statement, - utils::Join(range_iterators, ", ").c_str()); - auto transform = isl::manage(isl_map_read_from_str(isl_set_get_ctx(set), repr.c_str())); + std::string repr = + utils::StringFormat("{ %s[%s] -> %s[%s] }", + statement, + utils::Join(domain_iterators, ", ").c_str(), + statement, + utils::Join(range_iterators, ", ").c_str()); + auto transform = + isl::manage(isl_map_read_from_str(isl_set_get_ctx(set), repr.c_str())); return isl_set_apply(set, transform.release()); } -int isl_get_original_axes_from_optimized_level(isl_set __isl_keep *a, int level) { +int isl_get_original_axes_from_optimized_level(isl_set __isl_keep *a, + int level) { int original_level = -1; std::vector> iden_dim_offsets; for (int i = 0; i <= level;) { @@ -217,8 +232,10 @@ int isl_max_level_compatible(isl_set *a, isl_set *b) { int compatible_level = -1; for (int i = 0; i < std::min(an, bn); i++) { - isl::set a_prefix = isl::manage(isl_get_precending_aixs(isl_set_copy(a), i, false)); - isl::set b_prefix = isl::manage(isl_get_precending_aixs(isl_set_copy(b), i, false)); + isl::set a_prefix = + isl::manage(isl_get_precending_aixs(isl_set_copy(a), i, false)); + isl::set b_prefix = + isl::manage(isl_get_precending_aixs(isl_set_copy(b), i, false)); a_prefix = isl::manage(isl_set_set_tuple_name(a_prefix.release(), "s")); b_prefix = isl::manage(isl_set_set_tuple_name(b_prefix.release(), "s")); @@ -233,23 +250,28 @@ int isl_max_level_compatible(isl_set *a, isl_set *b) { isl_set *isl_remove_axis_by_name(isl_set *set, const char *axis_name) { std::string tuple_name = isl_set_get_tuple_name(set); - int offset = isl_set_find_dim_by_name(set, isl_dim_set, axis_name); - set = isl_set_remove_dims(set, isl_dim_set, offset, 1); - set = isl_set_set_tuple_name(set, tuple_name.c_str()); + int offset = isl_set_find_dim_by_name(set, isl_dim_set, axis_name); + set = isl_set_remove_dims(set, isl_dim_set, offset, 1); + set = isl_set_set_tuple_name(set, tuple_name.c_str()); return set; } -isl_map *isl_remove_axis_by_name(isl_map *map, isl_dim_type dim_type, const char *axis_name) { - int offset = isl_map_find_dim_by_name(map, dim_type, axis_name); +isl_map *isl_remove_axis_by_name(isl_map *map, + isl_dim_type dim_type, + const char *axis_name) { + int offset = isl_map_find_dim_by_name(map, dim_type, axis_name); std::string tuple_name = isl_map_get_tuple_name(map, dim_type); - map = isl_map_remove_dims(map, dim_type, offset, 1); - map = isl_map_set_tuple_name(map, dim_type, tuple_name.c_str()); + map = isl_map_remove_dims(map, dim_type, offset, 1); + map = isl_map_set_tuple_name(map, dim_type, tuple_name.c_str()); return map; } isl_set *isl_rename_axis(isl_set *set, int offset, const char *name) { return isl_set_set_dim_name(set, isl_dim_set, offset, name); } -isl_map *isl_rename_axis(isl_map *map, isl_dim_type dim_type, int offset, const char *name) { +isl_map *isl_rename_axis(isl_map *map, + isl_dim_type dim_type, + int offset, + const char *name) { return isl_map_set_dim_name(map, dim_type, offset, name); } @@ -268,7 +290,8 @@ isl::union_set isl_union_set_from_sets(llvm::ArrayRef sets) { return res; } -std::tuple isl_set_get_axis_range_by_name(isl_set *set, std::string axis_name) { +std::tuple isl_set_get_axis_range_by_name( + isl_set *set, std::string axis_name) { std::vector from_iters; for (int i = 0; i < isl_set_dim(set, isl_dim_set); i++) { auto *name = isl_set_get_dim_name(set, isl_dim_set, i); @@ -279,10 +302,11 @@ std::tuple isl_set_get_axis_range_by_name(isl_set *set, std: } } - isl::aff aff( - isl_set_get_ctx(set), - utils::StringFormat( - "{ %s[%s] -> [%s] }", isl_set_get_tuple_name(set), utils::Join(from_iters, ",").c_str(), axis_name.c_str())); + isl::aff aff(isl_set_get_ctx(set), + utils::StringFormat("{ %s[%s] -> [%s] }", + isl_set_get_tuple_name(set), + utils::Join(from_iters, ",").c_str(), + axis_name.c_str())); isl::val max_val = isl::manage(isl_set_max_val(set, aff.get())); isl::val min_val = isl::manage(isl_set_min_val(set, aff.get())); @@ -291,7 +315,8 @@ std::tuple isl_set_get_axis_range_by_name(isl_set *set, std: } std::tuple isl_set_get_axis_range(isl_set *set, int pos) { - CHECK(isl_set_dim_is_bounded(set, isl_dim_set, pos)) << "an unbound cannot get range, " << isl_set_to_str(set); + CHECK(isl_set_dim_is_bounded(set, isl_dim_set, pos)) + << "an unbound cannot get range, " << isl_set_to_str(set); std::vector from_iters; std::string target_axis_name; @@ -336,9 +361,12 @@ bool isl_set_axis_has_noparam_constant_bound(isl_set __isl_keep *set, int pos) { bool is_param_involved = false; isl_pw_aff_foreach_piece( val, - [](isl_set *__isl_give set, isl_aff *__isl_give aff, void *user) -> isl_stat { - // Ignore the set piece, e.g. [_cp_C_0, _cp_C_1] -> { cache[0, 0] : _cp_C_0 = 0 and _cp_C_1 = 0 } - // will get a set [_cp_C_0, _cp_C_1] -> { : _cp_C_0 = 0 and _cp_C_1 = 0 } + [](isl_set *__isl_give set, + isl_aff *__isl_give aff, + void *user) -> isl_stat { + // Ignore the set piece, e.g. [_cp_C_0, _cp_C_1] -> { cache[0, 0] : + // _cp_C_0 = 0 and _cp_C_1 = 0 } will get a set [_cp_C_0, _cp_C_1] -> + // { : _cp_C_0 = 0 and _cp_C_1 = 0 } if (set) { // ignore } @@ -349,10 +377,10 @@ bool isl_set_axis_has_noparam_constant_bound(isl_set __isl_keep *set, int pos) { // drop unused params, so the Aff [n]->{ [(0)] } will be []->{ [(0)] } auto *pw_aff = isl_pw_aff_from_aff(aff); - pw_aff = isl_pw_aff_drop_unused_params(pw_aff); + pw_aff = isl_pw_aff_drop_unused_params(pw_aff); // check if some params is involved. - isl::set params = isl::manage(isl_pw_aff_params(pw_aff)); + isl::set params = isl::manage(isl_pw_aff_params(pw_aff)); is_param_involved = isl_set_dim(params.get(), isl_dim_param) > 0; isl_set_free(set); @@ -366,13 +394,15 @@ bool isl_set_axis_has_noparam_constant_bound(isl_set __isl_keep *set, int pos) { return is_dim_a_constant(max_val) && is_dim_a_constant(min_val); } -isl::map isl_set_dim_name_if_null(isl_map *map, std::function namer) { - int in_dims = isl_map_dim(map, isl_dim_in); - int out_dims = isl_map_dim(map, isl_dim_out); +isl::map isl_set_dim_name_if_null( + isl_map *map, std::function namer) { + int in_dims = isl_map_dim(map, isl_dim_in); + int out_dims = isl_map_dim(map, isl_dim_out); auto set_name = [&](isl_dim_type dim_type) { for (int i = 0; i < isl_map_dim(map, dim_type); i++) { if (!isl_map_get_dim_name(map, dim_type, i)) { - map = isl_map_set_dim_name(map, dim_type, i, namer(dim_type, i).c_str()); + map = + isl_map_set_dim_name(map, dim_type, i, namer(dim_type, i).c_str()); } } }; @@ -383,10 +413,12 @@ isl::map isl_set_dim_name_if_null(isl_map *map, std::function namer) { +isl::set isl_set_dim_name_if_null( + isl_set *set, std::function namer) { for (int i = 0; i < isl_set_dim(set, isl_dim_set); i++) { if (!isl_set_get_dim_name(set, isl_dim_set, i)) { - set = isl_set_set_dim_name(set, isl_dim_set, i, namer(isl_dim_set, i).c_str()); + set = isl_set_set_dim_name( + set, isl_dim_set, i, namer(isl_dim_set, i).c_str()); } } return isl::manage(set); @@ -396,39 +428,47 @@ isl::map RemoveAxiesByInputNames(const isl::map &x, const isl::set &origin_domain, const std::vector &dim_in_names) { std::string map_str = isl_map_to_str(x.get()); - isl::ctx this_ctx = x.ctx(); + isl::ctx this_ctx = x.ctx(); isl::map temp_transform(this_ctx, map_str); - auto related_output_names = GetRelatedOutputAxies(x, origin_domain, dim_in_names); + auto related_output_names = + GetRelatedOutputAxies(x, origin_domain, dim_in_names); if (dim_in_names.empty()) return temp_transform; for (auto &i : dim_in_names) { - temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_in, i.c_str())); + temp_transform = isl::manage(isl_remove_axis_by_name( + temp_transform.release(), isl_dim_in, i.c_str())); } for (auto &i : related_output_names) { - temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_out, i.c_str())); + temp_transform = isl::manage(isl_remove_axis_by_name( + temp_transform.release(), isl_dim_out, i.c_str())); } return temp_transform; } -isl::map RemoveAxiesByOutputNames(const isl::map &x, - const isl::set &origin_domain, - const std::vector &dim_out_names) { +isl::map RemoveAxiesByOutputNames( + const isl::map &x, + const isl::set &origin_domain, + const std::vector &dim_out_names) { std::string map_str = isl_map_to_str(x.get()); - isl::ctx this_ctx = x.ctx(); + isl::ctx this_ctx = x.ctx(); isl::map temp_transform(this_ctx, map_str); - auto related_input_names = GetRelatedInputAxies(x, origin_domain, dim_out_names); + auto related_input_names = + GetRelatedInputAxies(x, origin_domain, dim_out_names); if (dim_out_names.empty()) return temp_transform; for (auto &i : dim_out_names) { - temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_out, i.c_str())); + temp_transform = isl::manage(isl_remove_axis_by_name( + temp_transform.release(), isl_dim_out, i.c_str())); } for (auto &i : related_input_names) { - temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_in, i.c_str())); + temp_transform = isl::manage(isl_remove_axis_by_name( + temp_transform.release(), isl_dim_in, i.c_str())); } return temp_transform; } -std::vector GetRelatedOutputAxies(const isl::map &x, - const isl::set &origin_domain, - const std::vector &dim_in_names) { +std::vector GetRelatedOutputAxies( + const isl::map &x, + const isl::set &origin_domain, + const std::vector &dim_in_names) { std::string map_str = isl_map_to_str(x.get()); VLOG(1) << "GetRelatedOutputAxies map_str is : " << map_str; isl::ctx this_ctx = x.ctx(); @@ -441,7 +481,8 @@ std::vector GetRelatedOutputAxies(const isl::map &x, } std::set res_set; for (auto &i : dim_out_names) { - auto related_in_dim = GetRelatedInputAxies(temp_transform, origin_domain, {i}); + auto related_in_dim = + GetRelatedInputAxies(temp_transform, origin_domain, {i}); for (auto &j : related_in_dim) { if (dim_in_set.count(j) > 0) { res_set.insert(i); @@ -456,10 +497,11 @@ std::vector GetRelatedOutputAxies(const isl::map &x, return res; } -std::vector GetRelatedInputAxies(const isl::map &x, - const isl::set &origin_domain, - const std::vector &dim_out_names, - bool strict) { +std::vector GetRelatedInputAxies( + const isl::map &x, + const isl::set &origin_domain, + const std::vector &dim_out_names, + bool strict) { std::string map_str = isl_map_to_str(x.get()); VLOG(1) << "GetRelatedInputAxies map_str is : " << map_str; isl::ctx this_ctx = x.ctx(); @@ -467,14 +509,15 @@ std::vector GetRelatedInputAxies(const isl::map &x, auto dim_in_names = isl_get_dim_names(temp_transform, isl_dim_in); for (auto &i : dim_out_names) { VLOG(1) << "GetRelatedInputAxies dim_out_names is : " << i; - temp_transform = isl::manage(isl_remove_axis_by_name(temp_transform.release(), isl_dim_out, i.c_str())); + temp_transform = isl::manage(isl_remove_axis_by_name( + temp_transform.release(), isl_dim_out, i.c_str())); } std::string deleted_map = isl_map_to_str(temp_transform.get()); std::vector res; std::set out_set; std::set out_set_without_suffix; std::string set_str = isl_set_to_str(origin_domain.get()); - isl::ctx set_ctx = origin_domain.ctx(); + isl::ctx set_ctx = origin_domain.ctx(); isl::set temp_set(this_ctx, set_str); auto transformed_domain = temp_set.apply(x); for (auto &i : dim_out_names) { diff --git a/paddle/cinn/poly/isl_utils.h b/paddle/cinn/poly/isl_utils.h index a1c5637b204b1..005135071f7cc 100644 --- a/paddle/cinn/poly/isl_utils.h +++ b/paddle/cinn/poly/isl_utils.h @@ -27,60 +27,83 @@ namespace poly { //! Get dimension names from isl containers. // @{ std::vector isl_get_dim_names(const isl::set& x); -std::vector isl_get_dim_names(const isl::map& x, isl_dim_type dim_type); +std::vector isl_get_dim_names(const isl::map& x, + isl_dim_type dim_type); std::vector isl_get_dim_names(isl_map* map, isl_dim_type dim_type); std::vector isl_get_dim_names(isl_set* set); // @} -void isl_set_dim_names(isl::set* __isl_keep set, const std::vector& names); -void isl_set_dim_names(isl::map* __isl_keep map, isl_dim_type dim_type, const std::vector& names); +void isl_set_dim_names(isl::set* __isl_keep set, + const std::vector& names); +void isl_set_dim_names(isl::map* __isl_keep map, + isl_dim_type dim_type, + const std::vector& names); isl::union_set isl_union_set_from_sets(llvm::ArrayRef sets); -isl::map isl_set_dim_name_if_null(isl_map* __isl_take map, std::function namer); -isl::set isl_set_dim_name_if_null(isl_set* __isl_take set, std::function namer); +isl::map isl_set_dim_name_if_null( + isl_map* __isl_take map, + std::function namer); +isl::set isl_set_dim_name_if_null( + isl_set* __isl_take set, + std::function namer); //! Convert a list of isl::map to isl::union_map isl::union_map isl_maps_to_union_map(const std::vector& maps); isl::union_set isl_sets_to_union_set(const std::vector& sets); //! Get a representation of the tuple in the map. -std::string isl_map_get_statement_repr(__isl_keep isl_map* map, isl_dim_type type); +std::string isl_map_get_statement_repr(__isl_keep isl_map* map, + isl_dim_type type); -isl_set* __isl_give isl_get_precending_aixs(isl_set* set, int level, bool with_tuple_name); +isl_set* __isl_give isl_get_precending_aixs(isl_set* set, + int level, + bool with_tuple_name); -//! If the min and max bounds of the axis are same, isl will remove this axis after ast_build. Counts the removed axes -//! before the given axis. +//! If the min and max bounds of the axis are same, isl will remove this axis +//! after ast_build. Counts the removed axes before the given axis. int isl_get_precending_removed_axes_counts(isl_set __isl_keep* a, int level); //! Get the original level from the level after removing axes. -int isl_get_original_axes_from_optimized_level(isl_set __isl_keep* a, int level); +int isl_get_original_axes_from_optimized_level(isl_set __isl_keep* a, + int level); -//! If the min and max bounds of the axis are same, isl will remove this axis after ast_build. Judge whether or not the -//! axis will be removed by isl. +//! If the min and max bounds of the axis are same, isl will remove this axis +//! after ast_build. Judge whether or not the axis will be removed by isl. bool isl_is_removed_axis(isl_set __isl_keep* a, int level); //! Get the maximum level of axis that is has the same domain. int isl_max_level_compatible(isl_set* __isl_keep a, isl_set* __isl_keep b); -isl_set* __isl_give isl_remove_axis_by_name(isl_set* __isl_take set, const char* axis_name); -isl_map* __isl_give isl_remove_axis_by_name(isl_map* __isl_take map, isl_dim_type dim_type, const char* axis_name); -isl_set* __isl_give isl_rename_axis(isl_set* __isl_take set, int offset, const char* name); -isl_map* __isl_give isl_rename_axis(isl_map* __isl_take map, isl_dim_type dim_type, int offset, const char* name); +isl_set* __isl_give isl_remove_axis_by_name(isl_set* __isl_take set, + const char* axis_name); +isl_map* __isl_give isl_remove_axis_by_name(isl_map* __isl_take map, + isl_dim_type dim_type, + const char* axis_name); +isl_set* __isl_give isl_rename_axis(isl_set* __isl_take set, + int offset, + const char* name); +isl_map* __isl_give isl_rename_axis(isl_map* __isl_take map, + isl_dim_type dim_type, + int offset, + const char* name); isl_set* __isl_give isl_simplify(isl_set* __isl_take set); // { s[i]: 0 < i < 20 } bool isl_set_axis_has_noparam_constant_bound(isl_set* __isl_keep set, int pos); -//! get a minimum and maximum range of a set, if the bound not exists, return a INT_MAX instead. -//! NOTE the set should be bound. -//! returns: a tuple of (min, max) -std::tuple isl_set_get_axis_range(isl_set* __isl_keep set, int pos); +//! get a minimum and maximum range of a set, if the bound not exists, return a +//! INT_MAX instead. NOTE the set should be bound. returns: a tuple of (min, +//! max) +std::tuple isl_set_get_axis_range(isl_set* __isl_keep set, + int pos); -std::tuple isl_set_get_axis_range_by_name(isl_set* __isl_keep set, std::string axis_name); +std::tuple isl_set_get_axis_range_by_name( + isl_set* __isl_keep set, std::string axis_name); -//! Port the set from \p from to \p to with the \p poses dims constraints remained. +//! Port the set from \p from to \p to with the \p poses dims constraints +//! remained. //! @param from The set to port. //! @param to The set to be. //! @param poses The dimensions to remained. @@ -110,9 +133,10 @@ isl::map RemoveAxiesByInputNames(const isl::map& x, * @param dim_in_names The names of output dims to remove. * @return The edited map. */ -isl::map RemoveAxiesByOutputNames(const isl::map& x, - const isl::set& origin_domain, - const std::vector& dim_out_names); +isl::map RemoveAxiesByOutputNames( + const isl::map& x, + const isl::set& origin_domain, + const std::vector& dim_out_names); /** * Given an isl::map and a vector of names of dim_out, @@ -120,13 +144,15 @@ isl::map RemoveAxiesByOutputNames(const isl::map& x, * @param x The input map. * @param dim_out_names The names of output dims. * @param strict Indicates whether computes the strictly related input axies. - * For example, if strict == true, then input 'j' is related to output 'j_outer_inner_outer' + * For example, if strict == true, then input 'j' is related to output + * 'j_outer_inner_outer' * @return The vector of names of related input dims. */ -std::vector GetRelatedInputAxies(const isl::map& x, - const isl::set& origin_domain, - const std::vector& dim_out_names, - bool strict = false); +std::vector GetRelatedInputAxies( + const isl::map& x, + const isl::set& origin_domain, + const std::vector& dim_out_names, + bool strict = false); /** * Given an isl::map and a vector of names of dim_in, @@ -135,9 +161,10 @@ std::vector GetRelatedInputAxies(const isl::map& x, * @param dim_in_names The names of input dims. * @return The vector of names of related output dims. */ -std::vector GetRelatedOutputAxies(const isl::map& x, - const isl::set& origin_domain, - const std::vector& dim_in_names); +std::vector GetRelatedOutputAxies( + const isl::map& x, + const isl::set& origin_domain, + const std::vector& dim_in_names); } // namespace poly } // namespace cinn diff --git a/paddle/cinn/poly/map.cc b/paddle/cinn/poly/map.cc index 2496964c8a482..c8c77e5e0fe4f 100644 --- a/paddle/cinn/poly/map.cc +++ b/paddle/cinn/poly/map.cc @@ -25,16 +25,21 @@ std::string Map::__str__() const { auto get_ids_repr = [](const std::vector& ids) { std::vector fields; - std::transform(ids.begin(), ids.end(), std::back_inserter(fields), [](const Iterator& x) { return x.id; }); + std::transform(ids.begin(), + ids.end(), + std::back_inserter(fields), + [](const Iterator& x) { return x.id; }); return utils::Join(fields, ", "); }; auto domain_iterators_repr = get_ids_repr(domain_iterators_); - auto range_iterators_repr = get_ids_repr(range_iterators_); + auto range_iterators_repr = get_ids_repr(range_iterators_); std::vector conds_fields; - std::transform( - conds_.begin(), conds_.end(), std::back_inserter(conds_fields), [](const Condition& x) { return x.__str__(); }); + std::transform(conds_.begin(), + conds_.end(), + std::back_inserter(conds_fields), + [](const Condition& x) { return x.__str__(); }); auto conds_repr = utils::Join(conds_fields, " and "); if (!conds_.empty()) { @@ -69,9 +74,11 @@ Map::Map(isl::ctx ctx, isl::map Map::to_isl() const { auto map = isl::map(ctx_, __str__()); // set dimension names - auto handler = [](const Iterator& x) { return x.id; }; - auto domain_dim_names = utils::Map, std::string>(domain_iterators_, handler); - auto range_dim_names = utils::Map, std::string>(range_iterators_, handler); + auto handler = [](const Iterator& x) { return x.id; }; + auto domain_dim_names = utils::Map, std::string>( + domain_iterators_, handler); + auto range_dim_names = + utils::Map, std::string>(range_iterators_, handler); isl_set_dim_names(&map, isl_dim_in, domain_dim_names); isl_set_dim_names(&map, isl_dim_out, range_dim_names); return map; diff --git a/paddle/cinn/poly/map.h b/paddle/cinn/poly/map.h index e0390d0d6d4a5..3e3b72704403f 100644 --- a/paddle/cinn/poly/map.h +++ b/paddle/cinn/poly/map.h @@ -37,8 +37,12 @@ struct Iterator { explicit Iterator(Iterator&& x) : id(std::move(x.id)) {} Iterator& operator=(const Iterator& other); - friend bool operator==(const Iterator& a, const Iterator& b) { return a.id == b.id; } - friend bool operator!=(const Iterator& a, const Iterator& b) { return !(a.id == b.id); } + friend bool operator==(const Iterator& a, const Iterator& b) { + return a.id == b.id; + } + friend bool operator!=(const Iterator& a, const Iterator& b) { + return !(a.id == b.id); + } friend std::ostream& operator<<(std::ostream& os, const Iterator& x); }; @@ -53,7 +57,9 @@ struct Condition { return os; } - std::string __str__() const { return utils::StringFormat("%s", cond.c_str()); } + std::string __str__() const { + return utils::StringFormat("%s", cond.c_str()); + } }; /** @@ -102,7 +108,9 @@ class Aff : public Map { std::ostream& operator<<(std::ostream& os, const Map& x); std::ostream& operator<<(std::ostream& os, const Aff& x); -static bool operator<(const Iterator& a, const Iterator& b) { return a.id < b.id; } +static bool operator<(const Iterator& a, const Iterator& b) { + return a.id < b.id; +} } // namespace poly } // namespace cinn diff --git a/paddle/cinn/poly/naive_scheduler.cc b/paddle/cinn/poly/naive_scheduler.cc index 97cc89ef15eb0..90ab99d9bbed3 100644 --- a/paddle/cinn/poly/naive_scheduler.cc +++ b/paddle/cinn/poly/naive_scheduler.cc @@ -26,7 +26,8 @@ std::unique_ptr NaiveScheduler::BuildSchedule() { for (auto &group : groups_) { std::vector status; CHECK_EQ(group.nodes.size(), 1UL); - NaiveGroupScheduler scheduler(const_cast(group.nodes.front()->stage)); + NaiveGroupScheduler scheduler( + const_cast(group.nodes.front()->stage)); scheduler.Build(); } @@ -38,7 +39,7 @@ std::unique_ptr NaiveScheduler::BuildSchedule() { void NaiveScheduler::PartitionGroups() { // treat each node as a unique group, collect the groups in topological order. - auto topo_order = schedule_graph_.topological_order(); // NOLINT + auto topo_order = schedule_graph_.topological_order(); // NOLINT auto &nodes_in_order = std::get<0>(topo_order); auto &edges_in_order = std::get<1>(topo_order); diff --git a/paddle/cinn/poly/naive_scheduler.h b/paddle/cinn/poly/naive_scheduler.h index 7a3bb4d0af909..a4074ea84ecca 100644 --- a/paddle/cinn/poly/naive_scheduler.h +++ b/paddle/cinn/poly/naive_scheduler.h @@ -35,9 +35,9 @@ class NaiveGroupScheduler : public SchedulerBase { }; /** - * The NaiveScheduler just schedule each noninlined Tensor as a unique group. Only the `compute_at` will merge two - * tensor in the same group. - * It is simple and robust. + * The NaiveScheduler just schedule each noninlined Tensor as a unique group. + * Only the `compute_at` will merge two tensor in the same group. It is simple + * and robust. */ class NaiveScheduler : public SchedulerBase { public: diff --git a/paddle/cinn/poly/poly_scheduler.cc b/paddle/cinn/poly/poly_scheduler.cc index 59b77124c4b4c..d3a34e4544507 100755 --- a/paddle/cinn/poly/poly_scheduler.cc +++ b/paddle/cinn/poly/poly_scheduler.cc @@ -30,15 +30,15 @@ namespace poly { namespace detail { -//! Visit the nodes in topological order, if one node is valid to visit, visit it and check whether its out link -//! children are ready to visit, merge them to the same group. -//! NOTE this is discarded. +//! Visit the nodes in topological order, if one node is valid to visit, visit +//! it and check whether its out link children are ready to visit, merge them to +//! the same group. NOTE this is discarded. std::vector PartitionGraphByIterationDomain(common::Graph* graph) { VLOG(3) << "graph:\n" << graph->Visualize(); // collect indegrees for naive topological traversal. std::map indegree; for (common::GraphNode* n : graph->nodes()) { - auto* node = n->safe_as(); + auto* node = n->safe_as(); indegree[node] = node->inlinks().size(); } @@ -62,11 +62,13 @@ std::vector PartitionGraphByIterationDomain(common::Graph* graph) { auto* child = c->sink()->safe_as(); --indegree[child]; - VLOG(3) << node->stage->transformed_domain() << " -> " << child->stage->transformed_domain(); + VLOG(3) << node->stage->transformed_domain() << " -> " + << child->stage->transformed_domain(); if (indegree[child] == 0) { // Merge the two groups if their iteration domain is the same. if (DataFlowGraphNode::TransformedDomainIsSame(node, child)) { - VLOG(4) << child->id() << " ready to merge " << node->id() << " with " << child->id(); + VLOG(4) << child->id() << " ready to merge " << node->id() << " with " + << child->id(); DataFlowGraphNode::MergeGroup(node, child); } queue.push_back(child); @@ -78,7 +80,8 @@ std::vector PartitionGraphByIterationDomain(common::Graph* graph) { for (auto* n : graph->nodes()) { auto* node = n->safe_as(); for (auto& compute_at : node->stage->compute_ats()) { - CHECK(compute_at.IsCompatible(node->stage.get())) << "The registered ComputeAt is not compatible"; + CHECK(compute_at.IsCompatible(node->stage.get())) + << "The registered ComputeAt is not compatible"; // check the endpoints of compute_at has data dependency. auto* node0 = node; auto* node1 = name2node[compute_at.stage->id()]; @@ -95,12 +98,12 @@ std::vector PartitionGraphByIterationDomain(common::Graph* graph) { std::map> node_groups; - auto topo_order = graph->topological_order(); + auto topo_order = graph->topological_order(); auto& nodes_in_order = std::get<0>(topo_order); auto& edges_in_order = std::get<1>(topo_order); for (auto* n : nodes_in_order) { - auto* node = n->safe_as(); + auto* node = n->safe_as(); auto* ancestor = node->group_ancestor(); if (!groups_gathered.count(ancestor)) { groups_gathered.insert(ancestor); @@ -121,7 +124,8 @@ std::vector PartitionGraphByIterationDomain(common::Graph* graph) { } // NOTE DEBUG - // check there are same count of nodes both in the orginal graph and the groups. + // check there are same count of nodes both in the orginal graph and the + // groups. // @{ int num_node_in_groups = 0; for (auto& group : groups) num_node_in_groups += group.nodes.size(); @@ -131,9 +135,10 @@ std::vector PartitionGraphByIterationDomain(common::Graph* graph) { return groups; } -//! Check whether a group partition is valid. The ComputeAt and some other transform may broke data dependency, use this -//! to check validity. -// TODO(Superjomn) Implement this and integrate it into ComputeAt transform for checking transform validity. +//! Check whether a group partition is valid. The ComputeAt and some other +//! transform may broke data dependency, use this to check validity. +// TODO(Superjomn) Implement this and integrate it into ComputeAt transform for +// checking transform validity. bool CheckGroupValid(const std::vector& groups) { CINN_NOT_IMPLEMENTED return false; @@ -164,7 +169,9 @@ bool IsLinkTo(const common::GraphNode* a, const common::GraphNode* b) { return false; } -bool IsBetween(const common::GraphNode* x, const common::GraphNode* a, const common::GraphNode* b) { +bool IsBetween(const common::GraphNode* x, + const common::GraphNode* a, + const common::GraphNode* b) { if (IsLinkTo(a, x) && IsLinkTo(x, b)) return true; if (IsLinkTo(x, a) && IsLinkTo(b, x)) return true; return false; @@ -178,13 +185,14 @@ std::vector TopoSortGroups(std::vector& groups) { std::vector group_order; absl::flat_hash_map node2group; for (int i = 0; i < groups.size(); i++) { - Group* group = &groups[i]; + Group* group = &groups[i]; int in_degree = 0; for (auto& node : group->nodes) { node2group[node->id()] = group; in_degree += node->inlinks().size(); for (auto& node2 : group->nodes) { - if (node2->as()->IsLinkedTo(node->as())) { + if (node2->as()->IsLinkedTo( + node->as())) { in_degree--; } } @@ -215,7 +223,8 @@ std::vector TopoSortGroups(std::vector& groups) { for (auto& edge : node->outlinks()) { CHECK_EQ(edge->source()->id(), node->id()); auto* sink = edge->sink(); - if (all_nodes.count(sink->id()) == 0 && (--group_indegree[node2group[sink->id()]]) == 0) { + if (all_nodes.count(sink->id()) == 0 && + (--group_indegree[node2group[sink->id()]]) == 0) { queue.push_back(node2group[sink->id()]); } } @@ -228,12 +237,12 @@ std::vector TopoSortGroups(std::vector& groups) { * Naive idea to split a graph. * * 1. treat each stage as a seperate group. - * 2. If ComputeAt is set between two stages and their iteration domain matches, the stages will be put in a group with - * relative order. + * 2. If ComputeAt is set between two stages and their iteration domain matches, + * the stages will be put in a group with relative order. */ std::vector NaivePartitionGraph(common::Graph* graph) { std::map> node_groups; - auto topo_order = graph->topological_order(); + auto topo_order = graph->topological_order(); auto& nodes_in_order = std::get<0>(topo_order); auto& edges_in_order = std::get<1>(topo_order); @@ -243,18 +252,21 @@ std::vector NaivePartitionGraph(common::Graph* graph) { } // process compute_at - absl::flat_hash_map node2score; // record each node's score for sorting. + absl::flat_hash_map + node2score; // record each node's score for sorting. int score = 0; for (auto* n : nodes_in_order) { - auto* node = n->safe_as(); + auto* node = n->safe_as(); node2score[node] = score++; for (ComputeAtRelation& compute_at : node->stage->compute_ats()) { - CHECK(compute_at.IsCompatible(node->stage.get())) << "The registered ComputeAt is not compatible"; + CHECK(compute_at.IsCompatible(node->stage.get())) + << "The registered ComputeAt is not compatible"; // check the endpoints of compute_at has data dependency. auto* node0 = node; if (name2node.count(compute_at.stage->id()) == 0) { continue; - LOG(FATAL) << "Didn't find node with name " << compute_at.stage->id() << " !"; + LOG(FATAL) << "Didn't find node with name " << compute_at.stage->id() + << " !"; } auto* node1 = name2node[compute_at.stage->id()]; VLOG(3) << "a -> b: " << node0->id() << " -> " << node1->id(); @@ -263,7 +275,8 @@ std::vector NaivePartitionGraph(common::Graph* graph) { // process single level of outlinks for (auto& outlink : node0->outlinks()) { if (IsBetween(outlink->sink(), node0, node1)) { - DataFlowGraphNode::MergeGroup(node0, outlink->sink()->safe_as()); + DataFlowGraphNode::MergeGroup( + node0, outlink->sink()->safe_as()); } } @@ -271,7 +284,9 @@ std::vector NaivePartitionGraph(common::Graph* graph) { } } // generate final groups. - absl::flat_hash_map> clusters; + absl::flat_hash_map> + clusters; for (auto* n : nodes_in_order) { auto* node = n->safe_as(); clusters[node->group_ancestor()].push_back(node); @@ -297,7 +312,8 @@ std::vector NaivePartitionGraph(common::Graph* graph) { graph_node_count += group.nodes.size(); } // check the groups contains all the nodes in graph. - CHECK_EQ(graph_node_count, graph->nodes().size()) << "the groups should contain all the nodes in the graph"; + CHECK_EQ(graph_node_count, graph->nodes().size()) + << "the groups should contain all the nodes in the graph"; #endif return group_order; @@ -314,13 +330,15 @@ std::unique_ptr PolyScheduler::BuildSchedule() { // transform the DFG groups to schedule groups. CHECK(!schedule_graph_.nodes().empty()); - CHECK_EQ(schedule_graph_.nodes().size(), dfg_->nodes().size()) << "DFG graph is not match schedule graph"; + CHECK_EQ(schedule_graph_.nodes().size(), dfg_->nodes().size()) + << "DFG graph is not match schedule graph"; schedule_groups_.clear(); for (auto& dfg_group : dfg_groups) { ScheduleGroup group; for (auto& node : dfg_group.nodes) { auto* schedule_node = schedule_graph_.RetrieveNode(node->id()); - CHECK(schedule_node) << "missing node " << node->id() << " in schedule graph"; + CHECK(schedule_node) << "missing node " << node->id() + << " in schedule graph"; group.nodes.push_back(schedule_node->safe_as()); } schedule_groups_.emplace_back(std::move(group)); @@ -335,15 +353,17 @@ std::unique_ptr PolyScheduler::BuildSchedule() { for (auto& group : schedule_groups_) { for (auto& node : group.nodes) { - res->schedule[node->id()] = node->time_schedule.to_isl(Context::isl_ctx()); + res->schedule[node->id()] = + node->time_schedule.to_isl(Context::isl_ctx()); } } return res; } -PolyScheduler::PolyScheduler(const std::vector& stages, - const std::vector>& extra_links) { +PolyScheduler::PolyScheduler( + const std::vector& stages, + const std::vector>& extra_links) { CHECK(!stages.empty()) << "No stage is provided"; // collect extra links @@ -360,7 +380,8 @@ PolyScheduler::PolyScheduler(const std::vector& stages, FinishStageAdd(); } -std::vector PolyScheduler::PartitionGroups(DataFlowGraph* graph) { +std::vector PolyScheduler::PartitionGroups( + DataFlowGraph* graph) { CHECK(graph); CHECK(!graph->nodes().empty()); return detail::NaivePartitionGraph(graph); @@ -377,7 +398,7 @@ void PolyScheduler::ScheduleAGroup(ScheduleGroup* group) { } PolyGroupScheduler scheduler(stages); - group->nodes = scheduler.Build(); + group->nodes = scheduler.Build(); group->dimension_names = scheduler.detailed_dimension_names(); } @@ -393,7 +414,7 @@ std::vector> PolyGroupScheduler::Build() { std::map stage_map; std::map compute_at_links; for (int i = 0; i < stages_.size(); i++) { - auto& stage = stages_[i]; + auto& stage = stages_[i]; stage_map[stage->tensor_->name] = stage; for (auto& item : stage->compute_ats()) { compute_at_links[stage->tensor_->name] = item; @@ -401,7 +422,8 @@ std::vector> PolyGroupScheduler::Build() { } std::map stage_level; for (auto& link : compute_at_links) { - CHECK_NE(stage_map.count(link.first), 0) << link.first << " not found in stage_map"; + CHECK_NE(stage_map.count(link.first), 0) + << link.first << " not found in stage_map"; CHECK_NE(stage_map.count(link.second.stage->tensor_->name), 0) << link.second.stage->tensor_->name << " not found in stage_map"; auto* a = stage_map.at(link.first); @@ -419,17 +441,20 @@ std::vector> PolyGroupScheduler::Build() { // a -> b not in the compute_at_links if (!compute_at_links.count(a->tensor_->name) || - compute_at_links[a->tensor_->name].stage->tensor_->name != b->tensor_->name) { + compute_at_links[a->tensor_->name].stage->tensor_->name != + b->tensor_->name) { int min_level = INT_MAX; - if (stage_level.count(a->id())) min_level = std::min(min_level, stage_level[a->id()]); - if (stage_level.count(b->id())) min_level = std::min(min_level, stage_level[b->id()]); + if (stage_level.count(a->id())) + min_level = std::min(min_level, stage_level[a->id()]); + if (stage_level.count(b->id())) + min_level = std::min(min_level, stage_level[b->id()]); if (min_level < INT_MAX) { After(*a, *b, min_level); } } } - auto topo_order = schedule_graph_.topological_order(); + auto topo_order = schedule_graph_.topological_order(); auto& nodes_in_order = std::get<0>(topo_order); auto& edges_in_order = std::get<1>(topo_order); std::vector> res; @@ -438,9 +463,10 @@ std::vector> PolyGroupScheduler::Build() { for (auto& edge : edges_in_order) { auto* node0 = edge->source()->safe_as(); auto* node1 = edge->sink()->safe_as(); - int level = edge->as()->level; + int level = edge->as()->level; if (level < 0) continue; - VLOG(2) << "schedule " << node0->id() << " -> " << node1->id() << " level " << level; + VLOG(2) << "schedule " << node0->id() << " -> " << node1->id() << " level " + << level; node1->time_schedule.OrderAfter(node0->time_schedule, level); } @@ -450,7 +476,8 @@ std::vector> PolyGroupScheduler::Build() { return res; } -PolyGroupScheduler::PolyGroupScheduler(const std::vector& stages) : stages_(stages) { +PolyGroupScheduler::PolyGroupScheduler(const std::vector& stages) + : stages_(stages) { CHECK_GT(stages.size(), 0) << "No stage is provided"; for (auto* stage : stages) { AddStage(*stage); diff --git a/paddle/cinn/poly/poly_scheduler.h b/paddle/cinn/poly/poly_scheduler.h index df33a06178251..e44cb9f055ed0 100644 --- a/paddle/cinn/poly/poly_scheduler.h +++ b/paddle/cinn/poly/poly_scheduler.h @@ -33,7 +33,8 @@ namespace cinn { namespace poly { /** - * Schedule a single group with iterator domain considered and follow the stage order. + * Schedule a single group with iterator domain considered and follow the stage + * order. */ class PolyGroupScheduler : public SchedulerBase { public: @@ -49,7 +50,8 @@ class PolyGroupScheduler : public SchedulerBase { /** * PolyScheduler - Perform schedule on polyhedral model. - * It takes a normal schedule as input, merge two stages automatically if they have the same domain. + * It takes a normal schedule as input, merge two stages automatically if they + * have the same domain. */ class PolyScheduler : public SchedulerBase { public: @@ -61,8 +63,9 @@ class PolyScheduler : public SchedulerBase { * '{ S[i,j] -> [i_outer, i_inner, j]: i_outer=floor(i/4) and i_inner=i%4 }' * that's OK. */ - explicit PolyScheduler(const std::vector &stages, - const std::vector> &extra_links = {}); + explicit PolyScheduler( + const std::vector &stages, + const std::vector> &extra_links = {}); /** * Build and create schedule. diff --git a/paddle/cinn/poly/schedule.cc b/paddle/cinn/poly/schedule.cc index 0c2137d4d8d2b..43357dbdfb104 100644 --- a/paddle/cinn/poly/schedule.cc +++ b/paddle/cinn/poly/schedule.cc @@ -49,8 +49,12 @@ std::string TimeSchedule::__str__() const { std::vector conds; conds.push_back(utils::StringFormat("r=%d", root_time_)); for (int i = 0; i < time_dims_.size(); i++) { - conds.push_back(utils::StringFormat("%s=%s", cond_dims[2 * i].c_str(), std::to_string(time_dims_[i].time).c_str())); - conds.push_back(utils::StringFormat("%s=%s", cond_dims[2 * i + 1].c_str(), time_dims_[i].dim.c_str())); + conds.push_back( + utils::StringFormat("%s=%s", + cond_dims[2 * i].c_str(), + std::to_string(time_dims_[i].time).c_str())); + conds.push_back(utils::StringFormat( + "%s=%s", cond_dims[2 * i + 1].c_str(), time_dims_[i].dim.c_str())); } return utils::StringFormat("{ %s[%s] -> [%s]: %s }", @@ -69,9 +73,10 @@ std::vector TimeSchedule::final_axis_names() const { return dims; } -TimeSchedule::TimeSchedule(const std::string &id, const std::vector &dims) { +TimeSchedule::TimeSchedule(const std::string &id, + const std::vector &dims) { CHECK_LE(dims.size(), kMaxDims); - id_ = id; + id_ = id; domain_dims = dims; for (auto &dim : domain_dims) { CHECK(!dim.empty()); @@ -92,10 +97,12 @@ void TimeSchedule::OrderAfter(const TimeSchedule &other, int level) { } for (int i = 0; i < level; i++) { - this->time_dims_[i].time = std::max(other.time_dims_[i].time, this->time_dims_[i].time); + this->time_dims_[i].time = + std::max(other.time_dims_[i].time, this->time_dims_[i].time); } - this->time_dims_[level].time = std::max(this->time_dims_[level].time, other.time_dims_[level].time + 1); + this->time_dims_[level].time = + std::max(this->time_dims_[level].time, other.time_dims_[level].time + 1); } isl::map TimeSchedule::to_isl(isl::ctx ctx) const { @@ -116,16 +123,17 @@ void TimeSchedule::ResizeTimeSpace(int size) { } /* -std::unique_ptr CreateSchedule(const ir::Tensor &tensor, ScheduleKind schedule_kind) { - auto stages = GatherStagesInTensors({tensor}); - VLOG(3) << "collected " << stages.size() << " stages"; - return CreateSchedule(stages, schedule_kind); +std::unique_ptr CreateSchedule(const ir::Tensor &tensor, ScheduleKind +schedule_kind) { auto stages = GatherStagesInTensors({tensor}); VLOG(3) << +"collected " << stages.size() << " stages"; return CreateSchedule(stages, +schedule_kind); } */ -std::unique_ptr CreateSchedule(const std::vector &stages, - ScheduleKind schedule_kind, - const std::vector> &extra_links) { +std::unique_ptr CreateSchedule( + const std::vector &stages, + ScheduleKind schedule_kind, + const std::vector> &extra_links) { CHECK(!stages.empty()); for (auto &stage : stages) { VLOG(4) << "stage: " << stage->domain(); @@ -145,7 +153,8 @@ std::unique_ptr CreateSchedule(const std::vector &stages, return nullptr; } -std::map CollectScheduleMapFromGroup(const ScheduleGroup &group) { +std::map CollectScheduleMapFromGroup( + const ScheduleGroup &group) { std::map map; std::vector stages; @@ -162,19 +171,25 @@ std::map CollectScheduleMapFromGroup(const ScheduleGroup void SchedulerBase::AddStage(const Stage &x) { CHECK(!registration_finalized_) << "element registration has been finalized."; - space_size_ = std::max(space_size_, isl_map_dim(x.transform().get(), isl_dim_out)); + space_size_ = + std::max(space_size_, isl_map_dim(x.transform().get(), isl_dim_out)); VLOG(3) << "space_size: " << space_size_; VLOG(3) << "schedule: " << x.transform(); - // Use the dimensions from element's schedule's range as the new domain dimensions because in Element, the schedule is - // like '{ S0[i,j] -> S0[i_outer, i_inner, j] }', the scheduler should schedule base on the range. - auto dims = isl_get_dim_names(x.transform(), isl_dim_out); + // Use the dimensions from element's schedule's range as the new domain + // dimensions because in Element, the schedule is like '{ S0[i,j] -> + // S0[i_outer, i_inner, j] }', the scheduler should schedule base on the + // range. + auto dims = isl_get_dim_names(x.transform(), isl_dim_out); std::string id = isl_map_get_tuple_name(x.transform().get(), isl_dim_in); schedule_graph_.RegisterNode( - x.id(), common::make_shared(id, isl_get_dim_names(x.transform(), isl_dim_out), &x)); + x.id(), + common::make_shared( + id, isl_get_dim_names(x.transform(), isl_dim_out), &x)); // record the longest dimensions. - if (dims.size() > detailed_dimension_names_.size()) detailed_dimension_names_ = dims; + if (dims.size() > detailed_dimension_names_.size()) + detailed_dimension_names_ = dims; if (!ctx_.get()) { ctx_ = x.domain().ctx(); @@ -191,10 +206,10 @@ void SchedulerBase::FinishStageAdd() { if (depend_node) { // some dependencies might be in another graph. auto *a_node = depend_node->safe_as(); auto *b_node = node->safe_as(); - auto _a_edge_b_edge_ = - a_node->LinkTo(b_node); // Add link from extra depend statment to current node. - auto &a_edge = std::get<0>(_a_edge_b_edge_); - auto &b_edge = std::get<1>(_a_edge_b_edge_); + auto _a_edge_b_edge_ = a_node->LinkTo( + b_node); // Add link from extra depend statment to current node. + auto &a_edge = std::get<0>(_a_edge_b_edge_); + auto &b_edge = std::get<1>(_a_edge_b_edge_); a_edge->as()->level = -1; b_edge->as()->level = -1; } @@ -202,17 +217,23 @@ void SchedulerBase::FinishStageAdd() { } CHECK(!schedule_graph_.nodes().empty()) - << "No node is registered to the graph, use RegisterElement to collect some elements"; + << "No node is registered to the graph, use RegisterElement to collect " + "some elements"; registration_finalized_ = true; for (auto &item : schedule_graph_.nodes()) { - VLOG(6) << "original dims in time_schedule: " - << utils::Join(item->safe_as()->time_schedule.domain_dims, ", "); - item->safe_as()->time_schedule.ResizeTimeSpace(space_size_); + VLOG(6) + << "original dims in time_schedule: " + << utils::Join( + item->safe_as()->time_schedule.domain_dims, + ", "); + item->safe_as()->time_schedule.ResizeTimeSpace( + space_size_); } } -std::vector SchedulerBase::WrapIteratorNames(const std::vector &names) { +std::vector SchedulerBase::WrapIteratorNames( + const std::vector &names) { std::vector res; for (int i = 0; i < names.size(); i++) { res.push_back(""); // fake name for time space. @@ -223,27 +244,35 @@ std::vector SchedulerBase::WrapIteratorNames(const std::vectorsafe_as(); - auto *b_node = schedule_graph_.RetrieveNode(b.id())->safe_as(); + auto *a_node = + schedule_graph_.RetrieveNode(a.id())->safe_as(); + auto *b_node = + schedule_graph_.RetrieveNode(b.id())->safe_as(); CHECK(a_node) << "no node called " << a.id() << " registered in the graph"; CHECK(b_node) << "no node called " << b.id() << " registered in the graph"; - auto _a_edge_b_edge_ = a_node->LinkTo(b_node); // NOLINT - auto &a_edge = std::get<0>(_a_edge_b_edge_); - auto &b_edge = std::get<1>(_a_edge_b_edge_); + auto _a_edge_b_edge_ = a_node->LinkTo(b_node); // NOLINT + auto &a_edge = std::get<0>(_a_edge_b_edge_); + auto &b_edge = std::get<1>(_a_edge_b_edge_); a_edge->as()->level = level; b_edge->as()->level = level; - VLOG(2) << "In After, Set [" << a.id() << "] -> [b: ]" << b.id() << "] with level = " << level; + VLOG(2) << "In After, Set [" << a.id() << "] -> [b: ]" << b.id() + << "] with level = " << level; return *this; } -SchedulerBase &SchedulerBase::Before(const Stage &a, const Stage &b, int level) { return After(b, a, level); } +SchedulerBase &SchedulerBase::Before(const Stage &a, + const Stage &b, + int level) { + return After(b, a, level); +} std::map SchedulerBase::schedule_map() const { std::map res; for (auto &node : schedule_graph_.nodes()) { - auto *schedule_node = node->safe_as(); - res[schedule_node->id()] = schedule_node->time_schedule.to_isl(Context::isl_ctx()); + auto *schedule_node = node->safe_as(); + res[schedule_node->id()] = + schedule_node->time_schedule.to_isl(Context::isl_ctx()); } return res; } diff --git a/paddle/cinn/poly/schedule.h b/paddle/cinn/poly/schedule.h index 43318a5f1bcbd..1c28c5961e4fd 100755 --- a/paddle/cinn/poly/schedule.h +++ b/paddle/cinn/poly/schedule.h @@ -45,7 +45,9 @@ struct TimeDim { std::string dim; TimeDim() = default; - TimeDim(const std::string &dim, int time) : dim(dim), time(time) { CHECK(!dim.empty()); } + TimeDim(const std::string &dim, int time) : dim(dim), time(time) { + CHECK(!dim.empty()); + } }; class ScheduleGraphNode; @@ -53,7 +55,8 @@ struct ScheduleGraph : public common::Graph {}; /** * ISL schedule map with time space, used to generate the final schedule. - * The map it generates is like { [x,y] -> [t0,x,t1,y] }, the t0 and t1 are time space. + * The map it generates is like { [x,y] -> [t0,x,t1,y] }, the t0 and t1 are time + * space. */ struct TimeSchedule { TimeSchedule(const std::string &id, const std::vector &dims); @@ -75,7 +78,8 @@ struct TimeSchedule { //! ISL range format, such as '[dup, t0, t1]: dup=0 and t0=0 and t1=i]' std::string __str__() const; - //! Get the axis names with the original dimension names and faked time dimensions. + //! Get the axis names with the original dimension names and faked time + //! dimensions. std::vector final_axis_names() const; std::vector domain_dims; @@ -91,7 +95,8 @@ struct TimeSchedule { struct ScheduleGroup; /** - * A container type to contain the schedule information of a graph(several groups). + * A container type to contain the schedule information of a graph(several + * groups). */ struct Schedule { //! The schedule groups partitioned from the graph. @@ -101,17 +106,20 @@ struct Schedule { }; /** - * The base class for all the Scheduler, it helps to schedule the nodes in a group(isl space). All the schedule in the - * same group should have the same number of dimensions, and each have some dependency with others. + * The base class for all the Scheduler, it helps to schedule the nodes in a + * group(isl space). All the schedule in the same group should have the same + * number of dimensions, and each have some dependency with others. */ class SchedulerBase { public: /** - * Wrap the iterator names with time space fake names, it is used for isl AST to set iterator names. + * Wrap the iterator names with time space fake names, it is used for isl AST + * to set iterator names. * @param names the original iterator names. * @return the iterator names with time space included. */ - static std::vector WrapIteratorNames(const std::vector &names); + static std::vector WrapIteratorNames( + const std::vector &names); /** * Mark this should schedule after another. @@ -129,7 +137,9 @@ class SchedulerBase { std::map schedule_map() const; - const std::vector &detailed_dimension_names() const { return detailed_dimension_names_; } + const std::vector &detailed_dimension_names() const { + return detailed_dimension_names_; + } protected: /** @@ -151,14 +161,15 @@ class SchedulerBase { protected: /** * The polyhedral schedule, any schedule is performed on it. - * We use the time-space map to record the schedule information, the format is borrowed from Tiramisu project: - * [time,dim,time,dim,time,dim ...] + * We use the time-space map to record the schedule information, the format is + * borrowed from Tiramisu project: [time,dim,time,dim,time,dim ...] */ int space_size_{0}; mutable isl::ctx ctx_{Context::isl_ctx()}; mutable ScheduleGraph schedule_graph_; - // Record the longest dimensions(of some stage) to be the final detailed dimension names. It might be used for ISL AST - // to set iterator names and generate readable code. + // Record the longest dimensions(of some stage) to be the final detailed + // dimension names. It might be used for ISL AST to set iterator names and + // generate readable code. mutable std::vector detailed_dimension_names_; private: @@ -176,12 +187,14 @@ enum class ScheduleKind { }; //! Create a schedule from a tensor. -// std::unique_ptr CreateSchedule(const ir::Tensor &tensor, ScheduleKind schedule_kind = ScheduleKind::Poly); -//! Create a schedule from a list of stages, it will schedule the stages using the information from data dependency, -//! iteration domains. -std::unique_ptr CreateSchedule(const std::vector &stages, - ScheduleKind schedule_kind = ScheduleKind::Poly, - const std::vector> &extra_links = {}); +// std::unique_ptr CreateSchedule(const ir::Tensor &tensor, +// ScheduleKind schedule_kind = ScheduleKind::Poly); +//! Create a schedule from a list of stages, it will schedule the stages using +//! the information from data dependency, iteration domains. +std::unique_ptr CreateSchedule( + const std::vector &stages, + ScheduleKind schedule_kind = ScheduleKind::Poly, + const std::vector> &extra_links = {}); /** * Gather the stages in the input tensors and their dependencies @@ -189,10 +202,12 @@ std::unique_ptr CreateSchedule(const std::vector &stages, * @param with_placeholder Whether to include placeholders(default false). * @returns The stages in topological order follow the connection to `xs`. */ -// std::vector GatherStagesInTensors(const std::vector &xs, bool with_placeholder = false); +// std::vector GatherStagesInTensors(const std::vector &xs, +// bool with_placeholder = false); struct ScheduleGraphEdge : public common::GraphEdge { - ScheduleGraphEdge(common::GraphNode *a, common::GraphNode *b) : common::GraphEdge(a, b) {} + ScheduleGraphEdge(common::GraphNode *a, common::GraphNode *b) + : common::GraphEdge(a, b) {} //! Dependency level. int level{-1}; @@ -206,10 +221,13 @@ struct ScheduleGraphNode : public common::GraphNode { Stage *stage{}; //! NOTE this id is not human-readable. - // std::string id() const override { return std::to_string(reinterpret_cast(this)); } + // std::string id() const override { return + // std::to_string(reinterpret_cast(this)); } std::string id() const override { return time_schedule.id(); } - explicit ScheduleGraphNode(const std::string &id, const std::vector &dims, const Stage *stage) + explicit ScheduleGraphNode(const std::string &id, + const std::vector &dims, + const Stage *stage) : time_schedule(id, dims), stage(const_cast(stage)) {} const char *type_info() const override { return __type_info__; } @@ -222,7 +240,8 @@ struct ScheduleGroup { std::vector dimension_names; }; -std::map CollectScheduleMapFromGroup(const ScheduleGroup &group); +std::map CollectScheduleMapFromGroup( + const ScheduleGroup &group); } // namespace poly } // namespace cinn diff --git a/paddle/cinn/poly/schedule_test.cc b/paddle/cinn/poly/schedule_test.cc index 3390b524e1aea..bc17551269585 100755 --- a/paddle/cinn/poly/schedule_test.cc +++ b/paddle/cinn/poly/schedule_test.cc @@ -65,7 +65,8 @@ TEST(CreateStages, buffer_bind_to_multiple_tensors_schedule) { Expr N(100); lang::Placeholder A("A", {N, N}); /* - * We create three tensors all binded to the same buffer, but has no depend in computation. + * We create three tensors all binded to the same buffer, but has no depend in + * computation. */ auto B = lang::Compute( diff --git a/paddle/cinn/poly/stage.cc b/paddle/cinn/poly/stage.cc index 172edd1457825..7b6f18c7acc60 100644 --- a/paddle/cinn/poly/stage.cc +++ b/paddle/cinn/poly/stage.cc @@ -67,26 +67,34 @@ std::vector NamesToIterators(const std::vector &names) { void Stage::InitTransform() { std::string id = isl_set_get_tuple_name(domain_.get()); - auto dims = isl_get_dim_names(domain_); + auto dims = isl_get_dim_names(domain_); auto dims_repr = utils::Join(dims, ", "); - auto repr = utils::StringFormat("{ %s[%s] -> %s[%s] }", id.c_str(), dims_repr.c_str(), id.c_str(), dims_repr.c_str()); + auto repr = utils::StringFormat("{ %s[%s] -> %s[%s] }", + id.c_str(), + dims_repr.c_str(), + id.c_str(), + dims_repr.c_str()); transform_ = isl::map(domain_.ctx(), repr); // set dimension names for (int i = 0; i < dims.size(); i++) { - transform_ = isl::manage(isl_map_set_dim_name(transform_.release(), isl_dim_in, i, dims[i].c_str())); - transform_ = isl::manage(isl_map_set_dim_name(transform_.release(), isl_dim_out, i, dims[i].c_str())); + transform_ = isl::manage(isl_map_set_dim_name( + transform_.release(), isl_dim_in, i, dims[i].c_str())); + transform_ = isl::manage(isl_map_set_dim_name( + transform_.release(), isl_dim_out, i, dims[i].c_str())); } } -Stage::Stage(const isl::set &domain, Expr expr, ir::_Tensor_ *tensor) : domain_(domain), expr_(expr), tensor_(tensor) { +Stage::Stage(const isl::set &domain, Expr expr, ir::_Tensor_ *tensor) + : domain_(domain), expr_(expr), tensor_(tensor) { CHECK(!domain_.is_null()); CHECK(!domain_.is_empty()); InitTransform(); } -std::tuple Stage::SplitOuter(const std::string &level, int nparts) { +std::tuple Stage::SplitOuter(const std::string &level, + int nparts) { return std::move(SplitOuter(Iterator(level), nparts)); } @@ -99,24 +107,27 @@ std::tuple Stage::SplitOuter(int level, int nparts) { int Stage::GetDimRange(int level) { auto _minv_maxv_ = isl_set_get_axis_range(transformed_domain().get(), level); - auto &minv = std::get<0>(_minv_maxv_); - auto &maxv = std::get<1>(_minv_maxv_); - int max_iv = maxv.get_num_si(); - int min_iv = minv.get_num_si(); - CHECK_EQ(0, min_iv) << "The min range of level " << level << " in " << id() << " is not 0!"; + auto &minv = std::get<0>(_minv_maxv_); + auto &maxv = std::get<1>(_minv_maxv_); + int max_iv = maxv.get_num_si(); + int min_iv = minv.get_num_si(); + CHECK_EQ(0, min_iv) << "The min range of level " << level << " in " << id() + << " is not 0!"; return max_iv + 1; } -std::tuple Stage::SplitOuter(const Iterator &level, int nparts) { - int offset = isl_set_find_dim_by_name(transformed_domain().get(), isl_dim_set, level.id.c_str()); +std::tuple Stage::SplitOuter(const Iterator &level, + int nparts) { + int offset = isl_set_find_dim_by_name( + transformed_domain().get(), isl_dim_set, level.id.c_str()); CHECK_GE(offset, 0) << "iterator " << level << " not in " << domain_; AssertAxisIsNotLocked(offset); auto _minv_maxv_ = isl_set_get_axis_range(transformed_domain().get(), offset); - auto &minv = std::get<0>(_minv_maxv_); - auto &maxv = std::get<1>(_minv_maxv_); - int max_iv = maxv.get_num_si(); - auto dim_names = isl_get_dim_names(transform_, isl_dim_out); - double temp = static_cast(max_iv + 1.0) / static_cast(nparts); + auto &minv = std::get<0>(_minv_maxv_); + auto &maxv = std::get<1>(_minv_maxv_); + int max_iv = maxv.get_num_si(); + auto dim_names = isl_get_dim_names(transform_, isl_dim_out); + double temp = static_cast(max_iv + 1.0) / static_cast(nparts); int factor_inner = ceil(temp); return Split(level, factor_inner); } @@ -129,7 +140,8 @@ std::tuple Stage::Split(int level, int factor) { } std::tuple Stage::Split(const Iterator &level, int factor) { - int offset = isl_set_find_dim_by_name(transformed_domain().get(), isl_dim_set, level.id.c_str()); + int offset = isl_set_find_dim_by_name( + transformed_domain().get(), isl_dim_set, level.id.c_str()); CHECK_GE(offset, 0) << "iterator " << level << " not in " << domain_; AssertAxisIsNotLocked(offset); @@ -148,9 +160,11 @@ std::tuple Stage::Split(const Iterator &level, int factor) { to_iters.push_back(outer_iter); to_iters.push_back(inner_iter); - conds.emplace_back(utils::StringFormat("%s=floor(%s/%d)", outer_iter.id.c_str(), level.id.c_str(), factor)); + conds.emplace_back(utils::StringFormat( + "%s=floor(%s/%d)", outer_iter.id.c_str(), level.id.c_str(), factor)); VLOG(3) << "outer cond: " << conds.back(); - conds.emplace_back(utils::StringFormat("%s=%s %s %d", inner_iter.id.c_str(), level.id.c_str(), "%", factor)); + conds.emplace_back(utils::StringFormat( + "%s=%s %s %d", inner_iter.id.c_str(), level.id.c_str(), "%", factor)); VLOG(3) << "inner cond: " << conds.back(); } else { @@ -160,8 +174,9 @@ std::tuple Stage::Split(const Iterator &level, int factor) { Map transform(domain_.ctx(), id(), from_iters, to_iters, conds, id()); VLOG(3) << "transform: " << transform.__str__(); - transform_ = transform_.apply_range(transform.to_isl()); - auto range_dims = utils::Map, std::string>(to_iters, [](const Iterator &x) { return x.id; }); + transform_ = transform_.apply_range(transform.to_isl()); + auto range_dims = utils::Map, std::string>( + to_iters, [](const Iterator &x) { return x.id; }); isl_set_dim_names(&transform_, isl_dim_out, range_dims); VLOG(3) << "transform " << transform.to_isl(); @@ -179,7 +194,8 @@ void Stage::Reorder(const std::vector &order) { std::vector range_iters, domain_iters; for (auto &o : order) { - CHECK(in_name_set.count(o.id)) << "Iterator " << o.id << " not int the exsting axis"; + CHECK(in_name_set.count(o.id)) + << "Iterator " << o.id << " not int the exsting axis"; } int order_offset = 0; @@ -216,21 +232,21 @@ Stage::Tile(int level0, int level1, int factor0, int factor1) { return Tile(i0, i1, factor0, factor1); } -std::tuple Stage::Tile(const Iterator &level0, - const Iterator &level1, - int factor0, - int factor1) { +std::tuple Stage::Tile( + const Iterator &level0, const Iterator &level1, int factor0, int factor1) { auto _level0_outer_level0_inner_ = Split(level0, factor0); // NOLINT - auto &level0_outer = std::get<0>(_level0_outer_level0_inner_); - auto &level0_inner = std::get<1>(_level0_outer_level0_inner_); + auto &level0_outer = std::get<0>(_level0_outer_level0_inner_); + auto &level0_inner = std::get<1>(_level0_outer_level0_inner_); auto _level1_outer_level1_inner_ = Split(level1, factor1); // NOLINT - auto &level1_outer = std::get<0>(_level1_outer_level1_inner_); - auto &level1_inner = std::get<1>(_level1_outer_level1_inner_); - return std::make_tuple(level0_outer, level0_inner, level1_outer, level1_inner); + auto &level1_outer = std::get<0>(_level1_outer_level1_inner_); + auto &level1_inner = std::get<1>(_level1_outer_level1_inner_); + return std::make_tuple( + level0_outer, level0_inner, level1_outer, level1_inner); } void Stage::ComputeAtSchedule(Stage *other, int level, ComputeAtKind kind) { - // TODO(Superjomn) Check there are data dependency between `self` and `other`, or the `ComputeAt` is meaningless. + // TODO(Superjomn) Check there are data dependency between `self` and `other`, + // or the `ComputeAt` is meaningless. CHECK(other->tensor()); CHECK(tensor()); @@ -238,9 +254,10 @@ void Stage::ComputeAtSchedule(Stage *other, int level, ComputeAtKind kind) { relation.stage = other; relation.level = level; - CHECK(relation.IsCompatible(this)) << "Cannot apply ComputeAtSchedule with level: " << level << " from \n" - << isl_set_to_str(this->transformed_domain().get()) << "\n to \n" - << isl_set_to_str(other->transformed_domain().get()); + CHECK(relation.IsCompatible(this)) + << "Cannot apply ComputeAtSchedule with level: " << level << " from \n" + << isl_set_to_str(this->transformed_domain().get()) << "\n to \n" + << isl_set_to_str(other->transformed_domain().get()); compute_ats_[other->id()] = relation; // Consider the order if provide. @@ -263,7 +280,8 @@ void Stage::ComputeAtSchedule(Stage *other, int level, ComputeAtKind kind) { } void Stage::ChangeIndex(Stage *other) { - auto indices = optim::CollectTensorIndex(&(other->expr_), this->tensor()->name); + auto indices = + optim::CollectTensorIndex(&(other->expr_), this->tensor()->name); RemoveDuplicate(indices); if (indices.empty()) { return; @@ -309,33 +327,41 @@ void Stage::AddForLoopInTransform(std::vector> &indices) { std::string dim_name = common::axis_name(i) + "_at"; Var dim_var(dim_name); - indices[0][i] = ir::Add::Make(indices[0][i], Expr(dim_var)); - std::string this_domain = isl_set_to_str(domain_.get()); + indices[0][i] = ir::Add::Make(indices[0][i], Expr(dim_var)); + std::string this_domain = isl_set_to_str(domain_.get()); std::string this_transform = isl_map_to_str(transform_.get()); - isl::ctx this_ctx = domain_.ctx(); + isl::ctx this_ctx = domain_.ctx(); isl::set domain2(this_ctx, this_domain); std::string tuple_name = isl_set_get_tuple_name(domain_.get()); - domain2 = isl::manage(isl_set_add_dims(domain2.release(), isl_dim_out, 1)); - int dim_size = isl_set_dim(domain2.get(), isl_dim_out); + domain2 = isl::manage(isl_set_add_dims(domain2.release(), isl_dim_out, 1)); + int dim_size = isl_set_dim(domain2.get(), isl_dim_out); - domain2 = isl::manage(isl_set_set_dim_name(domain2.release(), isl_dim_out, dim_size - 1, dim_name.c_str())); - domain2 = isl::manage(isl_set_set_tuple_name(domain2.release(), tuple_name.c_str())); + domain2 = isl::manage(isl_set_set_dim_name( + domain2.release(), isl_dim_out, dim_size - 1, dim_name.c_str())); + domain2 = isl::manage( + isl_set_set_tuple_name(domain2.release(), tuple_name.c_str())); std::string domain2_str = isl_set_to_str(domain2.get()); - domain2_str = domain2_str.substr(0, domain2_str.size() - 1) + "and 0 <= " + dim_name + - " <= " + std::to_string(int_range) + " }"; + domain2_str = domain2_str.substr(0, domain2_str.size() - 1) + + "and 0 <= " + dim_name + " <= " + std::to_string(int_range) + + " }"; VLOG(2) << "Edited domain is: " << domain2_str; isl::set domain_res(this_ctx, domain2_str); domain_ = domain_res; isl::map transform2(this_ctx, this_transform); - transform2 = isl::manage(isl_map_add_dims(transform2.release(), isl_dim_in, 1)); - dim_size = isl_map_dim(transform2.get(), isl_dim_in); - transform2 = isl::manage(isl_map_set_dim_name(transform2.release(), isl_dim_in, dim_size - 1, dim_name.c_str())); - transform2 = isl::manage(isl_map_set_tuple_name(transform2.release(), isl_dim_in, tuple_name.c_str())); + transform2 = + isl::manage(isl_map_add_dims(transform2.release(), isl_dim_in, 1)); + dim_size = isl_map_dim(transform2.get(), isl_dim_in); + transform2 = isl::manage(isl_map_set_dim_name( + transform2.release(), isl_dim_in, dim_size - 1, dim_name.c_str())); + transform2 = isl::manage(isl_map_set_tuple_name( + transform2.release(), isl_dim_in, tuple_name.c_str())); std::string transform2_str = isl_map_to_str(transform2.get()); - int found_index = transform2_str.find_last_of("]"); - transform2_str = transform2_str.substr(0, found_index) + ", " + dim_name + "' = " + dim_name + - transform2_str.substr(found_index, transform2_str.size() - found_index); + int found_index = transform2_str.find_last_of("]"); + transform2_str = + transform2_str.substr(0, found_index) + ", " + dim_name + + "' = " + dim_name + + transform2_str.substr(found_index, transform2_str.size() - found_index); VLOG(2) << "Edited transform is: " << transform2_str; isl::map trans_res(this_ctx, transform2_str); transform_ = trans_res; @@ -343,20 +369,22 @@ void Stage::AddForLoopInTransform(std::vector> &indices) { } /** * Change this stage's domain to be consistent with other's domain. - * @param level Change the domain lower than level to be consistent with other's domain. - * For example, when this->domain_ is "{ [i0, i1] : 0 <= i0 <= 9 and 0 <= i1 <= 9 }", - * other->domain_ is "{ [i0, i1] : 0 <= i0 <= 4 and 0 <= i1 <= 4 }" and level = 0. - * Then this->domain_ whill be changed to "{ [i0, i1] : 0 <= i0 <= 4 and 0 <= i1 <= 9 }". + * @param level Change the domain lower than level to be consistent with other's + * domain. For example, when this->domain_ is "{ [i0, i1] : 0 <= i0 <= 9 and 0 + * <= i1 <= 9 }", other->domain_ is "{ [i0, i1] : 0 <= i0 <= 4 and 0 <= i1 <= 4 + * }" and level = 0. Then this->domain_ whill be changed to "{ [i0, i1] : 0 <= + * i0 <= 4 and 0 <= i1 <= 9 }". */ void Stage::ChangeDomain(Stage *other, int level) { - auto indices = optim::CollectTensorIndex(&(other->expr_), this->tensor()->name); + auto indices = + optim::CollectTensorIndex(&(other->expr_), this->tensor()->name); if (indices.empty()) { return; } std::string this_domain = isl_set_to_str(this->domain().get()); - isl::ctx this_ctx = domain_.ctx(); - auto dim_names = isl_get_dim_names(domain_.get()); - auto map_names = isl_get_dim_names(other->transform().get(), isl_dim_out); + isl::ctx this_ctx = domain_.ctx(); + auto dim_names = isl_get_dim_names(domain_.get()); + auto map_names = isl_get_dim_names(other->transform().get(), isl_dim_out); std::set uniq_names; for (int i = 0; i <= level; i++) { uniq_names.insert(map_names[i].substr(0, 1)); @@ -365,21 +393,23 @@ void Stage::ChangeDomain(Stage *other, int level) { // instead of transformed axis(i_outer, i_inner, j, k, ...) level = uniq_names.size() - 1; for (int i = 0; i <= level; i++) { - auto _minv_maxv_ = isl_set_get_axis_range(domain_.get(), i); - auto &minv = std::get<0>(_minv_maxv_); - auto &maxv = std::get<1>(_minv_maxv_); - int min_iv = minv.get_num_si(); - int max_iv = maxv.get_num_si(); + auto _minv_maxv_ = isl_set_get_axis_range(domain_.get(), i); + auto &minv = std::get<0>(_minv_maxv_); + auto &maxv = std::get<1>(_minv_maxv_); + int min_iv = minv.get_num_si(); + int max_iv = maxv.get_num_si(); auto _minv2_maxv2_ = isl_set_get_axis_range(other->domain().get(), i); - auto &minv2 = std::get<0>(_minv2_maxv2_); - auto &maxv2 = std::get<1>(_minv2_maxv2_); - int min_tar = minv2.get_num_si(); - int max_tar = maxv2.get_num_si(); + auto &minv2 = std::get<0>(_minv2_maxv2_); + auto &maxv2 = std::get<1>(_minv2_maxv2_); + int min_tar = minv2.get_num_si(); + int max_tar = maxv2.get_num_si(); // Change each dim's range. // e.g., from "0 <= i0 <= 9" to "0 <= i0 <= 4" utils::Replace(&this_domain, - std::to_string(min_iv) + " <= " + dim_names[i] + " <= " + std::to_string(max_iv), - std::to_string(min_tar) + " <= " + dim_names[i] + " <= " + std::to_string(max_tar)); + std::to_string(min_iv) + " <= " + dim_names[i] + + " <= " + std::to_string(max_iv), + std::to_string(min_tar) + " <= " + dim_names[i] + + " <= " + std::to_string(max_tar)); } VLOG(3) << "Final changed domain is: " << this_domain; isl::set res_set(this_ctx, this_domain); @@ -389,11 +419,12 @@ void Stage::ChangeDomain(Stage *other, int level) { /** * Edit temp tensor's shape, its buffer's shape and index when doing ComputeAt2. * @param level The level of dims to be changed. - * For example, when this->domain_ is "{ [i0, i1] : 0 <= i0 <= 9 and 0 <= i1 <= 9 }", - * and 1st loop is binded to threadIdx.x, then i0 will be erased in this temp tensor's axes. + * For example, when this->domain_ is "{ [i0, i1] : 0 <= i0 <= 9 and 0 <= i1 <= + * 9 }", and 1st loop is binded to threadIdx.x, then i0 will be erased in this + * temp tensor's axes. */ void Stage::EditTempTensor(Stage *other, int level) { - auto bind_info = other->forloop_infos(); + auto bind_info = other->forloop_infos(); auto transform_domain_names = axis_names(); std::set erase_var; std::string tensor_name = this->tensor()->name; @@ -401,45 +432,56 @@ void Stage::EditTempTensor(Stage *other, int level) { if (isl_is_removed_axis(this->transformed_domain().get(), i)) { continue; } - int new_i = i - isl_get_precending_removed_axes_counts(this->transformed_domain().get(), i); + int new_i = i - isl_get_precending_removed_axes_counts( + this->transformed_domain().get(), i); if (bind_info.count(new_i) != 0) { - if (bind_info[new_i].for_type == ir::ForType::GPUThread && (this->scope() == ScopeKind::kShared)) { + if (bind_info[new_i].for_type == ir::ForType::GPUThread && + (this->scope() == ScopeKind::kShared)) { continue; } } // Iterators of loop within level will be erased. - auto related_dim_in = GetRelatedInputAxies(this->transform(), this->domain(), {transform_domain_names[i]}); + auto related_dim_in = GetRelatedInputAxies( + this->transform(), this->domain(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { erase_var.insert(j); } } std::set undo_erase_var; - // Beyond level, if the loop is binded to certain thread/block, it will also be earsed. + // Beyond level, if the loop is binded to certain thread/block, it will also + // be earsed. for (int i = level + 1; i < transform_domain_names.size(); i++) { if (isl_is_removed_axis(this->transformed_domain().get(), i)) { continue; } - int new_i = i - isl_get_precending_removed_axes_counts(this->transformed_domain().get(), i); + int new_i = i - isl_get_precending_removed_axes_counts( + this->transformed_domain().get(), i); if (bind_info.count(new_i) != 0) { if (bind_info[new_i].for_type == ir::ForType::GPUBlock && - (this->scope() == ScopeKind::kShared || this->scope() == ScopeKind::kLocal)) { - auto related_dim_in = GetRelatedInputAxies(this->transform(), this->domain(), {transform_domain_names[i]}); + (this->scope() == ScopeKind::kShared || + this->scope() == ScopeKind::kLocal)) { + auto related_dim_in = GetRelatedInputAxies( + this->transform(), this->domain(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { erase_var.insert(j); } - } else if (bind_info[new_i].for_type == ir::ForType::GPUThread && (this->scope() == ScopeKind::kLocal)) { - auto related_dim_in = GetRelatedInputAxies(this->transform(), this->domain(), {transform_domain_names[i]}); + } else if (bind_info[new_i].for_type == ir::ForType::GPUThread && + (this->scope() == ScopeKind::kLocal)) { + auto related_dim_in = GetRelatedInputAxies( + this->transform(), this->domain(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { erase_var.insert(j); } } else { - auto related_dim_in = GetRelatedInputAxies(this->transform(), this->domain(), {transform_domain_names[i]}); + auto related_dim_in = GetRelatedInputAxies( + this->transform(), this->domain(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { undo_erase_var.insert(j); } } } else { - auto related_dim_in = GetRelatedInputAxies(this->transform(), this->domain(), {transform_domain_names[i]}); + auto related_dim_in = GetRelatedInputAxies( + this->transform(), this->domain(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { undo_erase_var.insert(j); } @@ -463,11 +505,11 @@ void Stage::EditTempTensor(Stage *other, int level) { std::map dim_to_range; std::vector this_dim_names = isl_get_dim_names(domain_); for (int i = 0; i < this_dim_names.size(); i++) { - auto _minv_maxv_ = isl_set_get_axis_range(domain_.get(), i); - auto &minv = std::get<0>(_minv_maxv_); - auto &maxv = std::get<1>(_minv_maxv_); - int min_iv = minv.get_num_si(); - int max_iv = maxv.get_num_si(); + auto _minv_maxv_ = isl_set_get_axis_range(domain_.get(), i); + auto &minv = std::get<0>(_minv_maxv_); + auto &maxv = std::get<1>(_minv_maxv_); + int min_iv = minv.get_num_si(); + int max_iv = maxv.get_num_si(); dim_to_range[this_dim_names[i]] = max_iv; } @@ -495,22 +537,29 @@ void Stage::EditTempTensor(Stage *other, int level) { void Stage::ComputeAt(Stage *other, int level) { isl::set this_domain(domain().ctx(), isl_set_to_str(domain().get())); - isl::set target_domain(other->domain().ctx(), isl_set_to_str(other->domain().get())); + isl::set target_domain(other->domain().ctx(), + isl_set_to_str(other->domain().get())); auto reduce_axes = origin_reduce_axis_names(); for (auto &i : reduce_axes) { - this_domain = isl::manage(isl_remove_axis_by_name(this_domain.release(), i.c_str())); + this_domain = + isl::manage(isl_remove_axis_by_name(this_domain.release(), i.c_str())); } isl::map write_access = isl::manage(isl_set_identity(this_domain.release())); - isl::map read_access = isl::manage(isl_set_identity(target_domain.release())); - read_access = - isl::manage(isl_map_set_tuple_name(read_access.release(), isl_dim_out, isl_set_get_tuple_name(domain().get()))); + isl::map read_access = isl::manage(isl_set_identity(target_domain.release())); + read_access = isl::manage( + isl_map_set_tuple_name(read_access.release(), + isl_dim_out, + isl_set_get_tuple_name(domain().get()))); int num_out_dim = isl_map_dim(read_access.get(), isl_dim_out); - read_access = isl::manage(isl_map_remove_dims(read_access.release(), isl_dim_out, 0, num_out_dim)); - auto indices = optim::CollectTensorIndex(&(other->expr_), this->tensor()->name); + read_access = isl::manage( + isl_map_remove_dims(read_access.release(), isl_dim_out, 0, num_out_dim)); + auto indices = + optim::CollectTensorIndex(&(other->expr_), this->tensor()->name); RemoveDuplicate(indices); if (indices.empty()) { - LOG(ERROR) << "No Access Relation between [" << other->id() << "] and [" << this->id() << "]! Please check."; + LOG(ERROR) << "No Access Relation between [" << other->id() << "] and [" + << this->id() << "]! Please check."; } CHECK_EQ(indices.size(), 1) << "indices.size > 1 is not supported yet"; std::vector target_dims = isl_get_dim_names(other->domain()); @@ -532,27 +581,37 @@ void Stage::ComputeAt(Stage *other, int level) { // W is the write access relation // R is the read access relation // S is the original schedule of Stage *other - read_access = isl::manage(isl_map_add_dims(read_access.release(), isl_dim_out, index_names.size())); + read_access = isl::manage( + isl_map_add_dims(read_access.release(), isl_dim_out, index_names.size())); isl_set_dim_names(&read_access, isl_dim_out, index_names); - read_access = - isl::manage(isl_map_set_tuple_name(read_access.release(), isl_dim_out, isl_set_get_tuple_name(domain().get()))); + read_access = isl::manage( + isl_map_set_tuple_name(read_access.release(), + isl_dim_out, + isl_set_get_tuple_name(domain().get()))); std::string read_access_str = isl_map_to_str(read_access.get()); isl::map read_access2(read_access.ctx(), read_access_str); read_access2 = isl::manage(isl_map_reverse(read_access2.release())); - auto new_map = isl::manage(isl_map_apply_range(write_access.release(), read_access2.release())); - isl::map new_target_transform(other->transform().ctx(), isl_map_to_str(other->transform().get())); - auto target_map_dims = isl_get_dim_names(new_target_transform.get(), isl_dim_out); - auto target_map_dims_in = isl_get_dim_names(new_target_transform.get(), isl_dim_in); - // For axis out of the level, we don't copy their transform except for they are related to axis within the level. + auto new_map = isl::manage( + isl_map_apply_range(write_access.release(), read_access2.release())); + isl::map new_target_transform(other->transform().ctx(), + isl_map_to_str(other->transform().get())); + auto target_map_dims = + isl_get_dim_names(new_target_transform.get(), isl_dim_out); + auto target_map_dims_in = + isl_get_dim_names(new_target_transform.get(), isl_dim_in); + // For axis out of the level, we don't copy their transform except for they + // are related to axis within the level. std::vector level_out_dims; std::set related_output_dims_set; for (int i = 0; i <= level; i++) { level_out_dims.push_back(target_map_dims[i]); related_output_dims_set.insert(target_map_dims[i]); } - auto related_input_dims = GetRelatedInputAxies(new_target_transform, other->domain(), level_out_dims); - auto related_output_dims = GetRelatedOutputAxies(new_target_transform, other->domain(), related_input_dims); + auto related_input_dims = GetRelatedInputAxies( + new_target_transform, other->domain(), level_out_dims); + auto related_output_dims = GetRelatedOutputAxies( + new_target_transform, other->domain(), related_input_dims); for (auto &i : related_output_dims) { related_output_dims_set.insert(i); } @@ -562,25 +621,33 @@ void Stage::ComputeAt(Stage *other, int level) { } for (auto &i : target_map_dims) { if (related_output_dims_set.count(i) == 0) { - new_target_transform = - isl::manage(isl_remove_axis_by_name(new_target_transform.release(), isl_dim_out, i.c_str())); + new_target_transform = isl::manage(isl_remove_axis_by_name( + new_target_transform.release(), isl_dim_out, i.c_str())); } } for (auto &i : target_map_dims_in) { if (related_input_dims_set.count(i) == 0) { - new_target_transform = isl::manage(isl_map_add_dims(new_target_transform.release(), isl_dim_out, 1)); - int level = isl_map_dim(new_target_transform.get(), isl_dim_out); + new_target_transform = isl::manage( + isl_map_add_dims(new_target_transform.release(), isl_dim_out, 1)); + int level = isl_map_dim(new_target_transform.get(), isl_dim_out); std::string dim_name_add = i + "' = " + i; - new_target_transform = isl::manage( - isl_map_set_dim_name(new_target_transform.release(), isl_dim_out, level - 1, dim_name_add.c_str())); + new_target_transform = + isl::manage(isl_map_set_dim_name(new_target_transform.release(), + isl_dim_out, + level - 1, + dim_name_add.c_str())); } } - new_target_transform = isl::manage(isl_map_set_tuple_name(new_target_transform.release(), isl_dim_out, other->id())); + new_target_transform = isl::manage(isl_map_set_tuple_name( + new_target_transform.release(), isl_dim_out, other->id())); - isl::map f_target_transform(other->transform().ctx(), isl_map_to_str(new_target_transform.get())); - auto trans_res = isl::manage(isl_map_apply_range(new_map.release(), f_target_transform.release())); - trans_res = isl::manage(isl_map_set_tuple_name(trans_res.release(), isl_dim_out, this->id())); + isl::map f_target_transform(other->transform().ctx(), + isl_map_to_str(new_target_transform.get())); + auto trans_res = isl::manage( + isl_map_apply_range(new_map.release(), f_target_transform.release())); + trans_res = isl::manage( + isl_map_set_tuple_name(trans_res.release(), isl_dim_out, this->id())); // When there are reduce axes, we need to add these axes manually if (!reduce_axes.empty()) { @@ -588,32 +655,42 @@ void Stage::ComputeAt(Stage *other, int level) { for (auto &i : reduce_axes) { reduce_axes_out.push_back(i + "' = " + i); } - int map_dim_in = isl_map_dim(trans_res.get(), isl_dim_in); + int map_dim_in = isl_map_dim(trans_res.get(), isl_dim_in); int map_dim_out = isl_map_dim(trans_res.get(), isl_dim_out); - trans_res = isl::manage(isl_map_add_dims(trans_res.release(), isl_dim_in, reduce_axes.size())); + trans_res = isl::manage( + isl_map_add_dims(trans_res.release(), isl_dim_in, reduce_axes.size())); for (int i = 0; i < reduce_axes.size(); i++) { - trans_res = - isl::manage(isl_map_set_dim_name(trans_res.release(), isl_dim_in, map_dim_in + i, reduce_axes[i].c_str())); + trans_res = isl::manage(isl_map_set_dim_name(trans_res.release(), + isl_dim_in, + map_dim_in + i, + reduce_axes[i].c_str())); } - trans_res = isl::manage(isl_map_add_dims(trans_res.release(), isl_dim_out, reduce_axes_out.size())); + trans_res = isl::manage(isl_map_add_dims( + trans_res.release(), isl_dim_out, reduce_axes_out.size())); for (int i = 0; i < reduce_axes_out.size(); i++) { - trans_res = isl::manage( - isl_map_set_dim_name(trans_res.release(), isl_dim_out, map_dim_out + i, reduce_axes_out[i].c_str())); + trans_res = isl::manage(isl_map_set_dim_name(trans_res.release(), + isl_dim_out, + map_dim_out + i, + reduce_axes_out[i].c_str())); } - trans_res = isl::manage(isl_map_set_tuple_name(trans_res.release(), isl_dim_in, this->id())); - trans_res = isl::manage(isl_map_set_tuple_name(trans_res.release(), isl_dim_out, this->id())); + trans_res = isl::manage( + isl_map_set_tuple_name(trans_res.release(), isl_dim_in, this->id())); + trans_res = isl::manage( + isl_map_set_tuple_name(trans_res.release(), isl_dim_out, this->id())); std::string trans_res_str = isl_map_to_str(trans_res.get()); for (int i = 0; i < reduce_axes.size(); i++) { auto _minv_maxv_ = isl_set_get_axis_range(domain_.get(), i + map_dim_in); - auto &minv = std::get<0>(_minv_maxv_); - auto &maxv = std::get<1>(_minv_maxv_); - int min_iv = minv.get_num_si(); - int max_iv = maxv.get_num_si(); - - trans_res_str = trans_res_str.substr(0, trans_res_str.size() - 1) + "and " + std::to_string(min_iv) + - " <= " + reduce_axes[i] + " <= " + std::to_string(max_iv) + " }"; + auto &minv = std::get<0>(_minv_maxv_); + auto &maxv = std::get<1>(_minv_maxv_); + int min_iv = minv.get_num_si(); + int max_iv = maxv.get_num_si(); + + trans_res_str = trans_res_str.substr(0, trans_res_str.size() - 1) + + "and " + std::to_string(min_iv) + + " <= " + reduce_axes[i] + + " <= " + std::to_string(max_iv) + " }"; } isl::map temp_trans(trans_res.ctx(), trans_res_str); trans_res = temp_trans; @@ -622,26 +699,29 @@ void Stage::ComputeAt(Stage *other, int level) { VLOG(3) << "trans_res is : " << trans_res; { - auto trans_dim_out = isl_get_dim_names(trans_res.get(), isl_dim_out); + auto trans_dim_out = isl_get_dim_names(trans_res.get(), isl_dim_out); auto transformed_res = domain_.apply(trans_res); for (int i = level + 1; i < trans_dim_out.size(); i++) { - auto _minv_maxv_ = isl_set_get_axis_range(transformed_res.get(), i); - auto &minv = std::get<0>(_minv_maxv_); - auto &maxv = std::get<1>(_minv_maxv_); - int max_iv = maxv.get_num_si(); - int min_iv = minv.get_num_si(); - auto related_input_dims = GetRelatedInputAxies(trans_res, domain_, {trans_dim_out[i]}, true); + auto _minv_maxv_ = isl_set_get_axis_range(transformed_res.get(), i); + auto &minv = std::get<0>(_minv_maxv_); + auto &maxv = std::get<1>(_minv_maxv_); + int max_iv = maxv.get_num_si(); + int min_iv = minv.get_num_si(); + auto related_input_dims = + GetRelatedInputAxies(trans_res, domain_, {trans_dim_out[i]}, true); if (max_iv != min_iv && related_input_dims.empty()) { - trans_res = isl::manage(isl_remove_axis_by_name(trans_res.release(), isl_dim_out, trans_dim_out[i].c_str())); + trans_res = isl::manage(isl_remove_axis_by_name( + trans_res.release(), isl_dim_out, trans_dim_out[i].c_str())); } - VLOG(3) << "Input axis related to output axis [" << trans_dim_out[i] << "] (from " << min_iv << " to " << max_iv - << ") is : "; + VLOG(3) << "Input axis related to output axis [" << trans_dim_out[i] + << "] (from " << min_iv << " to " << max_iv << ") is : "; for (auto &j : related_input_dims) { VLOG(3) << j << ", "; } } } - VLOG(3) << "After removing redundant output axis, trans_res is : " << trans_res; + VLOG(3) << "After removing redundant output axis, trans_res is : " + << trans_res; transform_ = trans_res; CHECK(tensor_); @@ -650,15 +730,19 @@ void Stage::ComputeAt(Stage *other, int level) { relation.level = level; other->CtrlDepend(ir::Tensor(tensor())); - CHECK(relation.IsCompatible(this)) << "Cannot apply ComputeAt with level: " << level << " from \n" - << isl_set_to_str(this->transformed_domain().get()) << "\n to \n" - << isl_set_to_str(other->transformed_domain().get()); + CHECK(relation.IsCompatible(this)) + << "Cannot apply ComputeAt with level: " << level << " from \n" + << isl_set_to_str(this->transformed_domain().get()) << "\n to \n" + << isl_set_to_str(other->transformed_domain().get()); compute_ats_[other->id()] = relation; - for (int i = 0; i <= level; i++) AddForloopInfo(i, StageForloopInfo{ir::ForType::Default, DeviceAPI::UNK, 0}); + for (int i = 0; i <= level; i++) + AddForloopInfo(i, + StageForloopInfo{ir::ForType::Default, DeviceAPI::UNK, 0}); } void Stage::ComputeAt2(Stage *other, int level) { - // TODO(Superjomn) Check there are data dependency between `self` and `other`, or the `ComputeAt` is meaningless. + // TODO(Superjomn) Check there are data dependency between `self` and `other`, + // or the `ComputeAt` is meaningless. CHECK_GE(level, 0) << "level param of ComputeAt2 must be >= 0. Please check!"; this->ChangeDomain(other, level); this->CopyTransform(other, level); @@ -667,7 +751,8 @@ void Stage::ComputeAt2(Stage *other, int level) { other->CtrlDepend(ir::Tensor(tensor())); if (this->tensor()->buffer.defined()) { std::string t_name = this->tensor()->buffer->name; - if (utils::Endswith(t_name, "_read_cache") || utils::Endswith(t_name, "_write_cache")) { + if (utils::Endswith(t_name, "_read_cache") || + utils::Endswith(t_name, "_write_cache")) { EditTempTensor(other, level); } } @@ -676,9 +761,10 @@ void Stage::ComputeAt2(Stage *other, int level) { relation.level = level; other->CtrlDepend(ir::Tensor(tensor())); - CHECK(relation.IsCompatible(this)) << "Cannot apply ComputeAt2 with level: " << level << " from \n" - << isl_set_to_str(this->transformed_domain().get()) << "\n to \n" - << isl_set_to_str(other->transformed_domain().get()); + CHECK(relation.IsCompatible(this)) + << "Cannot apply ComputeAt2 with level: " << level << " from \n" + << isl_set_to_str(this->transformed_domain().get()) << "\n to \n" + << isl_set_to_str(other->transformed_domain().get()); compute_ats_[other->id()] = relation; } @@ -690,7 +776,8 @@ void Stage::ComputeAt3(Stage *other, int level) { other->CtrlDepend(ir::Tensor(tensor())); if (this->tensor()->buffer.defined()) { std::string t_name = this->tensor()->buffer->name; - if (utils::Endswith(t_name, "_read_cache") || utils::Endswith(t_name, "_write_cache")) { + if (utils::Endswith(t_name, "_read_cache") || + utils::Endswith(t_name, "_write_cache")) { EditTempTensor(other, level); } } @@ -701,7 +788,8 @@ void Stage::SimpleComputeAt(Stage *other, int level) { other->CtrlDepend(ir::Tensor(tensor())); if (this->tensor()->buffer.defined()) { std::string t_name = this->tensor()->buffer->name; - if (utils::Endswith(t_name, "_read_cache") || utils::Endswith(t_name, "_write_cache")) { + if (utils::Endswith(t_name, "_read_cache") || + utils::Endswith(t_name, "_write_cache")) { EditTempTensor(other, level); } } @@ -710,19 +798,26 @@ void Stage::SimpleComputeAt(Stage *other, int level) { relation.level = level; other->CtrlDepend(ir::Tensor(tensor())); - CHECK(relation.IsCompatible(this)) << "Cannot apply SimpleComputeAt with level: " << level << " from \n" - << isl_set_to_str(this->transformed_domain().get()) << "\n to \n" - << isl_set_to_str(other->transformed_domain().get()); + CHECK(relation.IsCompatible(this)) + << "Cannot apply SimpleComputeAt with level: " << level << " from \n" + << isl_set_to_str(this->transformed_domain().get()) << "\n to \n" + << isl_set_to_str(other->transformed_domain().get()); compute_ats_[other->id()] = relation; - auto other_expr = other->expr(); - auto find_tensors = ir::CollectIRNodesWithoutTensor( - other_expr, [&](const Expr *x) { return x->as_tensor() && x->as_tensor_ref()->name == tensor()->name; }); + auto other_expr = other->expr(); + auto find_tensors = + ir::CollectIRNodesWithoutTensor(other_expr, [&](const Expr *x) { + return x->as_tensor() && x->as_tensor_ref()->name == tensor()->name; + }); if (!find_tensors.empty()) { - for (int i = 0; i <= level; i++) AddForloopInfo(i, StageForloopInfo{ir::ForType::Default, DeviceAPI::UNK, 0}); + for (int i = 0; i <= level; i++) + AddForloopInfo(i, + StageForloopInfo{ir::ForType::Default, DeviceAPI::UNK, 0}); } } -std::tuple Stage::Skew(const Iterator &i, const Iterator &j, int factor) { +std::tuple Stage::Skew(const Iterator &i, + const Iterator &j, + int factor) { CINN_NOT_IMPLEMENTED Iterator i_new(i.id + "_skew"); Iterator j_new(j.id + "_skew"); @@ -778,10 +873,12 @@ Iterator Stage::Fuse(const std::vector &levels) { std::vector offsets; std::string new_iter_name; for (auto &level : levels) { - int offset = isl_set_find_dim_by_name(transformed_domain().get(), isl_dim_set, level.id.c_str()); + int offset = isl_set_find_dim_by_name( + transformed_domain().get(), isl_dim_set, level.id.c_str()); if (!offsets.empty()) CHECK_EQ(offsets.back() + 1, offset) - << "level [" << offsets.back() << "] and level [" << offset << "] should be adjancent"; + << "level [" << offsets.back() << "] and level [" << offset + << "] should be adjancent"; AssertAxisIsNotLocked(offset); offsets.push_back(offset); new_iter_name += utils::StringFormat("%s_", level.id.c_str()); @@ -791,12 +888,17 @@ Iterator Stage::Fuse(const std::vector &levels) { // Aff { s[i,j,k] -> [j] } and get the j's max value // to apply something like { S[i,j] -> S[k]: k = i+j } auto from_dim_names = isl_get_dim_names(transform_, isl_dim_out); - auto from_iters = NamesToIterators(from_dim_names); + auto from_iters = NamesToIterators(from_dim_names); std::vector iterator_max_val; for (auto &level : levels) { - Aff aff(domain_.ctx(), id(), from_iters, std::vector({Iterator(level.id)}), {}); - int level_max_val = transformed_domain().max_val(aff.to_isl()).get_num_si() + 1; + Aff aff(domain_.ctx(), + id(), + from_iters, + std::vector({Iterator(level.id)}), + {}); + int level_max_val = + transformed_domain().max_val(aff.to_isl()).get_num_si() + 1; iterator_max_val.push_back(level_max_val); } @@ -805,7 +907,8 @@ Iterator Stage::Fuse(const std::vector &levels) { { Iterator new_iter(new_iter_name); for (int i = 0; i < from_iters.size(); i++) { - int offset = isl_set_find_dim_by_name(transformed_domain().get(), isl_dim_set, from_iters[i].id.c_str()); + int offset = isl_set_find_dim_by_name( + transformed_domain().get(), isl_dim_set, from_iters[i].id.c_str()); if (i == offsets.back()) { to_iters.push_back(new_iter); } else if (i >= offsets.front() && i < offsets.back()) { @@ -816,21 +919,27 @@ Iterator Stage::Fuse(const std::vector &levels) { } auto my_prod = [=](int a, int b) { return a * b; }; std::vector conds; - conds.emplace_back( - utils::StringFormat("%s = floor(%s / %d)", - levels.front().id.c_str(), - new_iter_name.c_str(), - (int)std::accumulate(iterator_max_val.begin() + 1, iterator_max_val.end(), 1, my_prod))); - conds.emplace_back( - utils::StringFormat("%s = %s mod %d", levels.back().id.c_str(), new_iter_name.c_str(), iterator_max_val.back())); + conds.emplace_back(utils::StringFormat( + "%s = floor(%s / %d)", + levels.front().id.c_str(), + new_iter_name.c_str(), + (int)std::accumulate( + iterator_max_val.begin() + 1, iterator_max_val.end(), 1, my_prod))); + conds.emplace_back(utils::StringFormat("%s = %s mod %d", + levels.back().id.c_str(), + new_iter_name.c_str(), + iterator_max_val.back())); for (int i = 1; i < levels.size() - 1; i++) { - conds.emplace_back( - utils::StringFormat("%s = floor(%s / %d) mod %d", - levels[i].id.c_str(), - new_iter_name.c_str(), - (int)std::accumulate(iterator_max_val.begin() + i + 1, iterator_max_val.end(), 1, my_prod), - iterator_max_val[i])); + conds.emplace_back(utils::StringFormat( + "%s = floor(%s / %d) mod %d", + levels[i].id.c_str(), + new_iter_name.c_str(), + (int)std::accumulate(iterator_max_val.begin() + i + 1, + iterator_max_val.end(), + 1, + my_prod), + iterator_max_val[i])); } Map trans(domain_.ctx(), id(), from_iters, to_iters, conds, id()); @@ -853,22 +962,31 @@ Iterator Stage::Fuse(const std::vector &levels) { * Fuse use a polyhedral transform. */ Iterator Stage::Fuse(const Iterator &level0, const Iterator &level1) { - int offset0 = isl_set_find_dim_by_name(transformed_domain().get(), isl_dim_set, level0.id.c_str()); - int offset1 = isl_set_find_dim_by_name(transformed_domain().get(), isl_dim_set, level1.id.c_str()); - CHECK_EQ(offset1, offset0 + 1) << "level [" << level0.id << "] and level [" << level1.id << "] should be adjancent"; + int offset0 = isl_set_find_dim_by_name( + transformed_domain().get(), isl_dim_set, level0.id.c_str()); + int offset1 = isl_set_find_dim_by_name( + transformed_domain().get(), isl_dim_set, level1.id.c_str()); + CHECK_EQ(offset1, offset0 + 1) << "level [" << level0.id << "] and level [" + << level1.id << "] should be adjancent"; AssertAxisIsNotLocked(offset0); AssertAxisIsNotLocked(offset1); - auto new_iter_name = utils::StringFormat("%s_%s_fused", level0.id.c_str(), level1.id.c_str()); + auto new_iter_name = + utils::StringFormat("%s_%s_fused", level0.id.c_str(), level1.id.c_str()); // Aff { s[i,j,k] -> [j] } and get the j's max value // to apply something like { S[i,j] -> S[k]: k = i+j } auto from_dim_names = isl_get_dim_names(transform_, isl_dim_out); - auto from_iters = NamesToIterators(from_dim_names); + auto from_iters = NamesToIterators(from_dim_names); - Aff aff(domain_.ctx(), id(), from_iters, std::vector({Iterator(level1.id)}), {}); + Aff aff(domain_.ctx(), + id(), + from_iters, + std::vector({Iterator(level1.id)}), + {}); - int level1_max_val = transformed_domain().max_val(aff.to_isl()).get_num_si() + 1; + int level1_max_val = + transformed_domain().max_val(aff.to_isl()).get_num_si() + 1; // Map { s[i,j,k] -> s[n,k] : n = i * max_val + j } std::vector to_iters; @@ -885,8 +1003,11 @@ Iterator Stage::Fuse(const Iterator &level0, const Iterator &level1) { } std::vector conds; - conds.emplace_back(utils::StringFormat( - "%s = %s * %d + %s", new_iter_name.c_str(), level0.id.c_str(), level1_max_val, level1.id.c_str())); + conds.emplace_back(utils::StringFormat("%s = %s * %d + %s", + new_iter_name.c_str(), + level0.id.c_str(), + level1_max_val, + level1.id.c_str())); Map trans(domain_.ctx(), id(), from_iters, to_iters, conds, id()); @@ -904,7 +1025,8 @@ Iterator Stage::Fuse(const Iterator &level0, const Iterator &level1) { std::vector Stage::input_statements() const { if (!expr_.defined()) return {}; VLOG(3) << "stage " << id() << " expr: " << expr_; - auto load_exprs = ir::CollectIRNodes(expr_, [](const Expr *x) { return x->As(); }); + auto load_exprs = ir::CollectIRNodes( + expr_, [](const Expr *x) { return x->As(); }); std::set statements; for (auto &expr : load_exprs) { auto *load_node = expr.As(); @@ -919,16 +1041,23 @@ std::vector Stage::input_statements() const { std::string InnerName(const std::string &name) { return name + "_inner"; } std::string OuterName(const std::string &name) { return name + "_outer"; } -std::string InnerName(const Iterator &iterator) { return InnerName(iterator.id); } -std::string OuterName(const Iterator &iterator) { return OuterName(iterator.id); } +std::string InnerName(const Iterator &iterator) { + return InnerName(iterator.id); +} +std::string OuterName(const Iterator &iterator) { + return OuterName(iterator.id); +} const char *Stage::id() const { return isl_set_get_tuple_name(domain_.get()); } -std::tuple Stage::Split(const std::string &level, int factor) { +std::tuple Stage::Split(const std::string &level, + int factor) { return std::move(Split(Iterator(level), factor)); } -Shared Stage::New(const isl::set &domain, Expr expr, ir::_Tensor_ *tensor) { +Shared Stage::New(const isl::set &domain, + Expr expr, + ir::_Tensor_ *tensor) { return new Stage(domain, expr, tensor); } @@ -939,8 +1068,10 @@ std::vector Stage::compute_ats() const { } void Stage::ShowISL() const { - LOG(INFO) << "Tensor " << id() << " domain is: " << isl_set_to_str(domain().get()); - LOG(INFO) << "transformed_domain is: " << isl_set_to_str(transformed_domain().get()); + LOG(INFO) << "Tensor " << id() + << " domain is: " << isl_set_to_str(domain().get()); + LOG(INFO) << "transformed_domain is: " + << isl_set_to_str(transformed_domain().get()); LOG(INFO) << "transform is: " << isl_map_to_str(transform().get()); } @@ -961,17 +1092,21 @@ bool ComputeAtRelation::IsCompatible(Stage *self) { selected_dims.push_back(i); } - auto stage_partial_set = SetGetDims(stage->transformed_domain(), selected_dims); - auto self_partial_set = SetGetDims(self->transformed_domain(), selected_dims); + auto stage_partial_set = + SetGetDims(stage->transformed_domain(), selected_dims); + auto self_partial_set = SetGetDims(self->transformed_domain(), selected_dims); - stage_partial_set = isl::manage(isl_set_set_tuple_name(stage_partial_set.release(), "")); - self_partial_set = isl::manage(isl_set_set_tuple_name(self_partial_set.release(), "")); + stage_partial_set = + isl::manage(isl_set_set_tuple_name(stage_partial_set.release(), "")); + self_partial_set = + isl::manage(isl_set_set_tuple_name(self_partial_set.release(), "")); // remove parameters, we don't consider them yet auto remove_params = [](isl::set &set) { int nparams = isl_set_dim(set.get(), isl_dim_param); if (nparams > 0) { - set = isl::manage(isl_set_remove_dims(set.release(), isl_dim_param, 0, nparams)); + set = isl::manage( + isl_set_remove_dims(set.release(), isl_dim_param, 0, nparams)); } }; @@ -997,24 +1132,29 @@ void Stage::Vectorize(int level, int factor) { VLOG(3) << "Vectorizing for-1 has no sense, skip it"; return; } - int removed_axes_counts = isl_get_precending_removed_axes_counts(transformed_domain.get(), level); - VLOG(3) << "removed_axes_counts are " << removed_axes_counts << " before axis " << ith_dim_name(level); - VLOG(3) << "vectorize level: " << level - removed_axes_counts << ", factor: " << factor; + int removed_axes_counts = + isl_get_precending_removed_axes_counts(transformed_domain.get(), level); + VLOG(3) << "removed_axes_counts are " << removed_axes_counts + << " before axis " << ith_dim_name(level); + VLOG(3) << "vectorize level: " << level - removed_axes_counts + << ", factor: " << factor; vectorize_info_.set(level - removed_axes_counts /*inner*/, factor); } void Stage::Vectorize(const std::string &axis, int factor) { auto dims = isl_get_dim_names(transformed_domain()); - auto it = std::find(dims.begin(), dims.end(), axis); + auto it = std::find(dims.begin(), dims.end(), axis); CHECK(it != dims.end()) << "No dimension called " << axis; Vectorize(std::distance(dims.begin(), it), factor); } -void Stage::Vectorize(const Iterator &axis, int factor) { return Vectorize(axis.id, factor); } +void Stage::Vectorize(const Iterator &axis, int factor) { + return Vectorize(axis.id, factor); +} void Stage::Parallel(const std::string &axis) { auto dims = isl_get_dim_names(transformed_domain()); - auto it = std::find(dims.begin(), dims.end(), axis); + auto it = std::find(dims.begin(), dims.end(), axis); CHECK(it != dims.end()) << "No dimension called " << axis; Parallel(std::distance(dims.begin(), it)); } @@ -1030,8 +1170,10 @@ void Stage::Parallel(int level) { VLOG(3) << "Paralleling for-1 has no sense, skip it"; return; } - int removed_axes_counts = isl_get_precending_removed_axes_counts(transformed_domain.get(), level); - VLOG(3) << "removed_axes_counts are " << removed_axes_counts << " before axis " << ith_dim_name(level); + int removed_axes_counts = + isl_get_precending_removed_axes_counts(transformed_domain.get(), level); + VLOG(3) << "removed_axes_counts are " << removed_axes_counts + << " before axis " << ith_dim_name(level); parallel_info_.insert(level - removed_axes_counts); } @@ -1043,8 +1185,10 @@ void Stage::Unroll(int level) { VLOG(1) << "Unrolling for-1 has no sense, skip it"; return; } - int removed_axes_counts = isl_get_precending_removed_axes_counts(transformed_domain.get(), level); - VLOG(3) << "removed_axes_counts are " << removed_axes_counts << " before axis " << ith_dim_name(level); + int removed_axes_counts = + isl_get_precending_removed_axes_counts(transformed_domain.get(), level); + VLOG(3) << "removed_axes_counts are " << removed_axes_counts + << " before axis " << ith_dim_name(level); unroll_info_.insert(level - removed_axes_counts); } @@ -1054,7 +1198,9 @@ std::string Stage::ith_dim_name(int level) { return dims[level]; } -Iterator Stage::ith_iterator(int level) { return Iterator(ith_dim_name(level)); } +Iterator Stage::ith_iterator(int level) { + return Iterator(ith_dim_name(level)); +} isl::set Stage::transformed_domain() const { CHECK(!domain_.is_null()); @@ -1062,7 +1208,8 @@ isl::set Stage::transformed_domain() const { return domain_.apply(transform_); } -std::vector> ExtractExtraDepLinksFromStages(const std::vector &stages) { +std::vector> ExtractExtraDepLinksFromStages( + const std::vector &stages) { std::vector> extra_links; for (auto &stage : stages) { for (auto &tensor : stage->ctrl_depends()) { @@ -1076,21 +1223,23 @@ std::vector> ExtractExtraDepLinksFromStages( void Stage::Unroll(const std::string &level) { auto dim_names = axis_names(); - auto it = std::find(dim_names.begin(), dim_names.end(), level); - int l = std::distance(dim_names.begin(), it); + auto it = std::find(dim_names.begin(), dim_names.end(), level); + int l = std::distance(dim_names.begin(), it); AssertAxisIsNotLocked(l); Unroll(l); } void Stage::Unroll(const Iterator &level) { auto dim_names = axis_names(); - auto it = std::find(dim_names.begin(), dim_names.end(), level.id); - int l = std::distance(dim_names.begin(), it); + auto it = std::find(dim_names.begin(), dim_names.end(), level.id); + int l = std::distance(dim_names.begin(), it); AssertAxisIsNotLocked(l); Unroll(l); } -std::vector Stage::axis_names() const { return isl_get_dim_names(transformed_domain()); } +std::vector Stage::axis_names() const { + return isl_get_dim_names(transformed_domain()); +} std::vector Stage::origin_reduce_axis_names() { auto reduce_axis_var = this->tensor()->reduce_axis; @@ -1107,10 +1256,14 @@ void Stage::Bind(int level, const std::string &axis) { if (axis == "threadIdx.x" || axis == "threadIdx.y" || axis == "threadIdx.z") { uint8_t offset = axis.back() - 'x'; - AddForloopInfo(level, StageForloopInfo{ir::ForType::GPUThread, DeviceAPI::GPU, offset}); - } else if (axis == "blockIdx.x" || axis == "blockIdx.y" || axis == "blockIdx.z") { + AddForloopInfo( + level, + StageForloopInfo{ir::ForType::GPUThread, DeviceAPI::GPU, offset}); + } else if (axis == "blockIdx.x" || axis == "blockIdx.y" || + axis == "blockIdx.z") { uint8_t offset = axis.back() - 'x'; - AddForloopInfo(level, StageForloopInfo{ir::ForType::GPUBlock, DeviceAPI::GPU, offset}); + AddForloopInfo( + level, StageForloopInfo{ir::ForType::GPUBlock, DeviceAPI::GPU, offset}); } else { CINN_NOT_IMPLEMENTED } @@ -1123,7 +1276,7 @@ Iterator Stage::axis(int i) const { } Iterator Stage::axis(const std::string &i) const { auto names = axis_names(); - auto it = std::find(names.begin(), names.end(), i); + auto it = std::find(names.begin(), names.end(), i); CHECK(it != names.end()); return Iterator(*it); } @@ -1139,7 +1292,9 @@ void Stage::SyncThreads(StageMap stages) { auto sync_threads = lang::Compute( {}, - [](const std::vector &axis) { return runtime::IntrinsicCall(Void(), "__syncthreads", {}); }, + [](const std::vector &axis) { + return runtime::IntrinsicCall(Void(), "__syncthreads", {}); + }, Context::Global().NewName("syncthreads")); stages->Insert(sync_threads, ir::CreateStage(sync_threads).get()); @@ -1147,12 +1302,17 @@ void Stage::SyncThreads(StageMap stages) { stages[sync_threads]->CtrlDepend(this_tensor); CHECK_LE(this->compute_ats().size(), 1); for (auto &compute_at : this->compute_ats()) { - isl::set sync_domain(compute_at.stage->domain().ctx(), - isl_set_to_str(compute_at.stage->transformed_domain().get())); + isl::set sync_domain( + compute_at.stage->domain().ctx(), + isl_set_to_str(compute_at.stage->transformed_domain().get())); int dim_num = isl_set_dim(sync_domain.get(), isl_dim_set); - sync_domain = isl::manage( - isl_set_remove_dims(sync_domain.release(), isl_dim_set, compute_at.level + 1, dim_num - compute_at.level - 1)); - sync_domain = isl::manage(isl_set_set_tuple_name(sync_domain.release(), sync_threads->name.c_str())); + sync_domain = + isl::manage(isl_set_remove_dims(sync_domain.release(), + isl_dim_set, + compute_at.level + 1, + dim_num - compute_at.level - 1)); + sync_domain = isl::manage(isl_set_set_tuple_name( + sync_domain.release(), sync_threads->name.c_str())); stages[sync_threads]->domain_ = sync_domain; stages[sync_threads]->InitTransform(); @@ -1161,10 +1321,12 @@ void Stage::SyncThreads(StageMap stages) { relation.level = compute_at.level; relation.stage->CtrlDepend(sync_threads); - CHECK(relation.IsCompatible(this)) << "Cannot create ComputeAtRelation in SyncThreads with level: " - << relation.level << " from \n" - << isl_set_to_str(stages[sync_threads]->transformed_domain().get()) << "\n to \n" - << isl_set_to_str(relation.stage->transformed_domain().get()); + CHECK(relation.IsCompatible(this)) + << "Cannot create ComputeAtRelation in SyncThreads with level: " + << relation.level << " from \n" + << isl_set_to_str(stages[sync_threads]->transformed_domain().get()) + << "\n to \n" + << isl_set_to_str(relation.stage->transformed_domain().get()); stages[sync_threads]->compute_ats_[relation.stage->id()] = relation; } @@ -1175,13 +1337,17 @@ void Stage::SyncThreads(StageMap stages) { } } -void Stage::SyncThreads(int level, const std::vector &before_tensors, StageMap stages) { +void Stage::SyncThreads(int level, + const std::vector &before_tensors, + StageMap stages) { CHECK(tensor_); auto this_tensor = ir::Tensor(tensor_); auto sync_threads = lang::Compute( {}, - [](const std::vector &axis) { return runtime::IntrinsicCall(Void(), "__syncthreads", {}); }, + [](const std::vector &axis) { + return runtime::IntrinsicCall(Void(), "__syncthreads", {}); + }, Context::Global().NewName("syncthreads")); stages->Insert(sync_threads, ir::CreateStage(sync_threads).get()); @@ -1192,10 +1358,13 @@ void Stage::SyncThreads(int level, const std::vector &before_tensors stages[sync_threads]->CtrlDepend(other); } - isl::set sync_domain(domain().ctx(), isl_set_to_str(transformed_domain().get())); + isl::set sync_domain(domain().ctx(), + isl_set_to_str(transformed_domain().get())); int dim_num = isl_set_dim(sync_domain.get(), isl_dim_set); - sync_domain = isl::manage(isl_set_remove_dims(sync_domain.release(), isl_dim_set, level + 1, dim_num - level - 1)); - sync_domain = isl::manage(isl_set_set_tuple_name(sync_domain.release(), sync_threads->name.c_str())); + sync_domain = isl::manage(isl_set_remove_dims( + sync_domain.release(), isl_dim_set, level + 1, dim_num - level - 1)); + sync_domain = isl::manage(isl_set_set_tuple_name(sync_domain.release(), + sync_threads->name.c_str())); stages[sync_threads]->domain_ = sync_domain; stages[sync_threads]->InitTransform(); @@ -1223,7 +1392,9 @@ struct CacheReplaceMutator : public ir::IRMutator<> { * @param cache the cache * @param read_or_write read or write cache */ - CacheReplaceMutator(const std::string &tensor_name, ir::Tensor cache, bool read_or_write) + CacheReplaceMutator(const std::string &tensor_name, + ir::Tensor cache, + bool read_or_write) : tensor_name(tensor_name), cache(cache), read_or_write(read_or_write) {} void operator()(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } @@ -1249,7 +1420,9 @@ struct CacheReplaceMutator : public ir::IRMutator<> { }; } // namespace -void CacheReadWriteReplace(std::vector &readers, ir::Tensor cache_tensor, std::string origin_tensor_name) { +void CacheReadWriteReplace(std::vector &readers, + ir::Tensor cache_tensor, + std::string origin_tensor_name) { for (auto k : readers) { auto op = k->operation->as()->body; for (auto j : op) { @@ -1280,15 +1453,21 @@ void Stage::SetBuffer(const std::string &memory_type) { * To create a read cache: * 1. create a cache write stage for cache assign. * 2. add extra deps between cache and tensor to keep SSA order - * 3. register the readers of the cache to the \p tensor, replace latter in Lower + * 3. register the readers of the cache to the \p tensor, replace latter in + * Lower */ -ir::Tensor Stage::CacheRead(const std::string &memory_type, std::vector &readers, StageMap stages) { +ir::Tensor Stage::CacheRead(const std::string &memory_type, + std::vector &readers, + StageMap stages) { CHECK(tensor_); - auto my_tensor = ir::Tensor(tensor_); - std::string cache_name = Context::Global().NewName(tensor_->name) + "_read_cache"; + auto my_tensor = ir::Tensor(tensor_); + std::string cache_name = + Context::Global().NewName(tensor_->name) + "_read_cache"; VLOG(4) << "cache_name " << cache_name; auto cache_tensor = lang::Compute( - tensor_->shape, [=](const std::vector &dims) { return my_tensor(dims); }, cache_name); + tensor_->shape, + [=](const std::vector &dims) { return my_tensor(dims); }, + cache_name); cache_tensor->WithBuffer(memory_type); stages->Insert(cache_tensor, CreateStage(cache_tensor).get()); @@ -1297,8 +1476,10 @@ ir::Tensor Stage::CacheRead(const std::string &memory_type, std::vector reader_names; - std::transform( - readers.begin(), readers.end(), std::back_inserter(reader_names), [](const ir::Tensor &x) { return x->name; }); + std::transform(readers.begin(), + readers.end(), + std::back_inserter(reader_names), + [](const ir::Tensor &x) { return x->name; }); CacheReadWriteReplace(readers, cache_tensor, tensor_->name); if (memory_type == "shared") { @@ -1315,22 +1496,29 @@ ir::Tensor Stage::CacheRead(const std::string &memory_type, std::vectorbuffer.defined()) << "This tensor is already binded to a buffer, cannot cache write"; - CHECK(!meta.compute_inline) << "Cannot create a write cache on an inlined tensor"; - auto ctrl_depend = stages[tensor_]->ctrl_depends(); + CHECK(!tensor_->buffer.defined()) + << "This tensor is already binded to a buffer, cannot cache write"; + CHECK(!meta.compute_inline) + << "Cannot create a write cache on an inlined tensor"; + auto ctrl_depend = stages[tensor_]->ctrl_depends(); std::string cache_name = tensor_->name + "_write_cache"; - auto original_name = tensor_->name; - tensor_->name = cache_name; - auto my_tensor = ir::Tensor(tensor_); + auto original_name = tensor_->name; + tensor_->name = cache_name; + auto my_tensor = ir::Tensor(tensor_); // make my_tensor a cache my_tensor->WithBuffer(memory_type); auto write_stage = lang::Compute( - tensor_->shape, [=](const std::vector &dims) { return my_tensor(dims); }, original_name); + tensor_->shape, + [=](const std::vector &dims) { return my_tensor(dims); }, + original_name); stages->Insert(my_tensor, CreateStage(my_tensor).get()); stages[my_tensor]->ctrl_depends_ = ctrl_depend; @@ -1339,7 +1527,9 @@ ir::Tensor Stage::CacheWrite(const std::string &memory_type, StageMap stages, ir stages[write_stage]->CtrlDepend(my_tensor); std::vector temp; for (auto &i : stages) { - if (i.second->tensor()->name == original_name || i.second->tensor()->name == cache_name) continue; + if (i.second->tensor()->name == original_name || + i.second->tensor()->name == cache_name) + continue; if (i.second->tensor()->is_compute_node()) { temp.push_back(ir::Tensor(i.second->tensor())); } @@ -1370,26 +1560,37 @@ void Stage::ShareBufferWith(Stage *other) { other->meta.tensors_to_share_buffer_with.insert(tensor_->name); } -isl_map *__isl_give GatherAccesses(Stage *stage, const std::string &tensor_name) { +isl_map *__isl_give GatherAccesses(Stage *stage, + const std::string &tensor_name) { CHECK(stage->tensor_); auto loads = ir::CollectIRNodes(stage->tensor_->body(), [&](const Expr *x) { - return x->As() && x->As()->tensor.as_tensor()->name == tensor_name; + return x->As() && + x->As()->tensor.as_tensor()->name == tensor_name; }); auto vars = stage->tensor_->axis_with_reduce(); - std::string in_tuple_name = stage->tensor_->name; + std::string in_tuple_name = stage->tensor_->name; std::string out_tuple_name = tensor_name; std::vector in_dim_names, out_loads; - std::transform(vars.begin(), vars.end(), std::back_inserter(in_dim_names), [](const Var &x) { return x->name; }); - std::transform( - loads.begin(), loads.end(), std::back_inserter(out_loads), [](const Expr &x) { return utils::GetStreamCnt(x); }); + std::transform(vars.begin(), + vars.end(), + std::back_inserter(in_dim_names), + [](const Var &x) { return x->name; }); + std::transform(loads.begin(), + loads.end(), + std::back_inserter(out_loads), + [](const Expr &x) { return utils::GetStreamCnt(x); }); isl_map *res = nullptr; for (auto &load : out_loads) { - std::string repr = utils::StringFormat( - "{ %s[%s] -> %s }", in_tuple_name.c_str(), utils::Join(in_dim_names, ",").c_str(), load.c_str()); - isl_map *access = isl_map_read_from_str(stage->domain().ctx().get(), repr.c_str()); + std::string repr = + utils::StringFormat("{ %s[%s] -> %s }", + in_tuple_name.c_str(), + utils::Join(in_dim_names, ",").c_str(), + load.c_str()); + isl_map *access = + isl_map_read_from_str(stage->domain().ctx().get(), repr.c_str()); if (res) { res = isl_map_union(res, access); } else { @@ -1402,15 +1603,16 @@ isl_map *__isl_give GatherAccesses(Stage *stage, const std::string &tensor_name) void Stage::AddForloopInfo(int level, const StageForloopInfo &info) { cuda_bind_info_ = true; - int num_levels = isl_map_dim(transform_.get(), isl_dim_out); + int num_levels = isl_map_dim(transform_.get(), isl_dim_out); CHECK_GE(level, 0); CHECK_LT(level, num_levels); auto transformed_domain = this->transformed_domain(); - int removed_axes_counts = isl_get_precending_removed_axes_counts(transformed_domain.get(), level); + int removed_axes_counts = + isl_get_precending_removed_axes_counts(transformed_domain.get(), level); if (isl_is_removed_axis(transformed_domain.get(), level)) { - // For scalar case, forloop info will be lost after for-1 and reduce-axis elimination. We record the forloop info in - // the -1th level for backup. + // For scalar case, forloop info will be lost after for-1 and reduce-axis + // elimination. We record the forloop info in the -1th level for backup. if (level == 0) { VLOG(3) << "add forloop_infos in the -1 level for backup"; forloop_infos_[-1] = info; @@ -1418,35 +1620,43 @@ void Stage::AddForloopInfo(int level, const StageForloopInfo &info) { VLOG(3) << "for-1 has no sense, skip it"; return; } - VLOG(3) << "removed_axes_counts are " << removed_axes_counts << " before axis " << ith_dim_name(level); + VLOG(3) << "removed_axes_counts are " << removed_axes_counts + << " before axis " << ith_dim_name(level); forloop_infos_[level - removed_axes_counts] = info; } void Stage::CopyTransform(Stage *other, int level) { - auto target_transform = - RemoveAxiesByInputNames(other->transform(), other->domain(), other->origin_reduce_axis_names()); - isl::set target_origin_domain(other->domain().ctx(), isl_set_to_str(other->domain().get())); + auto target_transform = RemoveAxiesByInputNames( + other->transform(), other->domain(), other->origin_reduce_axis_names()); + isl::set target_origin_domain(other->domain().ctx(), + isl_set_to_str(other->domain().get())); for (auto &i : other->origin_reduce_axis_names()) { - target_origin_domain = isl::manage(isl_remove_axis_by_name(target_origin_domain.release(), i.c_str())); + target_origin_domain = isl::manage( + isl_remove_axis_by_name(target_origin_domain.release(), i.c_str())); } std::string str_target_trans = isl_map_to_str(target_transform.get()); std::string this_tensor_name = isl_set_get_tuple_name(domain_.get()); - isl::ctx this_ctx = domain_.ctx(); + isl::ctx this_ctx = domain_.ctx(); isl::map temp_transform_(this_ctx, str_target_trans); - auto this_map_dims = isl_get_dim_names(transform_.get(), isl_dim_in); + auto this_map_dims = isl_get_dim_names(transform_.get(), isl_dim_in); auto target_map_dims = isl_get_dim_names(target_transform.get(), isl_dim_in); - // Edit level. e.g. if A->Split(0,10) and B->CopyTransform(A,0), level should increase to 1. + // Edit level. e.g. if A->Split(0,10) and B->CopyTransform(A,0), level should + // increase to 1. isl::map temp_target_trans(this_ctx, str_target_trans); if (level + 1 < isl_map_dim(temp_target_trans.get(), isl_dim_out)) { - std::string pivot_dim_out = isl_map_get_dim_name(temp_target_trans.get(), isl_dim_out, level + 1); + std::string pivot_dim_out = + isl_map_get_dim_name(temp_target_trans.get(), isl_dim_out, level + 1); std::vector dim_out_level; for (int i = 0; i <= level; i++) { - dim_out_level.push_back(isl_map_get_dim_name(temp_target_trans.get(), isl_dim_out, i)); + dim_out_level.push_back( + isl_map_get_dim_name(temp_target_trans.get(), isl_dim_out, i)); } - auto related_dim_in = GetRelatedInputAxies(temp_target_trans, target_origin_domain, dim_out_level); - auto related_dim_out = GetRelatedOutputAxies(temp_target_trans, target_origin_domain, related_dim_in); + auto related_dim_in = GetRelatedInputAxies( + temp_target_trans, target_origin_domain, dim_out_level); + auto related_dim_out = GetRelatedOutputAxies( + temp_target_trans, target_origin_domain, related_dim_in); for (auto &i : related_dim_out) { if (i == pivot_dim_out) { this->CopyTransform(other, level + 1); @@ -1455,16 +1665,19 @@ void Stage::CopyTransform(Stage *other, int level) { } } else if (level >= isl_map_dim(temp_target_trans.get(), isl_dim_out)) { LOG(ERROR) << "ComputeAt level: " << level - << " is not less than the axis number : " << isl_map_dim(temp_target_trans.get(), isl_dim_out) + << " is not less than the axis number : " + << isl_map_dim(temp_target_trans.get(), isl_dim_out) << ", please check."; } - //! When this->tensor's dim is more than other->tensor, we need to supplment dims. + //! When this->tensor's dim is more than other->tensor, we need to supplment + //! dims. std::vector sup_dims; for (int i = target_map_dims.size(); i < this_map_dims.size(); i++) { sup_dims.push_back(this_map_dims[i]); } - //! Check the dim range in this domain and target domain. Correspoding dim's range must be equal. + //! Check the dim range in this domain and target domain. Correspoding dim's + //! range must be equal. auto dim_names = isl_get_dim_names(domain_.get()); std::set this_dim_names; @@ -1472,21 +1685,26 @@ void Stage::CopyTransform(Stage *other, int level) { for (int i = 0; i < isl_set_dim(domain_.get(), isl_dim_set); i++) { this_dim_names.insert(isl_set_get_dim_name(domain_.get(), isl_dim_set, i)); } - //! Delete redundant input dims in transform_ (e,g. B[i,j] -> CopyTransform(C[i,j,k]) , Redundant dim k will be - //! deleted.) + //! Delete redundant input dims in transform_ (e,g. B[i,j] -> + //! CopyTransform(C[i,j,k]) , Redundant dim k will be deleted.) for (int i = 0; i < isl_map_dim(temp_transform_.get(), isl_dim_in); i++) { - if (this_dim_names.count(isl_map_get_dim_name(temp_transform_.get(), isl_dim_in, i)) == 0) { - temp_transform_ = isl::manage(isl_map_remove_dims(temp_transform_.release(), isl_dim_in, i, 1)); + if (this_dim_names.count( + isl_map_get_dim_name(temp_transform_.get(), isl_dim_in, i)) == 0) { + temp_transform_ = isl::manage( + isl_map_remove_dims(temp_transform_.release(), isl_dim_in, i, 1)); i--; } } - //! Check related output dims in transform_ and delete them (e,g. C[i,j,k] -> C[i,j,k1,k2] , Redundant output dim k1 - //! nad k2 will be deleted.) + //! Check related output dims in transform_ and delete them (e,g. C[i,j,k] -> + //! C[i,j,k1,k2] , Redundant output dim k1 nad k2 will be deleted.) std::string new_target_trans = isl_map_to_str(temp_transform_.get()); for (int i = 0; i < isl_map_dim(temp_transform_.get(), isl_dim_out); i++) { - std::string temp_dim = isl_map_get_dim_name(temp_transform_.get(), isl_dim_out, i); - if (utils::Count(&new_target_trans, temp_dim) != utils::Count(&str_target_trans, temp_dim)) { - temp_transform_ = isl::manage(isl_map_remove_dims(temp_transform_.release(), isl_dim_out, i, 1)); + std::string temp_dim = + isl_map_get_dim_name(temp_transform_.get(), isl_dim_out, i); + if (utils::Count(&new_target_trans, temp_dim) != + utils::Count(&str_target_trans, temp_dim)) { + temp_transform_ = isl::manage( + isl_map_remove_dims(temp_transform_.release(), isl_dim_out, i, 1)); i--; } } @@ -1495,56 +1713,74 @@ void Stage::CopyTransform(Stage *other, int level) { std::set keep_names; int dim_size = isl_map_dim(temp_transform_.get(), isl_dim_out); for (int i = level + 1; i < dim_size; i++) { - std::string temp = isl_map_get_dim_name(temp_transform_.get(), isl_dim_out, i); - temp = temp.substr(0, 1); - temp = temp + "' = " + temp; + std::string temp = + isl_map_get_dim_name(temp_transform_.get(), isl_dim_out, i); + temp = temp.substr(0, 1); + temp = temp + "' = " + temp; keep_names.insert(temp); } - temp_transform_ = - isl::manage(isl_map_remove_dims(temp_transform_.release(), isl_dim_out, level + 1, dim_size - level - 1)); + temp_transform_ = isl::manage(isl_map_remove_dims(temp_transform_.release(), + isl_dim_out, + level + 1, + dim_size - level - 1)); for (auto i : keep_names) { VLOG(3) << "i in keep_names is: " << i; - temp_transform_ = isl::manage(isl_map_add_dims(temp_transform_.release(), isl_dim_out, 1)); - temp_transform_ = isl::manage(isl_map_set_dim_name(temp_transform_.release(), isl_dim_out, level + 1, i.c_str())); + temp_transform_ = isl::manage( + isl_map_add_dims(temp_transform_.release(), isl_dim_out, 1)); + temp_transform_ = isl::manage(isl_map_set_dim_name( + temp_transform_.release(), isl_dim_out, level + 1, i.c_str())); level++; } } if (sup_dims.size() > 0) { - int level_in = isl_map_dim(temp_transform_.get(), isl_dim_in); + int level_in = isl_map_dim(temp_transform_.get(), isl_dim_in); int level_out = isl_map_dim(temp_transform_.get(), isl_dim_out); for (auto i : sup_dims) { VLOG(3) << "i in sup_dims is: " << i; - temp_transform_ = isl::manage(isl_map_add_dims(temp_transform_.release(), isl_dim_in, 1)); - temp_transform_ = isl::manage(isl_map_set_dim_name(temp_transform_.release(), isl_dim_in, level_in, i.c_str())); + temp_transform_ = isl::manage( + isl_map_add_dims(temp_transform_.release(), isl_dim_in, 1)); + temp_transform_ = isl::manage(isl_map_set_dim_name( + temp_transform_.release(), isl_dim_in, level_in, i.c_str())); level_in++; std::string i_dim_out = i + "' = " + i; - temp_transform_ = isl::manage(isl_map_add_dims(temp_transform_.release(), isl_dim_out, 1)); + temp_transform_ = isl::manage( + isl_map_add_dims(temp_transform_.release(), isl_dim_out, 1)); temp_transform_ = - isl::manage(isl_map_set_dim_name(temp_transform_.release(), isl_dim_out, level_out, i_dim_out.c_str())); + isl::manage(isl_map_set_dim_name(temp_transform_.release(), + isl_dim_out, + level_out, + i_dim_out.c_str())); level_out++; } } - isl_map_set_tuple_name(temp_transform_.get(), isl_dim_in, this_tensor_name.c_str()); - isl_map_set_tuple_name(temp_transform_.get(), isl_dim_out, this_tensor_name.c_str()); + isl_map_set_tuple_name( + temp_transform_.get(), isl_dim_in, this_tensor_name.c_str()); + isl_map_set_tuple_name( + temp_transform_.get(), isl_dim_out, this_tensor_name.c_str()); std::string res_trans = isl_map_to_str(temp_transform_.get()); isl::map res_map(this_ctx, res_trans); VLOG(2) << "This domain is: " << isl_set_to_str(domain_.get()); - VLOG(2) << "After Copytransform this trans is : " << isl_map_to_str(res_map.get()); - VLOG(2) << "Target transform is : " << isl_map_to_str(other->transform().get()); + VLOG(2) << "After Copytransform this trans is : " + << isl_map_to_str(res_map.get()); + VLOG(2) << "Target transform is : " + << isl_map_to_str(other->transform().get()); VLOG(2) << "CopyTransform Level is : " << level; transform_ = res_map; } void Stage::CopyLoopInfo(Stage *other) { // copy other stage's forloop_infos - auto &target_forloop_infos = other->forloop_infos(); - auto target_transformed_domain = other->transformed_domain(); - std::vector this_dim_names = isl_get_dim_names(transformed_domain()); - std::vector target_dim_names = isl_get_dim_names(target_transformed_domain); + auto &target_forloop_infos = other->forloop_infos(); + auto target_transformed_domain = other->transformed_domain(); + std::vector this_dim_names = + isl_get_dim_names(transformed_domain()); + std::vector target_dim_names = + isl_get_dim_names(target_transformed_domain); for (auto &i : target_forloop_infos) { for (int j = 0; j < this_dim_names.size(); j++) { - int new_i = poly::isl_get_original_axes_from_optimized_level(target_transformed_domain.get(), i.first); + int new_i = poly::isl_get_original_axes_from_optimized_level( + target_transformed_domain.get(), i.first); if (target_dim_names[new_i] == this_dim_names[j]) { this->AddForloopInfo(j, i.second); } @@ -1552,11 +1788,11 @@ void Stage::CopyLoopInfo(Stage *other) { } // copy other stage's vectorize/unroll/parallel info auto &target_vectorize_info = other->vectorize_info(); - auto &target_unroll_info = other->unroll_info(); - auto &target_parallel_info = other->parallel_info(); - vectorize_info_ = target_vectorize_info; - unroll_info_ = target_unroll_info; - parallel_info_ = target_parallel_info; + auto &target_unroll_info = other->unroll_info(); + auto &target_parallel_info = other->parallel_info(); + vectorize_info_ = target_vectorize_info; + unroll_info_ = target_unroll_info; + parallel_info_ = target_parallel_info; } void Stage::LockAxis(uint32_t level) { @@ -1575,13 +1811,15 @@ bool Stage::is_axis_locked(uint32_t level) const { } void Stage::AssertAxisIsNotLocked(uint32_t level) { - CHECK(!is_axis_locked(level)) << "The " << level << "-th axis is locked, cannot perform schedule"; + CHECK(!is_axis_locked(level)) + << "The " << level << "-th axis is locked, cannot perform schedule"; } int Stage::GetTransformedLevel(int level) { if (!compute_ats().empty()) { - // The ComputeAt schedule will insert some consumer axis in the preceding of this, so the raw before ComputeAt - // should add the numbers of axis inserted. + // The ComputeAt schedule will insert some consumer axis in the preceding of + // this, so the raw before ComputeAt should add the numbers of axis + // inserted. CHECK_EQ(compute_ats().size(), 1UL); auto &compute_at = compute_ats().front(); return compute_at.level + level + 1; @@ -1593,17 +1831,22 @@ int Stage::GetTransformedLevel(int level) { void Stage::CtrlDepend(const ir::Tensor &t) { ctrl_depends_.insert(t); } -const std::set &Stage::ctrl_depends() const { return ctrl_depends_; } +const std::set &Stage::ctrl_depends() const { + return ctrl_depends_; +} ir::Tensor Stage::LookupCtrlDepend(const std::string &tensor_name) const { auto it = std::find_if( - ctrl_depends_.begin(), ctrl_depends_.end(), [&](const ir::Tensor &x) { return x->name == tensor_name; }); + ctrl_depends_.begin(), ctrl_depends_.end(), [&](const ir::Tensor &x) { + return x->name == tensor_name; + }); if (it == ctrl_depends_.end()) return ir::Tensor(); return *it; } Stage *_StageMap_::operator[](const ir::Tensor &tensor) { - CHECK(data_.count(tensor->name)) << "StageMap has no stage for tensor [" << tensor->name << "]"; + CHECK(data_.count(tensor->name)) + << "StageMap has no stage for tensor [" << tensor->name << "]"; return data_[tensor->name].get(); } const Stage *_StageMap_::operator[](const ir::Tensor &tensor) const { @@ -1611,11 +1854,13 @@ const Stage *_StageMap_::operator[](const ir::Tensor &tensor) const { return data_.at(tensor->name).get(); } Stage *_StageMap_::operator[](const ir::_Tensor_ *tensor) { - CHECK(data_.count(tensor->name)) << "StageMap has no stage for tensor [" << tensor->name << "]"; + CHECK(data_.count(tensor->name)) + << "StageMap has no stage for tensor [" << tensor->name << "]"; return data_[tensor->name].get(); } const Stage *_StageMap_::operator[](const ir::_Tensor_ *tensor) const { - CHECK(data_.count(tensor->name)) << "StageMap has no stage for tensor [" << tensor->name << "]"; + CHECK(data_.count(tensor->name)) + << "StageMap has no stage for tensor [" << tensor->name << "]"; return data_.at(tensor->name).get(); } @@ -1643,7 +1888,8 @@ StageMap CreateStages(const std::vector &tensors) { std::set all_tensors(tensors.begin(), tensors.end()); for (auto &tensor : tensors) { - auto used_tensors = ir::CollectIRNodes(tensor->body(), [](const Expr *x) { return x->as_tensor(); }); + auto used_tensors = ir::CollectIRNodes( + tensor->body(), [](const Expr *x) { return x->as_tensor(); }); for (const Expr &x : used_tensors) { all_tensors.insert(x.as_tensor_ref()); } diff --git a/paddle/cinn/poly/stage.h b/paddle/cinn/poly/stage.h index 8c11b813d06a1..869f8f038de5e 100755 --- a/paddle/cinn/poly/stage.h +++ b/paddle/cinn/poly/stage.h @@ -40,7 +40,7 @@ using ir::DeviceAPI; struct ComputeAtRelation; enum class ScopeKind { - kLocal = 0, + kLocal = 0, kShared = 1, kGlobal = 2, }; @@ -53,7 +53,8 @@ struct StageForloopInfo { : for_type(for_type), device(device), offset(offset) {} ir::ForType for_type; - //! The offset in the \p for_type. e.g. for GPUBlock, 0 represents blockIdx.x, 1 is blockIdx.y, 2 is blockIdx.z. + //! The offset in the \p for_type. e.g. for GPUBlock, 0 represents blockIdx.x, + //! 1 is blockIdx.y, 2 is blockIdx.z. uint8_t offset; ir::DeviceAPI device; }; @@ -76,8 +77,9 @@ struct ComputeAtInfo { //! The shape of the buffer belong to the producer tensor after compute_at. //! NOTE this doesn't support dynamic dimension yet. std::vector adjusted_producer_shape; - //! The preceding offsets for the indice in the Loads for the producers, the offset will make the minimum indice to be - //! 0, size of this should equal to level+1. + //! The preceding offsets for the indice in the Loads for the producers, the + //! offset will make the minimum indice to be 0, size of this should equal to + //! level+1. std::vector preceding_offset_for_producer_load; //! the level of the consumer tensor's transformed range. int level{-1}; @@ -87,7 +89,8 @@ struct ComputeAtInfo { * Meta infomation for tensor. */ struct TensorScheduleMeta { - //! Store the information of all the other producer tensors `compute_at` this tensor. + //! Store the information of all the other producer tensors `compute_at` this + //! tensor. std::vector compute_at_infos; bool compute_inline{false}; @@ -102,7 +105,9 @@ struct TensorScheduleMeta { */ class Stage : public Object { public: - static Shared New(const isl::set& domain, Expr expr = Expr(), ir::_Tensor_* tensor = nullptr); + static Shared New(const isl::set& domain, + Expr expr = Expr(), + ir::_Tensor_* tensor = nullptr); TensorScheduleMeta meta; @@ -188,7 +193,10 @@ class Stage : public Object { * @return the new iterators. */ std::tuple // - Tile(const Iterator& level0, const Iterator& level1, int factor0, int factor1); + Tile(const Iterator& level0, + const Iterator& level1, + int factor0, + int factor1); std::tuple // Tile(int level0, int level1, int factor0, int factor1); @@ -226,7 +234,8 @@ class Stage : public Object { }; /** - * Apply loop skewing on the loop levels \p i and \p j with a skewing factor of \p factor. + * Apply loop skewing on the loop levels \p i and \p j with a skewing factor + * of \p factor. * TODO(Superjomn) Refine this transform. */ std::tuple // @@ -239,46 +248,53 @@ class Stage : public Object { /** * Set the memory type of this stage's tensor. - * @param memory_type the memory type of this tensor. For example, memory_type="shared". + * @param memory_type the memory type of this tensor. For example, + * memory_type="shared". */ void SetBuffer(const std::string& memory_type); /** - * Given two stages already satisfy ComputeAtRelation.IsCompatible, set compute_ats_ for them. + * Given two stages already satisfy ComputeAtRelation.IsCompatible, set + * compute_ats_ for them. * @param other the other stage to set compute_ats_. * @param level the level of ComputeAtRelation. */ void SimpleComputeAt(Stage* other, int level); /** - * Create a cache Tensor and load the \p source into this buffer, replace all the reading in the readers with the - * cache. + * Create a cache Tensor and load the \p source into this buffer, replace all + * the reading in the readers with the cache. * @param tensor the source memory to cache. - * @param memory_type the memory type, "share" for CUDA share memory, "local" for CUDA local memory. + * @param memory_type the memory type, "share" for CUDA share memory, "local" + * for CUDA local memory. * @param readers the readers of the \p tensor */ - ir::Tensor CacheRead(const std::string& memory_type, std::vector& readers, poly::StageMap stages); + ir::Tensor CacheRead(const std::string& memory_type, + std::vector& readers, + poly::StageMap stages); /** - * \brief Mark the stage compute at the level of some other stage. Usually used when there is no access relation - * between two tensors. + * \brief Mark the stage compute at the level of some other stage. Usually + * used when there is no access relation between two tensors. * - * The difference bewteen ComputeAt2 and ComputeAt is that ComputeAt2 can be used when there is no access relation - * between two tensors. + * The difference bewteen ComputeAt2 and ComputeAt is that ComputeAt2 can be + * used when there is no access relation between two tensors. * * @param other the target stage to compute at. * @param level the level of \p other's forloop to compute at */ void ComputeAt2(Stage* other, int level); - // Do ComputeAt2 except for setting the ComputeAt level, which is moving the computations together. + // Do ComputeAt2 except for setting the ComputeAt level, which is moving the + // computations together. void ComputeAt3(Stage* other, int level); /** * \brief Mark the stage compute at the level of some other stage. * - * NOTE This can only be called after all transformations are preformed, and once called, no further transform can - * perform for that if the iterators are changed, the original `ComputeAt` level will become invalid. + * NOTE This can only be called after all transformations are preformed, and + * once called, no further transform can perform for that if the iterators are + * changed, the original `ComputeAt` level will become invalid. * * @param other the target stage to compute at. * @param level the level of \p other's forloop to compute at @@ -291,13 +307,17 @@ class Stage : public Object { /** * Create a cache for write to the original tensor. * @param tensor the tensor to create the cache for. - * @param memory_type "share" for CUDA share memory, "local" for CUDA local memory. + * @param memory_type "share" for CUDA share memory, "local" for CUDA local + * memory. */ - ir::Tensor CacheWrite(const std::string& memory_type, poly::StageMap stages, ir::Tensor& key_tensor); + ir::Tensor CacheWrite(const std::string& memory_type, + poly::StageMap stages, + ir::Tensor& key_tensor); /** * Generate the `syncthreads()` code to sync all threads on CUDA backends. - * For other backends like Opencl, generate corresponding code to sync multi threads. + * For other backends like Opencl, generate corresponding code to sync multi + * threads. * @param tensor the exact tensor computed just before syncthreads. * @param stages the stagemap of all tensor. */ @@ -305,8 +325,10 @@ class Stage : public Object { /** * Generate the `syncthreads()` code to sync all threads on CUDA backends. - * For other backends like Opencl, generate corresponding code to sync multi threads. - * @param level the ComputeAt level of syncthreads in this tensor's computation. + * For other backends like Opencl, generate corresponding code to sync multi + * threads. + * @param level the ComputeAt level of syncthreads in this tensor's + * computation. * @param before_tensors the tensors computed before syncthreads. * @param stages the stagemap of all tensor. * Example Code : @@ -320,7 +342,9 @@ class Stage : public Object { * for (j = 0:9) * A[i,j] */ - void SyncThreads(int level, const std::vector& before_tensors, StageMap stages); + void SyncThreads(int level, + const std::vector& before_tensors, + StageMap stages); /** * Set thread scope. @@ -357,7 +381,8 @@ class Stage : public Object { Iterator ith_iterator(int level); /** Get the final level after all the transforms. - * The level will be affected by some schedule like ComputeAt, this will return the right level. + * The level will be affected by some schedule like ComputeAt, this will + * return the right level. * * @param level the level in schedule. */ @@ -368,38 +393,58 @@ class Stage : public Object { virtual const char* type_info() const { return __type_info__; } - inline const ir::VectorizeInfo& vectorize_info() const { return vectorize_info_; } + inline const ir::VectorizeInfo& vectorize_info() const { + return vectorize_info_; + } inline const std::set& unroll_info() const { return unroll_info_; } inline const std::set& parallel_info() const { return parallel_info_; } - inline std::map& GetComputeAts() { return compute_ats_; } - inline void SetComputeAts(const std::map& compute_ats) { compute_ats_ = compute_ats; } + inline std::map& GetComputeAts() { + return compute_ats_; + } + inline void SetComputeAts( + const std::map& compute_ats) { + compute_ats_ = compute_ats; + } /* - const std::set& extra_depend_stages() const { return extra_depend_stages_; } - void set_extra_depend_stages(const std::set& x) { extra_depend_stages_ = x; } - void add_extra_depend_stage(const std::string& statement) { extra_depend_stages_.insert(statement); } + const std::set& extra_depend_stages() const { return + extra_depend_stages_; } void set_extra_depend_stages(const + std::set& x) { extra_depend_stages_ = x; } void + add_extra_depend_stage(const std::string& statement) { + extra_depend_stages_.insert(statement); } */ - const std::map& forloop_infos() const { return forloop_infos_; } + const std::map& forloop_infos() const { + return forloop_infos_; + } bool has_expression() const; Stage() = default; - void ComputeAtSchedule(Stage* other, int level, ComputeAtKind kind = kComputeAtAuto); + void ComputeAtSchedule(Stage* other, + int level, + ComputeAtKind kind = kComputeAtAuto); ir::Tensor LookupCtrlDepend(const std::string& tensor_name) const; - //! Get number of transform output dimensions, this equals to the number of forloops in generated code. - inline int n_in_dims() const { return isl_map_dim(transform_.get(), isl_dim_in); } - //! Get number of transform output dimensions, this equals to the number of dimensions of corresponding tensor. - inline int n_out_dims() const { return isl_map_dim(transform_.get(), isl_dim_out); } + //! Get number of transform output dimensions, this equals to the number of + //! forloops in generated code. + inline int n_in_dims() const { + return isl_map_dim(transform_.get(), isl_dim_in); + } + //! Get number of transform output dimensions, this equals to the number of + //! dimensions of corresponding tensor. + inline int n_out_dims() const { + return isl_map_dim(transform_.get(), isl_dim_out); + } //! Copy other stage's transform. //! For example, if the target_transform is `Split(0,1)`, //! this api will apply `Split(0,1)` on itself. void CopyTransform(Stage* other, int level = -1); - //! Edit temp tensor's shape, its buffer's shape and index when doing ComputeAt2. + //! Edit temp tensor's shape, its buffer's shape and index when doing + //! ComputeAt2. void EditTempTensor(Stage* other, int level); //! Copy other stage's LoopInfo. //! For example, if the target_forloop_infos is `Bind(0,"threadIdx.x")`, @@ -408,12 +453,16 @@ class Stage : public Object { //! Set stage's transform_ void SetTransform(isl::map new_transform) { transform_ = new_transform; } //! Set stage's forloop_infos_ - void SetForloopInfo(std::map forloop_infos) { forloop_infos_ = forloop_infos; } + void SetForloopInfo(std::map forloop_infos) { + forloop_infos_ = forloop_infos; + } void AddForloopInfo(int level, const StageForloopInfo& info); bool IfCudaBind() { return cuda_bind_info_; } private: - explicit Stage(const isl::set& domain, Expr expr = Expr(), ir::_Tensor_* tensor = nullptr); + explicit Stage(const isl::set& domain, + Expr expr = Expr(), + ir::_Tensor_* tensor = nullptr); /** * Initialize with an identity schedule. @@ -453,11 +502,13 @@ class Stage : public Object { std::set locked_axis_; bool cuda_bind_info_{false}; - friend isl_map* __isl_give GatherAccesses(Stage* stage, const std::string& tensor_name); + friend isl_map* __isl_give GatherAccesses(Stage* stage, + const std::string& tensor_name); friend class PolyGroupScheduler; }; -std::vector> ExtractExtraDepLinksFromStages(const std::vector& stages); +std::vector> ExtractExtraDepLinksFromStages( + const std::vector& stages); //! This stage compute_at some other stage. struct ComputeAtRelation { @@ -476,12 +527,15 @@ inline std::string InnerName(const Iterator& iterator); inline std::string OuterName(const std::string& name); inline std::string OuterName(const Iterator& iterator); -inline Iterator DefaultIterator(int i) { return Iterator(common::axis_name(i)); } +inline Iterator DefaultIterator(int i) { + return Iterator(common::axis_name(i)); +} /** * Collect the access to a tensor named \p tensor_name in \p stage. */ -std::vector GatherAccesses(const Stage* stage, const std::string& tensor_name); +std::vector GatherAccesses(const Stage* stage, + const std::string& tensor_name); class _StageMap_ : public Object { public: @@ -523,9 +577,13 @@ class StageMap : public Shared<_StageMap_> { StageMap() : Shared(new _StageMap_) {} Stage* operator[](const ir::Tensor& tensor) { return (*self())[tensor]; } - const Stage* operator[](const ir::Tensor& tensor) const { return (*self())[tensor]; } + const Stage* operator[](const ir::Tensor& tensor) const { + return (*self())[tensor]; + } Stage* operator[](const ir::_Tensor_* tensor) { return (*self())[tensor]; } - const Stage* operator[](const ir::_Tensor_* tensor) const { return (*self())[tensor]; } + const Stage* operator[](const ir::_Tensor_* tensor) const { + return (*self())[tensor]; + } auto begin() const { return self()->data_.begin(); } auto end() const { return self()->data_.end(); } diff --git a/paddle/cinn/poly/stage_test.cc b/paddle/cinn/poly/stage_test.cc index cf0a629858c88..693fef4783724 100755 --- a/paddle/cinn/poly/stage_test.cc +++ b/paddle/cinn/poly/stage_test.cc @@ -32,7 +32,8 @@ namespace poly { // Create a call. Expr CreateCall(const std::string& name, const std::vector& args) { - auto expr = ir::Call::Make(Float(32), name, args, {}, ir::CallType::CINN, ir::FunctionRef(), 0); + auto expr = ir::Call::Make( + Float(32), name, args, {}, ir::CallType::CINN, ir::FunctionRef(), 0); return expr; } @@ -40,13 +41,14 @@ TEST(Stage, split) { isl::ctx ctx(isl_ctx_alloc()); isl::set domain(ctx, "{ S[i,j]: 0<=i,j<=100 }"); - auto ele = Stage::New(domain); + auto ele = Stage::New(domain); auto _outer_inner_ = ele->Split(Iterator("i"), 4); // NOLINT - auto& outer = std::get<0>(_outer_inner_); - auto& inner = std::get<1>(_outer_inner_); + auto& outer = std::get<0>(_outer_inner_); + auto& inner = std::get<1>(_outer_inner_); LOG(INFO) << ele->transform(); EXPECT_EQ(utils::GetStreamCnt(ele->transform()), - "{ S[i, j] -> S[i_outer, i_inner, j' = j] : (-i + i_inner) mod 4 = 0 and -3 + i <= 4i_outer <= i and 0 <= " + "{ S[i, j] -> S[i_outer, i_inner, j' = j] : (-i + i_inner) mod 4 = " + "0 and -3 + i <= 4i_outer <= i and 0 <= " "i_inner <= 3 }"); EXPECT_EQ(outer.id, "i_outer"); @@ -58,20 +60,22 @@ TEST(Stage, tile) { isl::set domain(ctx, "{ S[i,j,k]: 0<=i,j,k<=100 }"); auto ele = Stage::New(domain); - auto _outer0_inner0_outer1_inner1_ = ele->Tile(Iterator("i"), Iterator("j"), 4, 6); // NOLINT - auto& outer0 = std::get<0>(_outer0_inner0_outer1_inner1_); - auto& inner0 = std::get<1>(_outer0_inner0_outer1_inner1_); - auto& outer1 = std::get<2>(_outer0_inner0_outer1_inner1_); - auto& inner1 = std::get<3>(_outer0_inner0_outer1_inner1_); + auto _outer0_inner0_outer1_inner1_ = + ele->Tile(Iterator("i"), Iterator("j"), 4, 6); // NOLINT + auto& outer0 = std::get<0>(_outer0_inner0_outer1_inner1_); + auto& inner0 = std::get<1>(_outer0_inner0_outer1_inner1_); + auto& outer1 = std::get<2>(_outer0_inner0_outer1_inner1_); + auto& inner1 = std::get<3>(_outer0_inner0_outer1_inner1_); LOG(INFO) << ele->transform(); EXPECT_EQ(outer0.id, "i_outer"); EXPECT_EQ(outer1.id, "j_outer"); EXPECT_EQ(inner0.id, "i_inner"); EXPECT_EQ(outer1.id, "j_outer"); - EXPECT_EQ( - utils::GetStreamCnt(ele->transform()), - "{ S[i, j, k] -> S[i_outer, i_inner, j_outer, j_inner, k' = k] : (-i + i_inner) mod 4 = 0 and (-j + j_inner) mod " - "6 = 0 and -3 + i <= 4i_outer <= i and 0 <= i_inner <= 3 and -5 + j <= 6j_outer <= j and 0 <= j_inner <= 5 }"); + EXPECT_EQ(utils::GetStreamCnt(ele->transform()), + "{ S[i, j, k] -> S[i_outer, i_inner, j_outer, j_inner, k' = k] : " + "(-i + i_inner) mod 4 = 0 and (-j + j_inner) mod " + "6 = 0 and -3 + i <= 4i_outer <= i and 0 <= i_inner <= 3 and -5 + " + "j <= 6j_outer <= j and 0 <= j_inner <= 5 }"); } TEST(Stage, reorder) { @@ -86,10 +90,10 @@ TEST(Stage, reorder) { TEST(Stage, split_reorder) { isl::ctx ctx(isl_ctx_alloc()); isl::set domain(ctx, "{ S[i,j,k]: 0<=i,j,k<=100 }"); - auto ele = Stage::New(domain); + auto ele = Stage::New(domain); auto _outer_inner_ = ele->Split(Iterator("i"), 4); // NOLINT - auto& outer = std::get<0>(_outer_inner_); - auto& inner = std::get<1>(_outer_inner_); + auto& outer = std::get<0>(_outer_inner_); + auto& inner = std::get<1>(_outer_inner_); Iterator i("i"), j("j"), k("k"); ele->Reorder(std::vector{{outer, k, inner, j}}); @@ -113,10 +117,10 @@ TEST(ComputeAtRelation, basic) { TEST(Stage, Fuse) { isl::ctx ctx(isl_ctx_alloc()); isl::set domain(ctx, "{ S[i,j,k]: 0<=i,j,k<=100 }"); - auto ele = Stage::New(domain); + auto ele = Stage::New(domain); auto _outer_inner_ = ele->Split(Iterator("i"), 4); // NOLINT - auto& outer = std::get<0>(_outer_inner_); - auto& inner = std::get<1>(_outer_inner_); + auto& outer = std::get<0>(_outer_inner_); + auto& inner = std::get<1>(_outer_inner_); LOG(INFO) << "split: " << ele->transform(); ele->Fuse(outer, inner); LOG(INFO) << "fused: " << ele->transform(); @@ -140,7 +144,9 @@ TEST(ComputeAt, Before) { auto A_cache = Compute( {M, N}, [&](Expr i, Expr j) { return A(i, j); }, "cache"); auto C = Compute( - {Expr(10), Expr(10)}, [&](Expr i, Expr j) { return A_cache(i, j) + B(i, j); }, "C"); + {Expr(10), Expr(10)}, + [&](Expr i, Expr j) { return A_cache(i, j) + B(i, j); }, + "C"); auto stages = CreateStages({A_cache, C}); @@ -174,7 +180,9 @@ TEST(ComputeAt, simple) { auto A1 = Compute( {n, n}, [&](Expr i, Expr j) { return A(i, j); }, "A1"); auto B = Compute( - {n / 2, n / 2}, [&](Expr i, Expr j) { return A1(i, j) + A1(i + 1, j) + A1(i + 2, j); }, "B"); + {n / 2, n / 2}, + [&](Expr i, Expr j) { return A1(i, j) + A1(i + 1, j) + A1(i + 2, j); }, + "B"); auto stages = CreateStages({B}); stages[B]->Split(0, 16); @@ -212,7 +220,9 @@ function fn (_A, _A1, _B) CodeGenC codegen(common::DefaultHostTarget()); codegen.SetInlineBuiltinCodes(false); - LOG(INFO) << "source:\n" << codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + LOG(INFO) << "source:\n" + << codegen.Compile(builder.Build(), + backends::CodeGenC::OutputKind::CImpl); } } @@ -223,7 +233,10 @@ TEST(ComputeAt, Before1) { auto create_module = [&] { // cached compute way - auto cache_prepare = Compute({M, N} /*domain*/, [&](Var i, Var j) { return A(i, j); }, "cache", {N} /*shape*/); + auto cache_prepare = Compute({M, N} /*domain*/, + [&](Var i, Var j) { return A(i, j); }, + "cache", + {N} /*shape*/); auto transformed_compute = Compute( {M, N}, [&](Var i, Var j) { return Expr(1.f); }, "transformed"); @@ -233,8 +246,9 @@ TEST(ComputeAt, Before1) { { // C_init Before C auto _cache_prepare_transformed_compute_ = create_module(); - auto& cache_prepare = std::get<0>(_cache_prepare_transformed_compute_); - auto& transformed_compute = std::get<1>(_cache_prepare_transformed_compute_); + auto& cache_prepare = std::get<0>(_cache_prepare_transformed_compute_); + auto& transformed_compute = + std::get<1>(_cache_prepare_transformed_compute_); auto stages = CreateStages({cache_prepare, transformed_compute}); stages[cache_prepare]->ComputeAt2(stages[transformed_compute], 1); @@ -261,8 +275,9 @@ function fn (_A, _cache, _transformed) } { // C_init After C auto _cache_prepare_transformed_compute_ = create_module(); - auto& cache_prepare = std::get<0>(_cache_prepare_transformed_compute_); - auto& transformed_compute = std::get<1>(_cache_prepare_transformed_compute_); + auto& cache_prepare = std::get<0>(_cache_prepare_transformed_compute_); + auto& transformed_compute = + std::get<1>(_cache_prepare_transformed_compute_); auto stages = CreateStages({cache_prepare, transformed_compute}); stages[transformed_compute]->ComputeAt2(stages[cache_prepare], 1); @@ -288,7 +303,8 @@ function fn (_A, _cache, _transformed) } } -void TestElementwiseAddJitPrecession(std::function&& scheduler) { +void TestElementwiseAddJitPrecession( + std::function&& scheduler) { Expr M(30); Expr N(40); Placeholder A("A", {M, N}); @@ -313,10 +329,17 @@ void TestElementwiseAddJitPrecession(std::function& auto* fn_handler = reinterpret_cast(_fn_handler); // create buffer and args - auto A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); - auto B_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); - auto C_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_zero().Build(); - auto arg_pack = common::ArgsBuilder().Add(A_buf).Add(B_buf).Add(C_buf).Build(); + auto A_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_random() + .Build(); + auto B_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_random() + .Build(); + auto C_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_zero() + .Build(); + auto arg_pack = + common::ArgsBuilder().Add(A_buf).Add(B_buf).Add(C_buf).Build(); fn_handler(arg_pack.data(), arg_pack.size()); @@ -334,21 +357,23 @@ void TestElementwiseAddJitPrecession(std::function& // use an elementwise_add to test fuse precision TEST(Fuse, jit_precision_test) { - TestElementwiseAddJitPrecession([](ir::Tensor* C, StageMap stages) { stages[(*C)]->Fuse(0, 1); }); + TestElementwiseAddJitPrecession( + [](ir::Tensor* C, StageMap stages) { stages[(*C)]->Fuse(0, 1); }); } // split test fuse precision TEST(Fuse, jit_precision_test2) { TestElementwiseAddJitPrecession([](ir::Tensor* C, StageMap stages) { auto _i_outer_i_inner_ = stages[(*C)]->Split(0, 4); - auto& i_outer = std::get<0>(_i_outer_i_inner_); - auto& i_inner = std::get<1>(_i_outer_i_inner_); + auto& i_outer = std::get<0>(_i_outer_i_inner_); + auto& i_inner = std::get<1>(_i_outer_i_inner_); stages[(*C)]->Fuse(i_outer, i_inner); }); } TEST(Tile, jit_precision_test) { - TestElementwiseAddJitPrecession([](ir::Tensor* C, StageMap stages) { stages[(*C)]->Tile(0, 1, 4, 4); }); + TestElementwiseAddJitPrecession( + [](ir::Tensor* C, StageMap stages) { stages[(*C)]->Tile(0, 1, 4, 4); }); } TEST(Reorder, jit_precision_test) { @@ -359,11 +384,13 @@ TEST(Reorder, jit_precision_test) { } TEST(Unroll, jit_precision_test) { - TestElementwiseAddJitPrecession([](ir::Tensor* C, StageMap stages) { stages[(*C)]->Unroll(1); }); + TestElementwiseAddJitPrecession( + [](ir::Tensor* C, StageMap stages) { stages[(*C)]->Unroll(1); }); } TEST(Unroll, jit_precision_test1) { - TestElementwiseAddJitPrecession([](ir::Tensor* C, StageMap stages) { stages[*C]->Unroll(0); }); + TestElementwiseAddJitPrecession( + [](ir::Tensor* C, StageMap stages) { stages[*C]->Unroll(0); }); } TEST(ComputeInline, basic) { @@ -490,15 +517,20 @@ TEST(ShareBufferWith, basic) { CodeGenC codegen(common::DefaultHostTarget()); codegen.SetInlineBuiltinCodes(false); - LOG(INFO) << "\n" << codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + LOG(INFO) << "\n" + << codegen.Compile(builder.Build(), + backends::CodeGenC::OutputKind::CImpl); } TEST(isl, test) { isl::ctx ctx(isl_ctx_alloc()); - isl::set domain( - ctx, "[p0, p1] -> { p[i, j] : p0 = 0 and 0 <= p1 <= 2 and 4p1 <= i <= 1 + 4p1 and 0 <= j <= 9 + 4p1 - i }"); + isl::set domain(ctx, + "[p0, p1] -> { p[i, j] : p0 = 0 and 0 <= p1 <= 2 and 4p1 <= " + "i <= 1 + 4p1 and 0 <= j <= 9 + 4p1 - i }"); - isl::map schedule(ctx, "[p0, p1] -> { p[i, j] -> p[t0, t1, t2 = j] : 2t1 = i and (t0) mod 2 = 0 and 0 <= t0 <= 1 }"); + isl::map schedule(ctx, + "[p0, p1] -> { p[i, j] -> p[t0, t1, t2 = j] : 2t1 = i and " + "(t0) mod 2 = 0 and 0 <= t0 <= 1 }"); auto schedule_intersected = schedule.intersect_domain(domain); LOG(INFO) << "schedule_intersected: " << schedule_intersected.coalesce(); @@ -507,46 +539,55 @@ TEST(isl, test) { LOG(INFO) << "space: " << context.space(); auto* build = isl_ast_build_from_context(context.release()); - auto* node = isl_ast_build_node_from_schedule_map(build, isl_union_map_from_map(schedule_intersected.release())); + auto* node = isl_ast_build_node_from_schedule_map( + build, isl_union_map_from_map(schedule_intersected.release())); LOG(INFO) << "code:\n" << isl_ast_node_to_C_str(node); } TEST(isl, test1) { isl::ctx ctx(isl_ctx_alloc()); - isl::set domain( - ctx, "[p0, p1] -> { p[i, j] : p0 = 0 and 0 <= p1 <= 2 and 4p1 <= i <= 1 + 4p1 and 0 <= j <= 9 + 4p1 - i }"); - isl::map schedule( - ctx, - "[p0, p1] -> { p[i, j] -> p[o0, o1, t0, t1, t2 = j] : 2t1 = i and (o0) mod 4 = 0 and (t0) mod 2 = 0 " - "and 0 <= o0 <= 3 and 0 <= o1 <= 2 and 0 <= t0 <= 1 }"); + isl::set domain(ctx, + "[p0, p1] -> { p[i, j] : p0 = 0 and 0 <= p1 <= 2 and 4p1 <= " + "i <= 1 + 4p1 and 0 <= j <= 9 + 4p1 - i }"); + isl::map schedule(ctx, + "[p0, p1] -> { p[i, j] -> p[o0, o1, t0, t1, t2 = j] : 2t1 " + "= i and (o0) mod 4 = 0 and (t0) mod 2 = 0 " + "and 0 <= o0 <= 3 and 0 <= o1 <= 2 and 0 <= t0 <= 1 }"); isl::map schedule_t(ctx, - "[p0,p1] -> { p[i0,i1,i2,i3,i4] -> [t0,t1,t2,t3,t30,t4] : t0 =i0 and t1 = i1 and t2 = i2 and t3 " + "[p0,p1] -> { p[i0,i1,i2,i3,i4] -> [t0,t1,t2,t3,t30,t4] " + ": t0 =i0 and t1 = i1 and t2 = i2 and t3 " "= i3 and t4 = i4 and t30=0 }"); isl::set cdomain(ctx, "[p0,p1] -> { c[a,b,c]: 0<=a,b,c<10 }"); - isl::map cschedule(ctx, "[p0,p1] -> { c[a,b,c] -> c[t0,t1,t2,t3]: t0=a%4 and t1=a/4 and t2=b and t3=c }"); + isl::map cschedule(ctx, + "[p0,p1] -> { c[a,b,c] -> c[t0,t1,t2,t3]: t0=a%4 and " + "t1=a/4 and t2=b and t3=c }"); isl::map schedule_t1(ctx, - "[p0,p1] -> { c[i0,i1,i2,i3] -> [t0,t1,t2,t3,t30,t4] : t0 =i0 and t1 = i1 and t2 = i2 and t3=i3 " + "[p0,p1] -> { c[i0,i1,i2,i3] -> [t0,t1,t2,t3,t30,t4] : " + "t0 =i0 and t1 = i1 and t2 = i2 and t3=i3 " "and t4=0 and t30=1 }"); - schedule = schedule.apply_range(schedule_t); + schedule = schedule.apply_range(schedule_t); cschedule = cschedule.apply_range(schedule_t1); auto whole_domain = isl::manage(isl_union_set_from_set(domain.copy())); - whole_domain = isl::manage(isl_union_set_add_set(whole_domain.release(), cdomain.copy())); + whole_domain = isl::manage( + isl_union_set_add_set(whole_domain.release(), cdomain.copy())); auto whole_schedule = isl::manage(isl_union_map_from_map(schedule.copy())); - whole_schedule = isl::manage(isl_union_map_add_map(whole_schedule.release(), cschedule.copy())); + whole_schedule = isl::manage( + isl_union_map_add_map(whole_schedule.release(), cschedule.copy())); auto intersect_schedule = whole_schedule.intersect_domain(whole_domain); isl::set context(ctx, "[p0,p1]->{:p0<100 and p1<100}"); auto* build = isl_ast_build_from_context(context.release()); - auto* node = isl_ast_build_node_from_schedule_map(build, intersect_schedule.release()); + auto* node = + isl_ast_build_node_from_schedule_map(build, intersect_schedule.release()); LOG(INFO) << "code:\n\n" << isl_ast_node_to_C_str(node); } diff --git a/paddle/cinn/pybind/backends.cc b/paddle/cinn/pybind/backends.cc index bd5a116bdbb89..4e589380223df 100644 --- a/paddle/cinn/pybind/backends.cc +++ b/paddle/cinn/pybind/backends.cc @@ -41,30 +41,38 @@ void BindExecutionEngine(py::module *m) { .def_readwrite("enable_debug_info", &ExecutionOptions::enable_debug_info); auto lookup = [](ExecutionEngine &self, absl::string_view name) { - auto *function_ptr = reinterpret_cast(self.Lookup(name)); - auto function_wrapper = [function_ptr](std::vector &args) { - function_ptr(reinterpret_cast(args.data()), args.size()); - }; - return std::function &)>(function_wrapper); + auto *function_ptr = + reinterpret_cast(self.Lookup(name)); + auto function_wrapper = + [function_ptr](std::vector &args) { + function_ptr(reinterpret_cast(args.data()), args.size()); + }; + return std::function &)>( + function_wrapper); }; py::class_ engine(*m, "ExecutionEngine"); engine - .def_static("create", - py::overload_cast(&ExecutionEngine::Create), - py::arg("options") = ExecutionOptions()) - .def(py::init(py::overload_cast(&ExecutionEngine::Create)), + .def_static( + "create", + py::overload_cast(&ExecutionEngine::Create), + py::arg("options") = ExecutionOptions()) + .def(py::init(py::overload_cast( + &ExecutionEngine::Create)), py::arg("options") = ExecutionOptions()) .def("lookup", lookup) .def("link", &ExecutionEngine::Link); { auto lookup = [](Compiler &self, absl::string_view name) { - auto *function_ptr = reinterpret_cast(self.Lookup(name)); - auto function_wrapper = [function_ptr](std::vector &args) { - function_ptr(reinterpret_cast(args.data()), args.size()); - }; - return std::function &)>(function_wrapper); + auto *function_ptr = + reinterpret_cast(self.Lookup(name)); + auto function_wrapper = + [function_ptr](std::vector &args) { + function_ptr(reinterpret_cast(args.data()), args.size()); + }; + return std::function &)>( + function_wrapper); }; py::class_ compiler(*m, "Compiler"); diff --git a/paddle/cinn/pybind/bind.cc b/paddle/cinn/pybind/bind.cc index 3ee99abacd8ad..bf1285957e245 100644 --- a/paddle/cinn/pybind/bind.cc +++ b/paddle/cinn/pybind/bind.cc @@ -24,17 +24,23 @@ namespace cinn::pybind { PYBIND11_MODULE(core_api, m) { m.doc() = "CINN core API"; - py::module runtime = m.def_submodule("runtime", "bind cinn_runtime"); - py::module common = m.def_submodule("common", "namespace cinn::common"); - py::module lang = m.def_submodule("lang", "namespace cinn::lang"); - py::module ir = m.def_submodule("ir", "namespace cinn::ir"); - py::module poly = m.def_submodule("poly", "namespace cinn::poly, polyhedral"); - py::module backends = m.def_submodule("backends", "namespace cinn::backends, execution backends"); - py::module optim = m.def_submodule("optim", "namespace cinn::optim, CINN IR optimization"); - py::module pe = m.def_submodule("pe", "namespace cinn::hlir::pe, CINN Primitive Emitters"); - py::module frontend = m.def_submodule("frontend", "namespace cinn::frontend, CINN frontend"); - py::module framework = m.def_submodule("framework", "namespace cinn::hlir::framework, CINN framework"); - py::module utils = m.def_submodule("utils", "namespace cinn::utils, CINN framework"); + py::module runtime = m.def_submodule("runtime", "bind cinn_runtime"); + py::module common = m.def_submodule("common", "namespace cinn::common"); + py::module lang = m.def_submodule("lang", "namespace cinn::lang"); + py::module ir = m.def_submodule("ir", "namespace cinn::ir"); + py::module poly = m.def_submodule("poly", "namespace cinn::poly, polyhedral"); + py::module backends = m.def_submodule( + "backends", "namespace cinn::backends, execution backends"); + py::module optim = + m.def_submodule("optim", "namespace cinn::optim, CINN IR optimization"); + py::module pe = m.def_submodule( + "pe", "namespace cinn::hlir::pe, CINN Primitive Emitters"); + py::module frontend = + m.def_submodule("frontend", "namespace cinn::frontend, CINN frontend"); + py::module framework = m.def_submodule( + "framework", "namespace cinn::hlir::framework, CINN framework"); + py::module utils = + m.def_submodule("utils", "namespace cinn::utils, CINN framework"); BindRuntime(&runtime); BindCommon(&common); diff --git a/paddle/cinn/pybind/bind.h b/paddle/cinn/pybind/bind.h index 2d0ed01db09f4..cb56cae0096cf 100644 --- a/paddle/cinn/pybind/bind.h +++ b/paddle/cinn/pybind/bind.h @@ -23,12 +23,19 @@ namespace pybind11 { namespace detail { -template +template struct type_caster> - : map_caster, Key, Value> {}; + : map_caster, + Key, + Value> {}; template <> -struct type_caster : string_caster {}; +struct type_caster : string_caster { +}; } // namespace detail } // namespace pybind11 diff --git a/paddle/cinn/pybind/bind_utils.h b/paddle/cinn/pybind/bind_utils.h index 3b00029aaa31f..71374d41e1205 100644 --- a/paddle/cinn/pybind/bind_utils.h +++ b/paddle/cinn/pybind/bind_utils.h @@ -36,7 +36,7 @@ using common::Type; using ir::Expr; using ir::ExprNode; -using ExprOp = absl::variant; using BinaryOp = absl::variant<>; -using UnaryOp = absl::variant<>; +using UnaryOp = absl::variant<>; // hold CINNValue -using ValueVar = absl::variant; +using ValueVar = + absl::variant; inline ValueVar ConvertToVar(const CINNValue &value) { auto type_code = value.type_code(); @@ -90,7 +91,9 @@ auto DefineShared(py::module *m, absl::string_view obj_name) { std::string name = "Shared" + std::string(obj_name); py::class_> shared(*m, name.c_str()); - shared.def(py::init<>()).def(py::init()).def(py::init &>()); + shared.def(py::init<>()) + .def(py::init()) + .def(py::init &>()); return shared; } @@ -100,14 +103,20 @@ void DefineExprNode(py::module *m, absl::string_view node_name) { std::string prefix{"ExprNode"}; std::string name = prefix + std::string(node_name); - py::class_ expr_node(*m, name.c_str(), py::module_local()); + py::class_ expr_node( + *m, name.c_str(), py::module_local()); expr_node.def(py::init<>()) .def(py::init()) .def(py::init()) .def("operands_mutable", py::overload_cast<>(&ExprNodeT::operands)) - .def("operands_const", py::overload_cast<>(&ExprNodeT::operands, py::const_)) - .def("operand_mutable", py::overload_cast(&ExprNodeT::operand), py::return_value_policy::reference) - .def("operand_const", py::overload_cast(&ExprNodeT::operand, py::const_), py::return_value_policy::reference) + .def("operands_const", + py::overload_cast<>(&ExprNodeT::operands, py::const_)) + .def("operand_mutable", + py::overload_cast(&ExprNodeT::operand), + py::return_value_policy::reference) + .def("operand_const", + py::overload_cast(&ExprNodeT::operand, py::const_), + py::return_value_policy::reference) .def("copy", &ExprNodeT::Copy) .def("node_type", &ExprNodeT::node_type); } @@ -116,18 +125,29 @@ template void DefineBinaryOpNode(py::module *m, absl::string_view node_name) { DefineExprNode(m, node_name); std::string prefix{"BinaryOpNode"}; - std::string name = prefix + std::string(node_name); + std::string name = prefix + std::string(node_name); using BinaryOpNodeT = ir::BinaryOpNode; - py::class_> binary_op_node(*m, name.c_str()); + py::class_> binary_op_node( + *m, name.c_str()); binary_op_node.def(py::init<>()) .def(py::init()) - .def("a_mutable", py::overload_cast<>(&BinaryOpNodeT::a), py::return_value_policy::reference) - .def("a_const", py::overload_cast<>(&BinaryOpNodeT::a, py::const_), py::return_value_policy::reference) - .def("b_mutable", py::overload_cast<>(&BinaryOpNodeT::b), py::return_value_policy::reference) - .def("b_const", py::overload_cast<>(&BinaryOpNodeT::b, py::const_), py::return_value_policy::reference) + .def("a_mutable", + py::overload_cast<>(&BinaryOpNodeT::a), + py::return_value_policy::reference) + .def("a_const", + py::overload_cast<>(&BinaryOpNodeT::a, py::const_), + py::return_value_policy::reference) + .def("b_mutable", + py::overload_cast<>(&BinaryOpNodeT::b), + py::return_value_policy::reference) + .def("b_const", + py::overload_cast<>(&BinaryOpNodeT::b, py::const_), + py::return_value_policy::reference) .def("type", &BinaryOpNodeT::type) - .def("expr_fields_mutable", py::overload_cast<>(&BinaryOpNodeT::expr_fields)) - .def("expr_fields_const", py::overload_cast<>(&BinaryOpNodeT::expr_fields, py::const_)); + .def("expr_fields_mutable", + py::overload_cast<>(&BinaryOpNodeT::expr_fields)) + .def("expr_fields_const", + py::overload_cast<>(&BinaryOpNodeT::expr_fields, py::const_)); } template @@ -136,15 +156,24 @@ void DefineUnaryOpNode(py::module *m, absl::string_view node_name) { DefineExprNode(m, node_name); std::string name = "UnaryOpNode" + std::string(node_name); - py::class_> unary_op_node(*m, name.c_str()); + py::class_> unary_op_node(*m, + name.c_str()); unary_op_node.def(py::init<>()) .def(py::init()) .def("type", &UnaryOpNodeT::type) - .def("v_mutable", py::overload_cast<>(&UnaryOpNodeT::v), py::return_value_policy::reference) - .def("v_const", py::overload_cast<>(&UnaryOpNodeT::v, py::const_), py::return_value_policy::reference) - .def("expr_fields_mutable", py::overload_cast<>(&UnaryOpNodeT::expr_fields)) - .def("expr_fields_const", py::overload_cast<>(&UnaryOpNodeT::expr_fields, py::const_)) - .def("operands_mutable", py::overload_cast<>(&UnaryOpNodeT::operands), py::return_value_policy::reference) + .def("v_mutable", + py::overload_cast<>(&UnaryOpNodeT::v), + py::return_value_policy::reference) + .def("v_const", + py::overload_cast<>(&UnaryOpNodeT::v, py::const_), + py::return_value_policy::reference) + .def("expr_fields_mutable", + py::overload_cast<>(&UnaryOpNodeT::expr_fields)) + .def("expr_fields_const", + py::overload_cast<>(&UnaryOpNodeT::expr_fields, py::const_)) + .def("operands_mutable", + py::overload_cast<>(&UnaryOpNodeT::operands), + py::return_value_policy::reference) .def("operands_const", py::overload_cast<>(&UnaryOpNodeT::operands, py::const_), py::return_value_policy::reference); @@ -154,7 +183,9 @@ class ObjectWrapper : public Object { public: using Object::Object; - const char *type_info() const override { PYBIND11_OVERLOAD_PURE(const char *, Object, type_info); } + const char *type_info() const override { + PYBIND11_OVERLOAD_PURE(const char *, Object, type_info); + } }; class IrNodeWrapper : ir::IrNode { @@ -163,6 +194,8 @@ class IrNodeWrapper : ir::IrNode { class _Operation_Wrapper : ir::_Operation_ { public: - const char *func_type() const override { PYBIND11_OVERLOAD_PURE(const char *, ir::_Operation_, func_type); } + const char *func_type() const override { + PYBIND11_OVERLOAD_PURE(const char *, ir::_Operation_, func_type); + } }; } // namespace cinn::pybind diff --git a/paddle/cinn/pybind/common.cc b/paddle/cinn/pybind/common.cc index 08bc8afb72dc6..994308433bc7f 100644 --- a/paddle/cinn/pybind/common.cc +++ b/paddle/cinn/pybind/common.cc @@ -52,7 +52,10 @@ void BindTarget(py::module *m) { .def_readwrite("bits", &Target::bits) .def_readwrite("features", &Target::features) .def(py::init<>()) - .def(py::init &>()) + .def(py::init &>()) .def("defined", &Target::defined) .def("runtime_arch", &Target::runtime_arch); @@ -61,10 +64,14 @@ void BindTarget(py::module *m) { .def("DefaultTarget", &common::DefaultTarget); m->def("get_target", &cinn::runtime::CurrentTarget::GetCurrentTarget); - m->def("set_target", &cinn::runtime::CurrentTarget::SetCurrentTarget, py::arg("target")); + m->def("set_target", + &cinn::runtime::CurrentTarget::SetCurrentTarget, + py::arg("target")); py::enum_ os(target, "OS"); - os.value("Unk", Target::OS::Unk).value("Linux", Target::OS::Linux).value("Windows", Target::OS::Windows); + os.value("Unk", Target::OS::Unk) + .value("Linux", Target::OS::Linux) + .value("Windows", Target::OS::Windows); py::enum_ arch(target, "Arch"); arch.value("Unk", Target::Arch::Unk) @@ -73,10 +80,13 @@ void BindTarget(py::module *m) { .value("NVGPU", Target::Arch::NVGPU); py::enum_ bit(target, "Bit"); - bit.value("Unk", Target::Bit::Unk).value("k32", Target::Bit::k32).value("k64", Target::Bit::k64); + bit.value("Unk", Target::Bit::Unk) + .value("k32", Target::Bit::k32) + .value("k64", Target::Bit::k64); py::enum_ feature(target, "Feature"); - feature.value("JIT", Target::Feature::JIT).value("Debug", Target::Feature::Debug); + feature.value("JIT", Target::Feature::JIT) + .value("Debug", Target::Feature::Debug); m->def("is_compiled_with_cuda", cinn::runtime::IsCompiledWithCUDA); m->def("is_compiled_with_cudnn", cinn::runtime::IsCompiledWithCUDNN); @@ -85,7 +95,8 @@ void BindTarget(py::module *m) { void BindType(py::module *m) { py::class_ type(*m, "Type"); - type.def(py::init<>()).def(py::init()); + type.def(py::init<>()) + .def(py::init()); #define DEFINE_TYPE_METHOD(__name) (type = type.def(#__name, &Type::__name)) DEFINE_TYPE_METHOD(is_primitive); DEFINE_TYPE_METHOD(is_unk); @@ -116,7 +127,9 @@ void BindType(py::module *m) { .def("element_of", &Type::ElementOf) .def("pointer_of", &Type::PointerOf) .def("__str__", [](const Type &self) { return GetStreamCnt(self); }) - .def("__repr__", [](const Type &self) { return StringFormat("", GetStreamCnt(self).c_str()); }); + .def("__repr__", [](const Type &self) { + return StringFormat("", GetStreamCnt(self).c_str()); + }); py::enum_ type_t(type, "type_t"); type_t.value("unk", Type::type_t::Unk) @@ -144,7 +157,11 @@ void BindType(py::module *m) { m->def("Void", &common::Void) .def("Int", &common::Int, py::arg("bits"), py::arg("lanes") = 1) .def("UInt", &common::UInt, py::arg("bits"), py::arg("lanes") = 1) - .def("Float", &common::Float, py::arg("bits"), py::arg("lanes") = 1, py::arg("st") = Type::specific_type_t::None) + .def("Float", + &common::Float, + py::arg("bits"), + py::arg("lanes") = 1, + py::arg("st") = Type::specific_type_t::None) .def("Float16", &common::Float16, py::arg("lanes") = 1) .def("BFloat16", &common::BFloat16, py::arg("lanes") = 1) .def("Bool", &common::Bool, py::arg("lanes") = 1) @@ -152,31 +169,43 @@ void BindType(py::module *m) { m->def( "make_const", - [](const Type &type, int32_t val) -> Expr { return common::make_const(type, val); }, + [](const Type &type, int32_t val) -> Expr { + return common::make_const(type, val); + }, py::arg("type"), py::arg("val")) .def( "make_const", - [](const Type &type, int64_t val) -> Expr { return common::make_const(type, val); }, + [](const Type &type, int64_t val) -> Expr { + return common::make_const(type, val); + }, py::arg("type"), py::arg("val")) .def( "make_const", - [](const Type &type, float val) -> Expr { return common::make_const(type, val); }, + [](const Type &type, float val) -> Expr { + return common::make_const(type, val); + }, py::arg("type"), py::arg("val")) .def( "make_const", - [](const Type &type, double val) -> Expr { return common::make_const(type, val); }, + [](const Type &type, double val) -> Expr { + return common::make_const(type, val); + }, py::arg("type"), py::arg("val")) .def( "make_const", - [](const Type &type, bool val) -> Expr { return common::make_const(type, val); }, + [](const Type &type, bool val) -> Expr { + return common::make_const(type, val); + }, py::arg("type"), py::arg("val")); - m->def("type_of", [](absl::string_view dtype) { return common::Str2Type(dtype.data()); }); + m->def("type_of", [](absl::string_view dtype) { + return common::Str2Type(dtype.data()); + }); } void BindObject(py::module *m) { @@ -194,7 +223,8 @@ void BindShared(py::module *m) { .def("val", &common::RefCount::val); } -// TODO(wanghaipeng03) using true_type or false_type as tag disptcher losses semantic context +// TODO(wanghaipeng03) using true_type or false_type as tag disptcher losses +// semantic context template inline auto __binary_op_fn_dispatch(T1 x, T2 y, F fn, std::true_type) { return fn(ir::Expr(x), ir::Expr(y)).as_var_ref(); @@ -205,11 +235,13 @@ inline auto __binary_op_fn_dispatch(T1 x, T2 y, F fn, std::false_type) { } template -inline void __binary_op_visitor_dispatch(CINNValue &v, T1 lhs, T2 rhs, F fn, std::true_type) { +inline void __binary_op_visitor_dispatch( + CINNValue &v, T1 lhs, T2 rhs, F fn, std::true_type) { v = CINNValue(); } template -inline void __binary_op_visitor_dispatch(CINNValue &v, T1 lhs, T2 rhs, F fn, std::false_type) { +inline void __binary_op_visitor_dispatch( + CINNValue &v, T1 lhs, T2 rhs, F fn, std::false_type) { v.Set(fn(lhs, rhs)); } @@ -221,18 +253,26 @@ void BindCinnValue(py::module *m) { py::class_<_CINNValuePack_> cinn_value_pack(*m, "_CINNValuePack_"); cinn_value_pack.def_static("make", &_CINNValuePack_::Make) - .def("__getitem__", [](_CINNValuePack_ &self, int offset) { return self[offset]; }) - .def("__setitem__", [](_CINNValuePack_ &self, int offset, CINNValue &v) { self[offset] = v; }) + .def("__getitem__", + [](_CINNValuePack_ &self, int offset) { return self[offset]; }) + .def("__setitem__", + [](_CINNValuePack_ &self, int offset, CINNValue &v) { + self[offset] = v; + }) .def("add_value", &_CINNValuePack_::AddValue) .def("clear", &_CINNValuePack_::Clear) .def("size", &_CINNValuePack_::size) .def("__len__", &_CINNValuePack_::size) .def("type_info", &_CINNValuePack_::type_info); - py::class_> cinn_value_pack_shared(*m, "CINNValuePack"); + py::class_> + cinn_value_pack_shared(*m, "CINNValuePack"); cinn_value_pack_shared.def(py::init<_CINNValuePack_ *>()) - .def("__getitem__", [](CINNValuePack &self, int offset) { return self[offset]; }) - .def("__setitem__", [](CINNValuePack &self, int offset, CINNValue &v) { self[offset] = v; }); + .def("__getitem__", + [](CINNValuePack &self, int offset) { return self[offset]; }) + .def("__setitem__", [](CINNValuePack &self, int offset, CINNValue &v) { + self[offset] = v; + }); py::class_ cinn_value(*m, "CINNValue"); cinn_value.def(py::init<>()) @@ -251,16 +291,22 @@ void BindCinnValue(py::module *m) { .def(py::init()) .def(py::init()) .def("defined", &CINNValue::defined) - .def("to_double", [](CINNValue &self) { return static_cast(self); }) + .def("to_double", + [](CINNValue &self) { return static_cast(self); }) .def("to_float", [](CINNValue &self) { return static_cast(self); }) .def("to_int8", [](CINNValue &self) { return static_cast(self); }) - .def("to_int32", [](CINNValue &self) { return static_cast(self); }) - .def("to_int64", [](CINNValue &self) { return static_cast(self); }) - .def("to_void_p", [](CINNValue &self) { return static_cast(self); }) - .def("to_cinn_buffer_p", [](CINNValue &self) { return static_cast(self); }) + .def("to_int32", + [](CINNValue &self) { return static_cast(self); }) + .def("to_int64", + [](CINNValue &self) { return static_cast(self); }) + .def("to_void_p", + [](CINNValue &self) { return static_cast(self); }) + .def("to_cinn_buffer_p", + [](CINNValue &self) { return static_cast(self); }) .def("to_str", [](CINNValue &self) { return static_cast(self); }) .def("to_var", [](CINNValue &self) { return self.operator ir::Var(); }) - .def("to_expr", [](CINNValue &self) { return ir::Expr(self.operator ir::Expr()); }) + .def("to_expr", + [](CINNValue &self) { return ir::Expr(self.operator ir::Expr()); }) .def("set", &CINNValue::Set) .def("set", &CINNValue::Set) .def("set", &CINNValue::Set) @@ -277,24 +323,30 @@ void BindCinnValue(py::module *m) { using lhs_t = decltype(lhs); using rhs_t = decltype(rhs); using tag_t = - std::conditional_t::value || std::is_same::value || + std::conditional_t::value || + std::is_same::value || !std::is_same::value, std::true_type, std::false_type>; __binary_op_visitor_dispatch(v, lhs, rhs, fn, tag_t{}); }; -#define DEFINE_BINARY_OP(__op, __fn) \ - auto __op##_fn = [&](auto x, auto y) { \ - constexpr auto is_var_x = std::is_same, ir::Var>::value; \ - constexpr auto is_var_y = std::is_same, ir::Var>::value; \ - using tag_t = std::conditional_t; \ - return __binary_op_fn_dispatch(x, y, __fn, tag_t{}); \ - }; \ - cinn_value.def(#__op, [&](CINNValue &self, CINNValue &other) { \ - auto visitor = [&](auto x, auto y) { return binary_op_visitor(self, x, y, __op##_fn); }; \ - absl::visit(visitor, ConvertToVar(self), ConvertToVar(other)); \ - return self; \ +#define DEFINE_BINARY_OP(__op, __fn) \ + auto __op##_fn = [&](auto x, auto y) { \ + constexpr auto is_var_x = \ + std::is_same, ir::Var>::value; \ + constexpr auto is_var_y = \ + std::is_same, ir::Var>::value; \ + using tag_t = std:: \ + conditional_t; \ + return __binary_op_fn_dispatch(x, y, __fn, tag_t{}); \ + }; \ + cinn_value.def(#__op, [&](CINNValue &self, CINNValue &other) { \ + auto visitor = [&](auto x, auto y) { \ + return binary_op_visitor(self, x, y, __op##_fn); \ + }; \ + absl::visit(visitor, ConvertToVar(self), ConvertToVar(other)); \ + return self; \ }) DEFINE_BINARY_OP(__add__, [](auto x, auto y) { return x + y; }); diff --git a/paddle/cinn/pybind/framework.cc b/paddle/cinn/pybind/framework.cc index 0ea698694f5e4..e2fcc2c6b3cb7 100755 --- a/paddle/cinn/pybind/framework.cc +++ b/paddle/cinn/pybind/framework.cc @@ -36,8 +36,13 @@ namespace py = pybind11; using namespace cinn::hlir::framework; // NOLINT void BindFramework(pybind11::module *m) { py::class_(*m, "Operator") - .def("get_op_attrs", [](const std::string &key) { return Operator::GetAttrs(key); }) - .def("get_op_shape_attrs", [](const std::string &key) { return Operator::GetAttrs(key); }); + .def("get_op_attrs", + [](const std::string &key) { + return Operator::GetAttrs(key); + }) + .def("get_op_shape_attrs", [](const std::string &key) { + return Operator::GetAttrs(key); + }); py::class_>(*m, "OpValueType") .def("apply_strategy", @@ -49,7 +54,8 @@ void BindFramework(pybind11::module *m) { const std::vector> &output_shapes, const common::Target &target) { const Operator *op_ptr = Operator::Get(key); - auto impl = OpStrategy::SelectImpl(self[op_ptr](attrs, inputs, out_types, output_shapes, target)); + auto impl = OpStrategy::SelectImpl( + self[op_ptr](attrs, inputs, out_types, output_shapes, target)); std::vector temp_inputs; std::vector res; for (auto &tensor : inputs) { @@ -66,14 +72,22 @@ void BindFramework(pybind11::module *m) { input_output_names.push_back(input->name); } input_output_names.push_back(output_name); - std::vector funcs = hlir::framework::GetFuncFromImpl( - impl, common::CINNValuePack{temp_inputs}, res, input_output_names, key, target); + std::vector funcs = + hlir::framework::GetFuncFromImpl( + impl, + common::CINNValuePack{temp_inputs}, + res, + input_output_names, + key, + target); CHECK_EQ(funcs.size(), 1U); func = funcs[0]; } else { - common::CINNValuePack C = impl->fcompute(common::CINNValuePack{temp_inputs}); - poly::StageMap stages = C.back(); - // make sure all the tensors in the stages before schedule launch. + common::CINNValuePack C = + impl->fcompute(common::CINNValuePack{temp_inputs}); + poly::StageMap stages = C.back(); + // make sure all the tensors in the stages before schedule + // launch. for (int i = 0; i < C->size() - 1; i++) { ir::Expr temp = C[i]; stages->InsertLazily(temp.as_tensor_ref()); @@ -95,7 +109,7 @@ void BindFramework(pybind11::module *m) { const std::vector> &input_shapes, const AttrMapType &attrs) { const Operator *op_ptr = Operator::Get(key); - auto shapes = self[op_ptr](input_shapes, attrs); + auto shapes = self[op_ptr](input_shapes, attrs); return shapes; }); @@ -103,10 +117,13 @@ void BindFramework(pybind11::module *m) { .def(py::init<>()) .def_readwrite("attr_store", &NodeAttr::attr_store) .def("set_attr", - [](NodeAttr &self, const std::string &key, NodeAttr::attr_t value) { self.attr_store[key] = value; }) + [](NodeAttr &self, const std::string &key, NodeAttr::attr_t value) { + self.attr_store[key] = value; + }) .def("get_attr", [](NodeAttr &self, const std::string &key) { - CHECK_EQ(self.attr_store.count(key), 1) << "Didn't find value with key [" << key << "]."; + CHECK_EQ(self.attr_store.count(key), 1) + << "Didn't find value with key [" << key << "]."; return self.attr_store[key]; }) .def("__str__", [](NodeAttr &self) { return utils::GetStreamCnt(self); }); @@ -117,17 +134,21 @@ void BindFramework(pybind11::module *m) { [](Scope &self, const std::string &name, const Target &target) { auto t = self.GetTensor(name); py::dtype dt(common::Type2Str(t->type())); - py::array::ShapeContainer shape(t->shape().data().begin(), t->shape().data().end()); + py::array::ShapeContainer shape(t->shape().data().begin(), + t->shape().data().end()); py::array array(std::move(dt), std::move(shape)); auto *mutable_data = array.mutable_data(); if (target.arch == Target::Arch::X86) { - std::memcpy(mutable_data, t->data(), t->shape().numel() * t->type().bytes()); + std::memcpy(mutable_data, + t->data(), + t->shape().numel() * t->type().bytes()); } else if (target.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDA - CUDA_CALL(cudaMemcpy(mutable_data, - reinterpret_cast(t->mutable_data(target, t->type())), - t->shape().numel() * t->type().bytes(), - cudaMemcpyDeviceToHost)); + CUDA_CALL(cudaMemcpy( + mutable_data, + reinterpret_cast(t->mutable_data(target, t->type())), + t->shape().numel() * t->type().bytes(), + cudaMemcpyDeviceToHost)); #else LOG(FATAL) <<"To use CUDA backends, you need to set WITH_CUDA ON!"; #endif @@ -141,56 +162,74 @@ void BindFramework(pybind11::module *m) { py::class_>(*m, "SharedTensor"); py::class_>(*m, "Tensor") .def(py::init<>()) - .def("shape", [](hlir::framework::Tensor &self) { return self->shape().data(); }) - .def("set_type", [](hlir::framework::Tensor &self, Type type) { self->set_type(type); }) - .def("numpy", - [](hlir::framework::Tensor &self, const common::Target &target) { - std::string type_str = common::Type2Str(self->type()); - if (type_str == "bfloat16") { - type_str = "uint16"; - } - py::dtype dt(type_str); - py::array::ShapeContainer shape(self->shape().data().begin(), self->shape().data().end()); - py::array array(std::move(dt), std::move(shape)); - void *array_data = array.mutable_data(); - if (target.arch == Target::Arch::X86) { - std::memcpy(array_data, self->data(), self->shape().numel() * self->type().bytes()); - } else if (target.arch == Target::Arch::NVGPU) { + .def("shape", + [](hlir::framework::Tensor &self) { return self->shape().data(); }) + .def("set_type", + [](hlir::framework::Tensor &self, Type type) { + self->set_type(type); + }) + .def( + "numpy", + [](hlir::framework::Tensor &self, const common::Target &target) { + std::string type_str = common::Type2Str(self->type()); + if (type_str == "bfloat16") { + type_str = "uint16"; + } + py::dtype dt(type_str); + py::array::ShapeContainer shape(self->shape().data().begin(), + self->shape().data().end()); + py::array array(std::move(dt), std::move(shape)); + void *array_data = array.mutable_data(); + if (target.arch == Target::Arch::X86) { + std::memcpy(array_data, + self->data(), + self->shape().numel() * self->type().bytes()); + } else if (target.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDA - CUDA_CALL(cudaMemcpy(array_data, - self->data(), - self->shape().numel() * self->type().bytes(), - cudaMemcpyDeviceToHost)); + CUDA_CALL(cudaMemcpy(array_data, + self->data(), + self->shape().numel() * self->type().bytes(), + cudaMemcpyDeviceToHost)); #else LOG(FATAL) <<"To use CUDA backends, you need to set WITH_CUDA ON!"; #endif - } else { - CINN_NOT_IMPLEMENTED - } - return array; - }) - .def("from_numpy", [](hlir::framework::Tensor &self, py::array array, const common::Target &target) { - CHECK(array.dtype().is(py::dtype(common::Type2Str(self->type())))) - << "currently only support float32 data type as input"; - hlir::framework::shape_t shape; - std::copy_n(array.shape(), array.ndim(), std::back_inserter(shape)); - CHECK_EQ(std::accumulate(shape.begin(), shape.end(), 1, [](int32_t a, int32_t b) { return a * b; }), - self->shape().numel()); - auto *data = self->mutable_data(target, self->type()); - if (target.arch == Target::Arch::X86) { - std::memcpy(data, array.data(), self->shape().numel() * self->type().bytes()); - } else if (target.arch == Target::Arch::NVGPU) { + } else { + CINN_NOT_IMPLEMENTED + } + return array; + }) + .def( + "from_numpy", + [](hlir::framework::Tensor &self, + py::array array, + const common::Target &target) { + CHECK(array.dtype().is(py::dtype(common::Type2Str(self->type())))) + << "currently only support float32 data type as input"; + hlir::framework::shape_t shape; + std::copy_n(array.shape(), array.ndim(), std::back_inserter(shape)); + CHECK_EQ( + std::accumulate(shape.begin(), + shape.end(), + 1, + [](int32_t a, int32_t b) { return a * b; }), + self->shape().numel()); + auto *data = self->mutable_data(target, self->type()); + if (target.arch == Target::Arch::X86) { + std::memcpy(data, + array.data(), + self->shape().numel() * self->type().bytes()); + } else if (target.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDA - CUDA_CALL(cudaMemcpy(reinterpret_cast(data), - reinterpret_cast(array.data()), - self->shape().numel() * self->type().bytes(), - cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy(reinterpret_cast(data), + reinterpret_cast(array.data()), + self->shape().numel() * self->type().bytes(), + cudaMemcpyHostToDevice)); #else LOG(FATAL) <<"To use CUDA backends, you need to set WITH_CUDA ON!"; #endif - } else { - CINN_NOT_IMPLEMENTED - } - }); + } else { + CINN_NOT_IMPLEMENTED + } + }); } } // namespace cinn::pybind diff --git a/paddle/cinn/pybind/frontend.cc b/paddle/cinn/pybind/frontend.cc index 386ef9957a021..dcddbd47e2765 100644 --- a/paddle/cinn/pybind/frontend.cc +++ b/paddle/cinn/pybind/frontend.cc @@ -49,7 +49,7 @@ using namespace cinn::frontend; // NOLINT // used in this file only for py function register static const char *SnakeName(const char *name) { static char buf[256]; - char *p = buf; + char *p = buf; const char *q = name; for (; *q; q++, p++) { if ((*q >= 'A') && (*q <= 'Z')) { @@ -94,12 +94,15 @@ void BindFrontend(pybind11::module *m) { }); py::class_(*m, "Placeholder") // - .def(py::init &, absl::string_view>(), + .def(py::init &, + absl::string_view>(), py::arg("type"), py::arg("shape"), py::arg("id") = "") .def("shape", &Placeholder::shape) - .def("type", [](Placeholder &self) { return common::Type2Str(self.type()); }) + .def("type", + [](Placeholder &self) { return common::Type2Str(self.type()); }) .def("id", &Placeholder::id) .def("name", &Placeholder::id) .def("__str__", [](const Placeholder &self) { return self.id(); }); @@ -107,29 +110,48 @@ void BindFrontend(pybind11::module *m) { py::implicitly_convertible(); py::class_(*m, "Instruction") // - .def("set_attr", [](Instruction &self, const std::string &key, int x) { self.SetAttr(key, x); }) - .def("set_attr", [](Instruction &self, const std::string &key, float x) { self.SetAttr(key, x); }) - .def("set_attr", [](Instruction &self, const std::string &key, const std::string &x) { self.SetAttr(key, x); }) .def("set_attr", - [](Instruction &self, const std::string &key, const std::vector &x) { self.SetAttr(key, x); }) + [](Instruction &self, const std::string &key, int x) { + self.SetAttr(key, x); + }) + .def("set_attr", + [](Instruction &self, const std::string &key, float x) { + self.SetAttr(key, x); + }) + .def("set_attr", + [](Instruction &self, const std::string &key, const std::string &x) { + self.SetAttr(key, x); + }) + .def("set_attr", + [](Instruction &self, + const std::string &key, + const std::vector &x) { self.SetAttr(key, x); }) .def("set_attr", - [](Instruction &self, const std::string &key, const std::vector &x) { self.SetAttr(key, x); }) + [](Instruction &self, + const std::string &key, + const std::vector &x) { self.SetAttr(key, x); }) .def("set_attr", - [](Instruction &self, const std::string &key, const std::vector &x) { self.SetAttr(key, x); }) + [](Instruction &self, + const std::string &key, + const std::vector &x) { self.SetAttr(key, x); }) .def("get_attr_int32", &Instruction::GetAttrs) .def("get_attr_fp32", &Instruction::GetAttrs) .def("get_attr_str", &Instruction::GetAttrs) .def("get_attr_int32s", &Instruction::GetAttrs>) .def("get_attr_fp32s", &Instruction::GetAttrs>) .def("get_attr_strs", &Instruction::GetAttrs>) - .def("__str__", [](Instruction &self) { return utils::GetStreamCnt(self); }) + .def("__str__", + [](Instruction &self) { return utils::GetStreamCnt(self); }) .def("get_op_type", [](Instruction &self) { return self->op_type; }) .def("get_inputs", [](Instruction &self) { return self->inputs; }) .def("get_outputs", [](Instruction &self) { return self->outputs; }); - m->def("get_default_program_pass", []() { return DefaultTrainingOptimizeOptions().program_passes; }) - .def("get_default_graph_pass", []() { return DefaultTrainingOptimizeOptions().graph_passes; }) - .def("get_default_opfusion_pass", []() { return DefaultOpFusionPasses(); }); + m->def("get_default_program_pass", + []() { return DefaultTrainingOptimizeOptions().program_passes; }) + .def("get_default_graph_pass", + []() { return DefaultTrainingOptimizeOptions().graph_passes; }) + .def("get_default_opfusion_pass", + []() { return DefaultOpFusionPasses(); }); py::class_(*m, "Program") .def(py::init<>()) @@ -160,7 +182,8 @@ void BindFrontend(pybind11::module *m) { const std::vector &tensor_inputs, const std::vector &input_data, const std::vector &tensor_outputs, - const std::vector &passes = std::vector{}, + const std::vector &passes = + std::vector{}, std::shared_ptr scope = nullptr) { cinn::runtime::CurrentTarget::SetCurrentTarget(target); std::unordered_set fetch_ids; @@ -185,28 +208,32 @@ void BindFrontend(pybind11::module *m) { // Keep compile option same as paddle hlir::framework::GraphCompiler::CompileOptions options; options.with_instantiate_variables = true; - options.remove_unused_variables = false; - auto gc_fetch_ids = fetch_ids; - const auto &result = gc.Build(options, std::move(gc_fetch_ids)); - const auto &program = result.runtime_program; + options.remove_unused_variables = false; + auto gc_fetch_ids = fetch_ids; + const auto &result = gc.Build(options, std::move(gc_fetch_ids)); + const auto &program = result.runtime_program; for (size_t i = 0; i < tensor_inputs.size(); i++) { auto in_tensor = scope->GetTensor(tensor_inputs[i]->id); - auto dtype = tensor_inputs[i]->type; - auto *data = in_tensor->mutable_data(target, dtype); + auto dtype = tensor_inputs[i]->type; + auto *data = in_tensor->mutable_data(target, dtype); CHECK_EQ(input_data[i].size(), in_tensor->shape().numel()) << "The size of tensor [" << tensor_inputs[i]->id << "] is different with the input data's size! Please check."; if (target.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDA - CUDA_CALL(cudaMemcpy( - data, input_data[i].data(), in_tensor->shape().numel() * dtype.bytes(), cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy(data, + input_data[i].data(), + in_tensor->shape().numel() * dtype.bytes(), + cudaMemcpyHostToDevice)); #else LOG(FATAL) <<"To use CUDA backends, you need to set WITH_CUDA ON!"; #endif } else if (target.arch == Target::Arch::X86) { - memcpy(data, input_data[i].data(), - in_tensor->shape().numel() * dtype.bytes()); // All random data + memcpy(data, + input_data[i].data(), + in_tensor->shape().numel() * + dtype.bytes()); // All random data } else { CINN_NOT_IMPLEMENTED } @@ -219,7 +246,8 @@ void BindFrontend(pybind11::module *m) { outputs.back()->set_type(tensor_outputs[i]->type); // Change Tensor from 1D to 0D if (outputs.back()->shape().numel() == 1 && - zero_dim_outputs.find(tensor_outputs[i]->id) != zero_dim_outputs.end()) { + zero_dim_outputs.find(tensor_outputs[i]->id) != + zero_dim_outputs.end()) { outputs.back()->Resize({}); } } @@ -231,7 +259,7 @@ void BindFrontend(pybind11::module *m) { py::arg("feed_datas"), py::arg("fetch_list"), py::arg("passes") = std::vector{}, - py::arg("scope") = nullptr) + py::arg("scope") = nullptr) .def("apply_pass", [](Program &self, const std::unordered_set &fetch_ids, @@ -245,121 +273,136 @@ void BindFrontend(pybind11::module *m) { * @brief Test the performance of a single-op program * @param self The program built with only one op * @param target The Target that controls the backends to execute on - * @param tensor_inputs The vector that contains all input Variables. Must be on CPU - * @param input_data The vector that contains each input Variable's data(stored as py::array) + * @param tensor_inputs The vector that contains all input Variables. Must + * be on CPU + * @param input_data The vector that contains each input Variable's + * data(stored as py::array) * @param tensor_out The output Variable. - * @param repeat_ The number of executing time. Increase it to avoid testing noise. - * @param info The string to be print before testing. Usually it implyies the kind of op and - * input variable's shape. + * @param repeat_ The number of executing time. Increase it to avoid + * testing noise. + * @param info The string to be print before testing. Usually it implyies + * the kind of op and input variable's shape. * * @return The output tensor after executing the op. * * @note * This function is for user to test single op performance on python. - * To learn more about how to test op's benchmark, see '/python/tests/test_op_benchmark.py' + * To learn more about how to test op's benchmark, see + * '/python/tests/test_op_benchmark.py' * */ - .def("test_benchmark", - [](Program &self, - const common::Target &target, - const std::vector &tensor_inputs, - const std::vector &input_data, - const Variable &tensor_out, - int repeat_, - const std::string &info) { - std::shared_ptr g(new hlir::framework::Graph(self, target)); - hlir::framework::ApplyPass(g.get(), "InferShape"); - std::shared_ptr scope = hlir::framework::BuildScope(target, g); - hlir::framework::GraphCompiler gc(target, scope, g); - auto program = gc.Build(); - for (size_t i = 0; i < tensor_inputs.size(); i++) { - auto in_tensor = scope->GetTensor(tensor_inputs[i]->id); - auto *data = in_tensor->mutable_data(target); - CHECK_EQ(input_data[i].size(), in_tensor->shape().numel()) - << "The size of tensor [" << tensor_inputs[i]->id - << "] is different with the input data's size! Please check."; - if (target.arch == Target::Arch::NVGPU) { + .def( + "test_benchmark", + [](Program &self, + const common::Target &target, + const std::vector &tensor_inputs, + const std::vector &input_data, + const Variable &tensor_out, + int repeat_, + const std::string &info) { + std::shared_ptr g( + new hlir::framework::Graph(self, target)); + hlir::framework::ApplyPass(g.get(), "InferShape"); + std::shared_ptr scope = + hlir::framework::BuildScope(target, g); + hlir::framework::GraphCompiler gc(target, scope, g); + auto program = gc.Build(); + for (size_t i = 0; i < tensor_inputs.size(); i++) { + auto in_tensor = scope->GetTensor(tensor_inputs[i]->id); + auto *data = in_tensor->mutable_data(target); + CHECK_EQ(input_data[i].size(), in_tensor->shape().numel()) + << "The size of tensor [" << tensor_inputs[i]->id + << "] is different with the input data's size! Please check."; + if (target.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDA - CUDA_CALL(cudaMemcpy(reinterpret_cast(data), - input_data[i].data(), - in_tensor->shape().numel() * sizeof(float), - cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy(reinterpret_cast(data), + input_data[i].data(), + in_tensor->shape().numel() * sizeof(float), + cudaMemcpyHostToDevice)); #else LOG(FATAL) <<"To use CUDA backends, you need to set WITH_CUDA ON!"; #endif - } else if (target.arch == Target::Arch::X86) { - for (size_t j = 0; j < in_tensor->shape().numel(); j++) { - data[j] = reinterpret_cast(input_data[i].data())[j]; // All random data - } - } else { - CINN_NOT_IMPLEMENTED - } - } - VLOG(3) << info; - program->ExecuteTest(repeat_); - auto out = scope->GetTensor(tensor_out->id); - return out; - }) - .def("test_benchmark_with_code", - [](Program &self, - const common::Target &target, - const std::vector &tensor_inputs, - const std::vector &input_data, - const Variable &tensor_out, - int repeat_, - const std::string &info, - const std::string &code) { - // std::shared_ptr g(new hlir::framework::Graph(self, target)); - // hlir::framework::ApplyPass(g.get(), "InferShape"); - std::unordered_set fetch_ids; - auto graph = cinn::frontend::Optimize(&self, fetch_ids, target); - std::shared_ptr scope = hlir::framework::BuildScope(target, graph); + } else if (target.arch == Target::Arch::X86) { + for (size_t j = 0; j < in_tensor->shape().numel(); j++) { + data[j] = reinterpret_cast( + input_data[i].data())[j]; // All random data + } + } else { + CINN_NOT_IMPLEMENTED + } + } + VLOG(3) << info; + program->ExecuteTest(repeat_); + auto out = scope->GetTensor(tensor_out->id); + return out; + }) + .def( + "test_benchmark_with_code", + [](Program &self, + const common::Target &target, + const std::vector &tensor_inputs, + const std::vector &input_data, + const Variable &tensor_out, + int repeat_, + const std::string &info, + const std::string &code) { + // std::shared_ptr g(new + // hlir::framework::Graph(self, target)); + // hlir::framework::ApplyPass(g.get(), "InferShape"); + std::unordered_set fetch_ids; + auto graph = cinn::frontend::Optimize(&self, fetch_ids, target); + std::shared_ptr scope = + hlir::framework::BuildScope(target, graph); - hlir::framework::GraphCompiler gc(target, scope, graph); - auto program = gc.Build(code); - for (size_t i = 0; i < tensor_inputs.size(); i++) { - auto in_tensor = scope->GetTensor(tensor_inputs[i]->id); - auto *data = in_tensor->mutable_data(target); - CHECK_EQ(input_data[i].size(), in_tensor->shape().numel()) - << "The size of tensor [" << tensor_inputs[i]->id - << "] is different with the input data's size! Please check."; - if (target.arch == Target::Arch::NVGPU) { + hlir::framework::GraphCompiler gc(target, scope, graph); + auto program = gc.Build(code); + for (size_t i = 0; i < tensor_inputs.size(); i++) { + auto in_tensor = scope->GetTensor(tensor_inputs[i]->id); + auto *data = in_tensor->mutable_data(target); + CHECK_EQ(input_data[i].size(), in_tensor->shape().numel()) + << "The size of tensor [" << tensor_inputs[i]->id + << "] is different with the input data's size! Please check."; + if (target.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDA - CUDA_CALL(cudaMemcpy(reinterpret_cast(data), - input_data[i].data(), - in_tensor->shape().numel() * sizeof(float), - cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy(reinterpret_cast(data), + input_data[i].data(), + in_tensor->shape().numel() * sizeof(float), + cudaMemcpyHostToDevice)); #else LOG(FATAL) <<"To use CUDA backends, you need to set WITH_CUDA ON!"; #endif - } else if (target.arch == Target::Arch::X86) { - for (size_t j = 0; j < in_tensor->shape().numel(); j++) { - data[j] = reinterpret_cast(input_data[i].data())[j]; // All random data - } - } else { - CINN_NOT_IMPLEMENTED - } - } - VLOG(3) << info; - program->ExecuteTest(repeat_); - auto out = scope->GetTensor(tensor_out->id); - return out; - }) + } else if (target.arch == Target::Arch::X86) { + for (size_t j = 0; j < in_tensor->shape().numel(); j++) { + data[j] = reinterpret_cast( + input_data[i].data())[j]; // All random data + } + } else { + CINN_NOT_IMPLEMENTED + } + } + VLOG(3) << info; + program->ExecuteTest(repeat_); + auto out = scope->GetTensor(tensor_out->id); + return out; + }) .def("test_generate_code", [](Program &self, const common::Target &target, const std::vector &tensor_inputs, const std::vector &input_data, const Variable &tensor_out) { - std::shared_ptr g(new hlir::framework::Graph(self, target)); + std::shared_ptr g( + new hlir::framework::Graph(self, target)); hlir::framework::ApplyPass(g.get(), "InferShape"); - std::shared_ptr scope = hlir::framework::BuildScope(target, g); + std::shared_ptr scope = + hlir::framework::BuildScope(target, g); hlir::framework::GraphCompiler gc(target, scope, g); return gc.GenSourceCode(); }); py::class_(*m, "Interpreter") - .def(py::init &, const std::vector &>(), + .def(py::init &, + const std::vector &>(), py::arg("input_names"), py::arg("input_shapes")) // .def("load_paddle_model", @@ -449,29 +492,40 @@ void BindFrontend(pybind11::module *m) { #undef PY_REGISTER_REDUCE_CINN_FUNC // clang-format on .def(py::init(), py::arg("name") = "") + .def("create_input", + static_cast &, + const std::string &)>( + &NetBuilder::CreateInput), + py::arg("type"), + py::arg("shape"), + py::arg("id_hint")) .def( "create_input", - static_cast &, const std::string &)>( - &NetBuilder::CreateInput), - py::arg("type"), - py::arg("shape"), - py::arg("id_hint")) - .def( - "create_input", - [](NetBuilder &self, const std::string &type, const std::vector &shape, const std::string &id) { + [](NetBuilder &self, + const std::string &type, + const std::vector &shape, + const std::string &id) { return self.CreateInput(cinn::common::Str2Type(type), shape, id); }, py::arg("type"), py::arg("shape"), py::arg("id_hint")) - .def("create_input", static_cast(&NetBuilder::CreateInput)) + .def("create_input", + static_cast( + &NetBuilder::CreateInput)) .def("build", &NetBuilder::Build, py::arg("in_reverse") = false) .def("name", &NetBuilder::name) .def("__str__", [](NetBuilder &self) { return self.name(); }) - .def("append_instruction", &NetBuilder::AppendInstruction, py::arg("instr")) + .def("append_instruction", + &NetBuilder::AppendInstruction, + py::arg("instr")) .def("fill_constant", - static_cast &, const std::string &, const std::string &, const std::string &, bool)>( + static_cast &, + const std::string &, + const std::string &, + const std::string &, + bool)>( &NetBuilder::FillConstant), py::arg("shape"), py::arg("value"), @@ -479,11 +533,15 @@ void BindFrontend(pybind11::module *m) { py::arg("dtype"), py::arg("force_cpu") = false) .def("broadcast_to", - static_cast &)>(&NetBuilder::BroadcastTo), + static_cast &)>( + &NetBuilder::BroadcastTo), py::arg("x"), py::arg("out_shape")) .def("broadcast_to", - static_cast &, const std::vector &)>( + static_cast &, + const std::vector &)>( &NetBuilder::BroadcastTo), py::arg("x"), py::arg("out_shape"), @@ -491,23 +549,52 @@ void BindFrontend(pybind11::module *m) { .def("concat", &NetBuilder::Concat, py::arg("xs"), py::arg("axis") = 0) .def("reshape", &NetBuilder::Reshape, py::arg("x"), py::arg("shape")) .def("transpose", &NetBuilder::Transpose, py::arg("x"), py::arg("axis")) - .def("top_k", &NetBuilder::TopK, py::arg("x"), py::arg("k"), py::arg("axis"), py::arg("largest")) - .def("sort", &NetBuilder::Sort, py::arg("operand"), py::arg("axis"), py::arg("is_ascend")) - .def("argsort", &NetBuilder::ArgSort, py::arg("operand"), py::arg("axis"), py::arg("is_ascend")) + .def("top_k", + &NetBuilder::TopK, + py::arg("x"), + py::arg("k"), + py::arg("axis"), + py::arg("largest")) + .def("sort", + &NetBuilder::Sort, + py::arg("operand"), + py::arg("axis"), + py::arg("is_ascend")) + .def("argsort", + &NetBuilder::ArgSort, + py::arg("operand"), + py::arg("axis"), + py::arg("is_ascend")) .def("slice", &NetBuilder::Slice, py::arg("x"), py::arg("axes"), py::arg("starts"), py::arg("ends"), - py::arg("infer_flags") = std::vector{}, - py::arg("strides") = std::vector{}, + py::arg("infer_flags") = std::vector{}, + py::arg("strides") = std::vector{}, py::arg("decrease_axis") = std::vector{}) .def("reverse", &NetBuilder::Reverse, py::arg("x"), py::arg("axis")) - .def("resize", &NetBuilder::Resize, py::arg("x"), py::arg("out_shape"), py::arg("mode") = "bilinear") - .def("select", &NetBuilder::Select, py::arg("condition"), py::arg("true_value"), py::arg("false_value")) - .def("split", &NetBuilder::Split, py::arg("x"), py::arg("num_or_sections"), py::arg("axis") = 0) - .def("gather", &NetBuilder::Gather, py::arg("x"), py::arg("index"), py::arg("axis") = 0) + .def("resize", + &NetBuilder::Resize, + py::arg("x"), + py::arg("out_shape"), + py::arg("mode") = "bilinear") + .def("select", + &NetBuilder::Select, + py::arg("condition"), + py::arg("true_value"), + py::arg("false_value")) + .def("split", + &NetBuilder::Split, + py::arg("x"), + py::arg("num_or_sections"), + py::arg("axis") = 0) + .def("gather", + &NetBuilder::Gather, + py::arg("x"), + py::arg("index"), + py::arg("axis") = 0) .def("slice_assign", &NetBuilder::SliceAssign, py::arg("x"), @@ -532,8 +619,8 @@ void BindFrontend(pybind11::module *m) { &NetBuilder::IsClose, py::arg("x"), py::arg("y"), - py::arg("rtol") = 1e-05f, - py::arg("atol") = 1e-08f, + py::arg("rtol") = 1e-05f, + py::arg("atol") = 1e-08f, py::arg("equal_nan") = false) .def("mul", &NetBuilder::Mul, @@ -541,60 +628,79 @@ void BindFrontend(pybind11::module *m) { py::arg("y"), py::arg("x_num_col_dims") = 1, py::arg("y_num_col_dims") = 1, - py::arg("is_infer") = false) + py::arg("is_infer") = false) .def("elementwise_add_grad", &NetBuilder::ElementwiseAddGrad, py::arg("dout"), py::arg("x"), py::arg("y"), py::arg("axis") = -1) - .def("relu6", &NetBuilder::Relu6, py::arg("a"), py::arg("threshold") = 6.0f) + .def("relu6", + &NetBuilder::Relu6, + py::arg("a"), + py::arg("threshold") = 6.0f) .def("gelu", &NetBuilder::Gelu, py::arg("x")) - .def("squeeze", &NetBuilder::Squeeze, py::arg("a"), py::arg("axes") = std::vector{}) - .def("expand_dims", &NetBuilder::ExpandDims, py::arg("x"), py::arg("axes")) - .def("argmax", &NetBuilder::Argmax, py::arg("x"), py::arg("axis"), py::arg("keep_dim") = false) - .def("argmin", &NetBuilder::Argmin, py::arg("x"), py::arg("axis"), py::arg("keep_dim") = false) - .def("lookup_table", &NetBuilder::LookupTable, py::arg("table"), py::arg("ids"), py::arg("padding_idx")) + .def("squeeze", + &NetBuilder::Squeeze, + py::arg("a"), + py::arg("axes") = std::vector{}) + .def( + "expand_dims", &NetBuilder::ExpandDims, py::arg("x"), py::arg("axes")) + .def("argmax", + &NetBuilder::Argmax, + py::arg("x"), + py::arg("axis"), + py::arg("keep_dim") = false) + .def("argmin", + &NetBuilder::Argmin, + py::arg("x"), + py::arg("axis"), + py::arg("keep_dim") = false) + .def("lookup_table", + &NetBuilder::LookupTable, + py::arg("table"), + py::arg("ids"), + py::arg("padding_idx")) .def("one_hot", &NetBuilder::OneHot, py::arg("indices"), py::arg("on_value"), py::arg("off_value"), py::arg("depth"), - py::arg("axis") = -1, + py::arg("axis") = -1, py::arg("dtype") = "float32") .def("conv2d", &NetBuilder::Conv2d, py::arg("x"), py::arg("w"), - py::arg("strides") = std::vector{1, 1}, - py::arg("paddings") = std::vector{0, 0}, - py::arg("dilations") = std::vector{1, 1}, - py::arg("groups") = 1, - py::arg("data_format") = "NCHW", + py::arg("strides") = std::vector{1, 1}, + py::arg("paddings") = std::vector{0, 0}, + py::arg("dilations") = std::vector{1, 1}, + py::arg("groups") = 1, + py::arg("data_format") = "NCHW", py::arg("padding_algorithm") = "EXPLICIT") .def("depthwise_conv2d", &NetBuilder::DepthwiseConv2d, py::arg("x"), py::arg("w"), - py::arg("strides") = std::vector{1, 1}, - py::arg("paddings") = std::vector{0, 0}, - py::arg("dilations") = std::vector{1, 1}, - py::arg("groups") = 1, - py::arg("data_format") = "NCHW", + py::arg("strides") = std::vector{1, 1}, + py::arg("paddings") = std::vector{0, 0}, + py::arg("dilations") = std::vector{1, 1}, + py::arg("groups") = 1, + py::arg("data_format") = "NCHW", py::arg("padding_algorithm") = "EXPLICIT") .def("pool2d", &NetBuilder::Pool2d, py::arg("x"), py::arg("pooling_type"), py::arg("kernel_size"), - py::arg("stride") = std::vector{1, 1}, - py::arg("padding") = std::vector{0, 0}, - py::arg("ceil_mode") = false, - py::arg("exclusive") = true, - py::arg("global_pooling") = false, - py::arg("data_format") = "NCHW", - py::arg("adaptive") = false, + py::arg("stride") = std::vector{1, 1}, + py::arg("padding") = std::vector{0, 0}, + py::arg("ceil_mode") = false, + py::arg("exclusive") = true, + py::arg("global_pooling") = false, + py::arg("data_format") = "NCHW", + py::arg("adaptive") = false, py::arg("padding_algorithm") = "EXPLICIT") .def("pool2d_grad", &NetBuilder::Pool2dGrad, @@ -603,13 +709,13 @@ void BindFrontend(pybind11::module *m) { py::arg("dy"), py::arg("pooling_type"), py::arg("kernel_size"), - py::arg("stride") = std::vector{1, 1}, - py::arg("padding") = std::vector{0, 0}, - py::arg("ceil_mode") = false, - py::arg("exclusive") = true, - py::arg("global_pooling") = false, - py::arg("data_format") = "NCHW", - py::arg("adaptive") = false, + py::arg("stride") = std::vector{1, 1}, + py::arg("padding") = std::vector{0, 0}, + py::arg("ceil_mode") = false, + py::arg("exclusive") = true, + py::arg("global_pooling") = false, + py::arg("data_format") = "NCHW", + py::arg("adaptive") = false, py::arg("padding_algorithm") = "EXPLICIT") .def("batchnorm", &NetBuilder::BatchNorm, @@ -618,10 +724,10 @@ void BindFrontend(pybind11::module *m) { py::arg("bias"), py::arg("mean"), py::arg("variance"), - py::arg("epsilon") = 1e-5f, - py::arg("momentum") = 0.9f, + py::arg("epsilon") = 1e-5f, + py::arg("momentum") = 0.9f, py::arg("data_layout") = "NCHW", - py::arg("is_test") = true) + py::arg("is_test") = true) .def("batch_norm_grad", &NetBuilder::BatchNormGrad, py::arg("dy"), @@ -629,24 +735,24 @@ void BindFrontend(pybind11::module *m) { py::arg("scale"), py::arg("save_mean"), py::arg("save_variance"), - py::arg("epsilon") = 1e-5, + py::arg("epsilon") = 1e-5, py::arg("data_layout") = "NCHW") .def("scale", &NetBuilder::Scale, py::arg("x"), - py::arg("scale") = 1.0f, - py::arg("bias") = 0.0f, + py::arg("scale") = 1.0f, + py::arg("bias") = 0.0f, py::arg("bias_after_scale") = true) .def("softmax", &NetBuilder::Softmax, py::arg("x"), - py::arg("axes") = std::vector{-1}, - py::arg("mode") = "fast", + py::arg("axes") = std::vector{-1}, + py::arg("mode") = "fast", py::arg("data_format") = "AnyLayout") .def("dropout_infer", &NetBuilder::DropoutInfer, py::arg("x"), - py::arg("dropout_prob") = 0.5f, + py::arg("dropout_prob") = 0.5f, py::arg("dropout_implementation") = "downgrade_in_infer") .def("relu_grad", &NetBuilder::ReluGrad, py::arg("dout"), py::arg("x")) .def("sum", &NetBuilder::Sum, py::arg("inputs")) @@ -656,22 +762,30 @@ void BindFrontend(pybind11::module *m) { py::arg("y"), py::arg("transpose_x") = false, py::arg("transpose_y") = false, - py::arg("alpha") = 1.0f) + py::arg("alpha") = 1.0f) .def("conv", &NetBuilder::Conv, py::arg("x"), py::arg("w"), - py::arg("strides") = std::vector{1, 1}, - py::arg("paddings") = std::vector{0, 0}, - py::arg("dilations") = std::vector{1, 1}, - py::arg("groups") = 1, - py::arg("conv_type") = "forward", - py::arg("data_format") = "NCHW", + py::arg("strides") = std::vector{1, 1}, + py::arg("paddings") = std::vector{0, 0}, + py::arg("dilations") = std::vector{1, 1}, + py::arg("groups") = 1, + py::arg("conv_type") = "forward", + py::arg("data_format") = "NCHW", py::arg("padding_algorithm") = "EXPLICIT", - py::arg("output_shape") = std::vector{}) + py::arg("output_shape") = std::vector{}) .def("cast", &NetBuilder::Cast, py::arg("x"), py::arg("dtype")) - .def("bitcast_convert", &NetBuilder::BitcastConvert, py::arg("x"), py::arg("dtype")) - .def("arange", &NetBuilder::Arange, py::arg("start"), py::arg("stop"), py::arg("step"), py::arg("dtype")) + .def("bitcast_convert", + &NetBuilder::BitcastConvert, + py::arg("x"), + py::arg("dtype")) + .def("arange", + &NetBuilder::Arange, + py::arg("start"), + py::arg("stop"), + py::arg("step"), + py::arg("dtype")) .def("gather_nd", &NetBuilder::GatherNd, py::arg("x"), py::arg("index")) .def("cbrt", &NetBuilder::Cbrt, py::arg("x")) .def("clz", &NetBuilder::Clz, py::arg("x")) @@ -680,52 +794,66 @@ void BindFrontend(pybind11::module *m) { .def("gaussian_random", &NetBuilder::GaussianRandom, py::arg("shape"), - py::arg("mean") = 0.0f, - py::arg("std") = 1.0f, - py::arg("seed") = 0, + py::arg("mean") = 0.0f, + py::arg("std") = 1.0f, + py::arg("seed") = 0, py::arg("dtype") = "float32") .def("uniform_random", &NetBuilder::UniformRandom, py::arg("shape"), - py::arg("min") = -1.0f, - py::arg("max") = 1.0f, - py::arg("seed") = 0, - py::arg("dtype") = "float32", - py::arg("diag_num") = 0, + py::arg("min") = -1.0f, + py::arg("max") = 1.0f, + py::arg("seed") = 0, + py::arg("dtype") = "float32", + py::arg("diag_num") = 0, py::arg("diag_step") = 0, - py::arg("diag_val") = 1.0f) + py::arg("diag_val") = 1.0f) .def("randint", &NetBuilder::RandInt, py::arg("shape"), - py::arg("min") = 0, - py::arg("max") = 0, - py::arg("seed") = 0, + py::arg("min") = 0, + py::arg("max") = 0, + py::arg("seed") = 0, py::arg("dtype") = "int64") - .def("repeat", &NetBuilder::Repeat, py::arg("x"), py::arg("repeats"), py::arg("axis")) + .def("repeat", + &NetBuilder::Repeat, + py::arg("x"), + py::arg("repeats"), + py::arg("axis")) .def("flip", &NetBuilder::Flip, py::arg("x"), py::arg("axis")) - .def("cholesky", &NetBuilder::Cholesky, py::arg("x"), py::arg("upper") = false) + .def("cholesky", + &NetBuilder::Cholesky, + py::arg("x"), + py::arg("upper") = false) .def("triangular_solve", &NetBuilder::TriangularSolve, py::arg("input1"), py::arg("input2"), - py::arg("left_side") = true, - py::arg("upper") = false, - py::arg("transpose_a") = false, + py::arg("left_side") = true, + py::arg("upper") = false, + py::arg("transpose_a") = false, py::arg("unit_diagonal") = false); - auto computation = py::class_>(*m, "Computation"); + auto computation = + py::class_>( + *m, "Computation"); py::class_(computation, "CompileOptions") - .def_readwrite("use_decomposer", &CinnComputation::CompileOptions::use_decomposer) + .def_readwrite("use_decomposer", + &CinnComputation::CompileOptions::use_decomposer) .def_readwrite("do_prerun", &CinnComputation::CompileOptions::do_prerun) - .def_readwrite("use_default_passes", &CinnComputation::CompileOptions::use_default_passes) + .def_readwrite("use_default_passes", + &CinnComputation::CompileOptions::use_default_passes) .def_readwrite("passes", &CinnComputation::CompileOptions::passes); computation .def("default_compile_options", &CinnComputation::DefaultCompileOptions) - // currently stream param is not exported to python, the default stream is used always + // currently stream param is not exported to python, the default stream is + // used always .def_static( "build_and_compile", - [](const common::Target &target, NetBuilder &builder, const CinnComputation::CompileOptions &options) { + [](const common::Target &target, + NetBuilder &builder, + const CinnComputation::CompileOptions &options) { return CinnComputation::BuildAndCompile(target, builder, options); }, py::arg("target"), @@ -733,7 +861,9 @@ void BindFrontend(pybind11::module *m) { py::arg("options") = CinnComputation::DefaultCompileOptions()) .def_static( "compile", - [](const common::Target &target, Program &program, const CinnComputation::CompileOptions &options) { + [](const common::Target &target, + Program &program, + const CinnComputation::CompileOptions &options) { return CinnComputation::Compile(target, program, options); }, py::arg("target"), @@ -747,8 +877,12 @@ void BindFrontend(pybind11::module *m) { const std::vector &input_shapes, bool params_combined, const CinnComputation::CompileOptions &options) { - return CinnComputation::CompilePaddleModel( - target, model_path, input_names, input_shapes, params_combined, options); + return CinnComputation::CompilePaddleModel(target, + model_path, + input_names, + input_shapes, + params_combined, + options); }, py::arg("target"), py::arg("model_path"), @@ -762,22 +896,30 @@ void BindFrontend(pybind11::module *m) { py::class_(*m, "PaddleModelConvertor") .def(py::init<>()) - .def(py::init, std::shared_ptr>(), + .def(py::init, + std::shared_ptr>(), py::arg("target"), py::arg("builder") = nullptr, - py::arg("scope") = nullptr) + py::arg("scope") = nullptr) .def("__call__", &PaddleModelConvertor::operator()) .def("load_model", &PaddleModelConvertor::LoadModel, py::arg("model_dir"), py::arg("is_combined") = false, - py::arg("feed") = std::unordered_map>()) - .def("create_input", &PaddleModelConvertor::CreateInput, py::arg("dtype"), py::arg("shape"), py::arg("name")) + py::arg("feed") = + std::unordered_map>()) + .def("create_input", + &PaddleModelConvertor::CreateInput, + py::arg("dtype"), + py::arg("shape"), + py::arg("name")) .def("append_op", - static_cast> &, - const std::map> &, - const std::map &)>( + static_cast> &, + const std::map> &, + const std::map &)>( &PaddleModelConvertor::RunOp), py::arg("type"), py::arg("inputs"), @@ -786,11 +928,13 @@ void BindFrontend(pybind11::module *m) { .def("get_fetch_list", &PaddleModelConvertor::GetFetchList, py::arg("fetch_list") = std::unordered_set{}) - .def("get_cinn_name", [](PaddleModelConvertor &self, const std::string &paddle_name) { - CHECK(self.var_model_to_program_map().count(paddle_name)) - << "Cannot find variabel " << paddle_name << " in CINN! Please check."; - return self.var_model_to_program_map().at(paddle_name); - }); + .def("get_cinn_name", + [](PaddleModelConvertor &self, const std::string &paddle_name) { + CHECK(self.var_model_to_program_map().count(paddle_name)) + << "Cannot find variabel " << paddle_name + << " in CINN! Please check."; + return self.var_model_to_program_map().at(paddle_name); + }); } // namespace frontend diff --git a/paddle/cinn/pybind/ir.cc b/paddle/cinn/pybind/ir.cc index 767d0eb11af03..a79da48c88727 100755 --- a/paddle/cinn/pybind/ir.cc +++ b/paddle/cinn/pybind/ir.cc @@ -61,10 +61,16 @@ void BindLoweredFunc(py::module *m) { py::class_ argument(*m, "Argument"); py::enum_ io(argument, "IO"); - io.value("kInput", Argument::IO::kInput).value("kOutput", Argument::IO::kOutput); - - argument.def(py::init(), py::arg("buffer"), py::arg("io") = Argument::IO::kInput) - .def(py::init(), py::arg("var"), py::arg("io") = Argument::IO::kInput) + io.value("kInput", Argument::IO::kInput) + .value("kOutput", Argument::IO::kOutput); + + argument + .def(py::init(), + py::arg("buffer"), + py::arg("io") = Argument::IO::kInput) + .def(py::init(), + py::arg("var"), + py::arg("io") = Argument::IO::kInput) .def("set_buffer", &Argument::set_buffer) .def("set_var", &Argument::set_var) .def("is_input", &Argument::is_input) @@ -80,10 +86,16 @@ void BindLoweredFunc(py::module *m) { py::class_ lowered_func(*m, "LoweredFunc"); lowered_func.def(py::init<>()) .def(py::init()) - .def("name", [](const ir::LoweredFunc &self) -> std::string { return self->name; }) - .def("__str__", [](const ir::LoweredFunc &self) -> std::string { return utils::GetStreamCnt(Expr(self)); }) + .def( + "name", + [](const ir::LoweredFunc &self) -> std::string { return self->name; }) + .def("__str__", + [](const ir::LoweredFunc &self) -> std::string { + return utils::GetStreamCnt(Expr(self)); + }) .def("__repr__", [](const ir::LoweredFunc &self) -> std::string { - return llvm::formatv("", self.get(), self->name.c_str()); + return llvm::formatv( + "", self.get(), self->name.c_str()); }); } @@ -96,7 +108,8 @@ void BindNode(py::module *m) { #undef DECLARE_IR_NODE_TY // class IrNode - py::class_ ir_node(*m, "IrNode", py::module_local()); + py::class_ ir_node( + *m, "IrNode", py::module_local()); ir_node.def(py::init<>()) .def(py::init()) .def_readwrite("operands", &ir::IrNode::operands) @@ -104,14 +117,16 @@ void BindNode(py::module *m) { .def("type", &ir::IrNode::type) .def("set_type", &ir::IrNode::set_type) .def("expr_fields_mutable", py::overload_cast<>(&ir::IrNode::expr_fields)) - .def("expr_fields_const", py::overload_cast<>(&ir::IrNode::expr_fields, py::const_)) + .def("expr_fields_const", + py::overload_cast<>(&ir::IrNode::expr_fields, py::const_)) .def("type_info", &ir::IrNode::type_info); // class Shared DefineShared(m, "IrNode"); // class IrNodeRef : public Shared - py::class_> ir_node_ref(*m, "IrNodeRef"); + py::class_> ir_node_ref(*m, + "IrNodeRef"); ir_node_ref.def(py::init<>()) .def(py::init()) .def(py::init()) @@ -122,24 +137,31 @@ void BindNode(py::module *m) { py::class_> int_imm(*m, "IntImm"); int_imm.def_readwrite("value", &ir::IntImm::value) .def(py::init()) - .def("__str__", [](const ir::IntImm &self) { return std::to_string(self.value); }) - .def("__repr__", - [](ir::IntImm &self) -> std::string { return llvm::formatv("", self.self(), self.value); }); + .def("__str__", + [](const ir::IntImm &self) { return std::to_string(self.value); }) + .def("__repr__", [](ir::IntImm &self) -> std::string { + return llvm::formatv("", self.self(), self.value); + }); // struct UIntImm : ExprNode DefineExprNode(m, "UIntImm"); py::class_> uint_imm(*m, "UIntImm"); - uint_imm.def_readwrite("value", &ir::UIntImm::value).def(py::init()); + uint_imm.def_readwrite("value", &ir::UIntImm::value) + .def(py::init()); // struct FloatImm : ExprNode DefineExprNode(m, "FloatImm"); - py::class_> float_imm(*m, "FloatImm"); - float_imm.def_readwrite("value", &ir::FloatImm::value).def(py::init()); + py::class_> float_imm(*m, + "FloatImm"); + float_imm.def_readwrite("value", &ir::FloatImm::value) + .def(py::init()); // struct StringImm : ExprNode DefineExprNode(m, "StringImm"); - py::class_> string_imm(*m, "StringImm"); - string_imm.def_readwrite("value", &ir::StringImm::value).def(py::init()); + py::class_> string_imm( + *m, "StringImm"); + string_imm.def_readwrite("value", &ir::StringImm::value) + .def(py::init()); auto expr = py::class_(*m, "Expr"); @@ -159,20 +181,30 @@ void BindNode(py::module *m) { .def("as_float", &ir::Expr::as_float) .def("as_double", &ir::Expr::as_double) .def("int", [](ir::Expr &self) { return self.As()->value; }) - .def("float", [](ir::Expr &self) { return self.As()->value; }) + .def("float", + [](ir::Expr &self) { return self.As()->value; }) - .def("__str__", [](const Expr &self) { return utils::GetStreamCnt(self); }) + .def("__str__", + [](const Expr &self) { return utils::GetStreamCnt(self); }) .def("__repr__", [](const Expr &self) -> std::string { std::string content = self.get() ? utils::GetStreamCnt(self) : ""; return llvm::formatv("", content); }); - expr.def("as_var_mutable", py::overload_cast<>(&ir::Expr::as_var), py::return_value_policy::reference) - .def("as_var_const", py::overload_cast<>(&ir::Expr::as_var, py::const_), py::return_value_policy::reference) + expr.def("as_var_mutable", + py::overload_cast<>(&ir::Expr::as_var), + py::return_value_policy::reference) + .def("as_var_const", + py::overload_cast<>(&ir::Expr::as_var, py::const_), + py::return_value_policy::reference) .def("as_var_ref", &ir::Expr::as_var_ref); - expr.def("as_buffer_mutable", py::overload_cast<>(&ir::Expr::as_buffer), py::return_value_policy::reference) - .def("as_buffer_const", py::overload_cast<>(&ir::Expr::as_buffer, py::const_), py::return_value_policy::reference) + expr.def("as_buffer_mutable", + py::overload_cast<>(&ir::Expr::as_buffer), + py::return_value_policy::reference) + .def("as_buffer_const", + py::overload_cast<>(&ir::Expr::as_buffer, py::const_), + py::return_value_policy::reference) .def("as_buffer_ref", &ir::Expr::as_buffer_ref); expr.def("is_constant", &ir::Expr::is_constant) @@ -207,16 +239,30 @@ void BindNode(py::module *m) { BIND_POD_BINARY_OP(int()) // BIND_POD_BINARY_OP(float()); - expr.def("__add__", [](const Expr &self, const Var &other) -> Expr { return self + other; }) - .def("__sub__", [](const Expr &self, const Var &other) -> Expr { return self - other; }) - .def("__mul__", [](const Expr &self, const Var &other) -> Expr { return self * other; }) - .def("__div__", [](const Expr &self, const Var &other) -> Expr { return self / other; }); + expr.def("__add__", + [](const Expr &self, const Var &other) -> Expr { + return self + other; + }) + .def("__sub__", + [](const Expr &self, const Var &other) -> Expr { + return self - other; + }) + .def("__mul__", + [](const Expr &self, const Var &other) -> Expr { + return self * other; + }) + .def("__div__", [](const Expr &self, const Var &other) -> Expr { + return self / other; + }); } void BindIrVisitor(py::module *m) { py::class_ ir_visitor(*m, "IRVisitor"); - ir_visitor.def(py::init<>()).def("visit", py::overload_cast(&ir::IRVisitor::Visit)); -#define DEFINE_VISIT_FN(__ty) ir_visitor.def("visit", py::overload_cast(&ir::IRVisitor::Visit)); + ir_visitor.def(py::init<>()) + .def("visit", py::overload_cast(&ir::IRVisitor::Visit)); +#define DEFINE_VISIT_FN(__ty) \ + ir_visitor.def("visit", \ + py::overload_cast(&ir::IRVisitor::Visit)); NODETY_FORALL(DEFINE_VISIT_FN) #undef DEFINE_VISIT_FN } @@ -232,8 +278,12 @@ void BindIrIr(py::module *m) { DefineExprNode(m, "Cast"); py::class_> cast(*m, "Cast"); cast.def(py::init<>()) - .def("v_mutable", py::overload_cast<>(&ir::Cast::v), py::return_value_policy::reference) - .def("v_const", py::overload_cast<>(&ir::Cast::v, py::const_), py::return_value_policy::reference); + .def("v_mutable", + py::overload_cast<>(&ir::Cast::v), + py::return_value_policy::reference) + .def("v_const", + py::overload_cast<>(&ir::Cast::v, py::const_), + py::return_value_policy::reference); // struct Let : ExprNode DefineExprNode(m, "Let"); @@ -244,7 +294,8 @@ void BindIrIr(py::module *m) { .def_static("make", &ir::Let::Make) .def("type", &ir::Let::type) .def("expr_fields_mutable", py::overload_cast<>(&ir::Let::expr_fields)) - .def("expr_fields_const", py::overload_cast<>(&ir::Let::expr_fields, py::const_)); + .def("expr_fields_const", + py::overload_cast<>(&ir::Let::expr_fields, py::const_)); // struct Reduce : ExprNode DefineExprNode(m, "Reduce"); @@ -266,7 +317,8 @@ void BindIrIr(py::module *m) { .def_static("make", &ir::Reduce::Make) .def("type", &ir::Reduce::type) .def("expr_fields_mutable", py::overload_cast<>(&ir::Reduce::expr_fields)) - .def("expr_fields_const", py::overload_cast<>(&ir::Reduce::expr_fields, py::const_)); + .def("expr_fields_const", + py::overload_cast<>(&ir::Reduce::expr_fields, py::const_)); // enum class CallType py::enum_ call_type(*m, "CallType"); @@ -292,7 +344,8 @@ void BindIrIr(py::module *m) { .def("is_intrinsic_call", &ir::Call::is_intrinsic_call) .def("is_isl_call", &ir::Call::is_isl_call) .def("expr_fields_mutable", py::overload_cast<>(&ir::Call::expr_fields)) - .def("expr_fields_const", py::overload_cast<>(&ir::Call::expr_fields, py::const_)); + .def("expr_fields_const", + py::overload_cast<>(&ir::Call::expr_fields, py::const_)); // struct _Var_ : ExprNode<_Var_> DefineExprNode(m, "_Var_"); @@ -304,8 +357,13 @@ void BindIrIr(py::module *m) { .def_readwrite("tag", &ir::_Var_::tag) .def(py::init<>()) .def(py::init()) - .def_static("make", py::overload_cast(&ir::_Var_::Make)) - .def_static("make", py::overload_cast(&ir::_Var_::Make)) + .def_static("make", + py::overload_cast( + &ir::_Var_::Make)) + .def_static( + "make", + py::overload_cast( + &ir::_Var_::Make)) .def("copy", &ir::_Var_::Copy); // struct Select @@ -318,40 +376,50 @@ void BindIrIr(py::module *m) { .def_static("make", &ir::Select::Make) .def("type", &ir::Select::type) .def("expr_fields_mutable", py::overload_cast<>(&ir::Select::expr_fields)) - .def("expr_fields_const", py::overload_cast<>(&ir::Select::expr_fields, py::const_)); + .def("expr_fields_const", + py::overload_cast<>(&ir::Select::expr_fields, py::const_)); // struct LoadStoreAddrMnger - py::class_ load_store_addr_manager(*m, "LoadStoreAddrMnger"); - load_store_addr_manager.def_readwrite("tensor", &ir::LoadStoreAddrMnger::tensor) + py::class_ load_store_addr_manager( + *m, "LoadStoreAddrMnger"); + load_store_addr_manager + .def_readwrite("tensor", &ir::LoadStoreAddrMnger::tensor) .def("is_addr_tensor", &ir::LoadStoreAddrMnger::is_addr_tensor) .def("is_addr_scalar", &ir::LoadStoreAddrMnger::is_addr_scalar); // struct Load : ExprNode, LoadStoreAddrMnger DefineExprNode(m, "Load"); - py::class_, ir::LoadStoreAddrMnger> load(*m, "Load"); + py::class_, ir::LoadStoreAddrMnger> load(*m, + "Load"); load.def_readwrite("indices", &ir::Load::indices) .def("index", &ir::Load::index) .def_static("make", &ir::Load::Make) .def("expr_fields_mutable", py::overload_cast<>(&ir::Load::expr_fields)) - .def("expr_fields_const", py::overload_cast<>(&ir::Load::expr_fields, py::const_)) + .def("expr_fields_const", + py::overload_cast<>(&ir::Load::expr_fields, py::const_)) .def("name", &ir::Load::name) .def("type", &ir::Load::type); // struct Store : ExprNode, LoadStoreAddrMnger DefineExprNode(m, "Store"); - py::class_, ir::LoadStoreAddrMnger> store(*m, "Store"); + py::class_, ir::LoadStoreAddrMnger> store( + *m, "Store"); store.def_readwrite("value", &ir::Store::value) .def_readwrite("indices", &ir::Store::indices) .def_static("make", &ir::Store::Make) .def("expr_fields_mutable", py::overload_cast<>(&ir::Store::expr_fields)) - .def("expr_fields_const", py::overload_cast<>(&ir::Store::expr_fields, py::const_)) + .def("expr_fields_const", + py::overload_cast<>(&ir::Store::expr_fields, py::const_)) .def("type", &ir::Store::type) .def("index", &ir::Store::index); -#define DEFINE_BINARY_NODE(__node) \ - DefineBinaryOpNode(m, #__node); \ - py::class_> py_##__node(*m, #__node); \ - py_##__node.def(py::init()).def_static("make", &ir::__node::Make).def("type", &ir::__node::type) +#define DEFINE_BINARY_NODE(__node) \ + DefineBinaryOpNode(m, #__node); \ + py::class_> py_##__node(*m, \ + #__node); \ + py_##__node.def(py::init()) \ + .def_static("make", &ir::__node::Make) \ + .def("type", &ir::__node::type) DEFINE_BINARY_NODE(Add); DEFINE_BINARY_NODE(Sub); @@ -374,11 +442,14 @@ void BindIrIr(py::module *m) { // FracOp DefineBinaryOpNode(m, "FracOp"); py::class_> frac_op(*m, "FracOp"); - frac_op.def(py::init<>()).def_static("make", &ir::FracOp::Make).def("type", &ir::FracOp::type); - -#define DEFINE_UNARY_NODE(__node) \ - DefineUnaryOpNode(m, #__node); \ - py::class_> py_##__node(*m, #__node); \ + frac_op.def(py::init<>()) + .def_static("make", &ir::FracOp::Make) + .def("type", &ir::FracOp::type); + +#define DEFINE_UNARY_NODE(__node) \ + DefineUnaryOpNode(m, #__node); \ + py::class_> py_##__node(*m, \ + #__node); \ py_##__node.def(py::init()).def_static("make", &ir::__node::Make) DEFINE_UNARY_NODE(Minus); @@ -388,26 +459,37 @@ void BindIrIr(py::module *m) { py::class_ var(*m, "Var"); var.def(py::init<>()) .def(py::init()) - .def(py::init(), arg("name_hint"), arg("t") = common::type_of()) + .def(py::init(), + arg("name_hint"), + arg("t") = common::type_of()) .def(py::init()) .def(py::init()) .def(py::init()) - .def("get_mutable", py::overload_cast<>(&Var::get), py::return_value_policy::reference) - .def("get_const", py::overload_cast<>(&Var::get, py::const_), py::return_value_policy::reference) + .def("get_mutable", + py::overload_cast<>(&Var::get), + py::return_value_policy::reference) + .def("get_const", + py::overload_cast<>(&Var::get, py::const_), + py::return_value_policy::reference) .def("to_expr_mutable", py::overload_cast<>(&Var::operator ir::Expr)) - .def("to_expr_const", py::overload_cast<>(&Var::operator ir::Expr, py::const_)) - .def("__repr__", [](Var &self) -> std::string { return llvm::formatv("", self->name); }) + .def("to_expr_const", + py::overload_cast<>(&Var::operator ir::Expr, py::const_)) + .def("__repr__", + [](Var &self) -> std::string { + return llvm::formatv("", self->name); + }) .def("expr", [](Var &self) -> Expr { return Expr(self->self()); }) BIND_POD_BINARY_OP(int()) // BIND_POD_BINARY_OP(int32_t()) // BIND_POD_BINARY_OP(float()) -#define BINARY_OP(type__) \ - .def("__add__", [](Var &self, type__ v) -> Expr { return self + v; }) \ - .def("__sub__", [](Var &self, type__ v) -> Expr { return self - v; }) \ - .def("__truediv__", [](Var &self, type__ v) -> Expr { return self / v; }) \ - .def("__mul__", [](Var &self, type__ v) -> Expr { return self * v; }) \ +#define BINARY_OP(type__) \ + .def("__add__", [](Var &self, type__ v) -> Expr { return self + v; }) \ + .def("__sub__", [](Var &self, type__ v) -> Expr { return self - v; }) \ + .def("__truediv__", \ + [](Var &self, type__ v) -> Expr { return self / v; }) \ + .def("__mul__", [](Var &self, type__ v) -> Expr { return self * v; }) \ .def("__mod__", [](Var &self, type__ v) -> Expr { return self % v; }) BINARY_OP(int32_t) // @@ -420,7 +502,9 @@ void BindIrIr(py::module *m) { py::class_> product(*m, "Product"); product.def_static("make", &ir::Product::Make) .def("type", &ir::Product::type) - .def("operand_mutable", py::overload_cast(&ir::Product::operand), py::return_value_policy::reference) + .def("operand_mutable", + py::overload_cast(&ir::Product::operand), + py::return_value_policy::reference) .def("operand_const", py::overload_cast(&ir::Product::operand, py::const_), py::return_value_policy::reference); @@ -428,8 +512,12 @@ void BindIrIr(py::module *m) { DefineExprNode(m, "Sum"); py::class_> sum(*m, "Sum"); sum.def_static("make", &ir::Sum::Make) - .def("operand_mutable", py::overload_cast(&ir::Sum::operand), py::return_value_policy::reference) - .def("operand_const", py::overload_cast(&ir::Sum::operand, py::const_), py::return_value_policy::reference) + .def("operand_mutable", + py::overload_cast(&ir::Sum::operand), + py::return_value_policy::reference) + .def("operand_const", + py::overload_cast(&ir::Sum::operand, py::const_), + py::return_value_policy::reference) .def("type", &ir::Sum::type); DefineExprNode(m, "Block"); @@ -438,7 +526,8 @@ void BindIrIr(py::module *m) { .def(py::init<>()) .def_static("make", &ir::Block::Make) .def("expr_fields_mutable", py::overload_cast<>(&ir::Block::expr_fields)) - .def("expr_fields_const", py::overload_cast<>(&ir::Block::expr_fields, py::const_)); + .def("expr_fields_const", + py::overload_cast<>(&ir::Block::expr_fields, py::const_)); DefineExprNode(m, "_Module_"); py::class_> _module_(*m, "_Module_"); @@ -450,7 +539,8 @@ void BindIrIr(py::module *m) { } void BindOperation(py::module *m) { - py::class_ placeholder_op(*m, "PlaceholderOp"); + py::class_ placeholder_op( + *m, "PlaceholderOp"); placeholder_op.def_readwrite("shape", &ir::PlaceholderOp::shape) .def_readwrite("dtype", &ir::PlaceholderOp::dtype) .def_static("make", &ir::PlaceholderOp::Make) @@ -460,9 +550,11 @@ void BindOperation(py::module *m) { call_op.def("target", &ir::CallOp::target) .def_readwrite("call_expr", &ir::CallOp::call_expr) .def("read_args_mutable", py::overload_cast<>(&ir::CallOp::read_args)) - .def("read_args_const", py::overload_cast<>(&ir::CallOp::read_args, py::const_)) + .def("read_args_const", + py::overload_cast<>(&ir::CallOp::read_args, py::const_)) .def("write_args_mutable", py::overload_cast<>(&ir::CallOp::write_args)) - .def("write_args_const", py::overload_cast<>(&ir::CallOp::write_args, py::const_)) + .def("write_args_const", + py::overload_cast<>(&ir::CallOp::write_args, py::const_)) .def("args", &ir::CallOp::args) .def_readwrite("func", &ir::CallOp::func) .def_readwrite("value_slot", &ir::CallOp::value_slot) @@ -472,7 +564,8 @@ void BindOperation(py::module *m) { .def_static("make", &ir::CallOp::Make) .def("func_type", &ir::CallOp::func_type); - py::class_ compute_op(*m, "ComputeOp"); + py::class_ compute_op(*m, + "ComputeOp"); compute_op.def_readwrite("reduce_axis", &ir::ComputeOp::reduce_axis) .def_readwrite("shape", &ir::ComputeOp::shape) .def_readwrite("body", &ir::ComputeOp::body) @@ -488,9 +581,15 @@ void BindIrTensor(py::module *m) { .def(py::init()) .def("ndims", &ir::Tensor::ndims) .def("__call__", [](ir::Tensor &self, Expr a) { return self(a); }) - .def("__call__", [](ir::Tensor &self, Expr a, Expr b) { return self(a, b); }) - .def("__call__", [](ir::Tensor &self, Expr a, Expr b, Expr c) { return self(a, b, c); }) - .def("__call__", [](ir::Tensor &self, Expr a, Expr b, Expr c, Expr d) { return self(a, b, c, d); }); + .def("__call__", + [](ir::Tensor &self, Expr a, Expr b) { return self(a, b); }) + .def("__call__", + [](ir::Tensor &self, Expr a, Expr b, Expr c) { + return self(a, b, c); + }) + .def("__call__", [](ir::Tensor &self, Expr a, Expr b, Expr c, Expr d) { + return self(a, b, c, d); + }); DefineExprNode(m, "_Tensor_"); py::class_> _tensor_(*m, "_Tensor_"); @@ -500,7 +599,8 @@ void BindIrTensor(py::module *m) { .def_readwrite("name", &ir::_Tensor_::name) .def_readwrite("buffer", &ir::_Tensor_::buffer) .def("domain_with_reduce_axis", &ir::_Tensor_::domain_without_reduce_axis) - .def("domain_without_reduce_axis", &ir::_Tensor_::domain_without_reduce_axis) + .def("domain_without_reduce_axis", + &ir::_Tensor_::domain_without_reduce_axis) .def_static("make", &ir::_Tensor_::Make) .def("is_tuple", &ir::_Tensor_::is_tuple) .def("is_tuple_get", &ir::_Tensor_::is_tuple_get) @@ -519,14 +619,18 @@ void BindIrTensor(py::module *m) { .def("get_compute_op", &ir::_Tensor_::get_compute_op) .def("get_placeholder_op", &ir::_Tensor_::get_placeholder_op) .def("body", &ir::_Tensor_::body) - .def("tensor_store_expanded_body", &ir::_Tensor_::tensor_store_expanded_body) + .def("tensor_store_expanded_body", + &ir::_Tensor_::tensor_store_expanded_body) .def("inline_expanded", &ir::_Tensor_::inline_expanded) .def("contains_reduce_axis", &ir::_Tensor_::contains_reduce_axis) - .def("expr_fields_mutable", py::overload_cast<>(&ir::_Tensor_::expr_fields)) - .def("expr_fields_const", py::overload_cast<>(&ir::_Tensor_::expr_fields, py::const_)) + .def("expr_fields_mutable", + py::overload_cast<>(&ir::_Tensor_::expr_fields)) + .def("expr_fields_const", + py::overload_cast<>(&ir::_Tensor_::expr_fields, py::const_)) .def("axis", &ir::_Tensor_::axis) .def("axis_with_reduce", &ir::_Tensor_::axis_with_reduce) - .def("buffer_depended_tensor_names", &ir::_Tensor_::buffer_depended_tensor_names) + .def("buffer_depended_tensor_names", + &ir::_Tensor_::buffer_depended_tensor_names) .def(py::init<>()) .def("has_expression", &ir::_Tensor_::has_expression) .def("reshape", &ir::_Tensor_::Reshape) @@ -535,16 +639,22 @@ void BindIrTensor(py::module *m) { py::overload_cast(&ir::_Tensor_::WithBuffer), py::arg("type") = Type::type_t::Void) .def("with_buffer", - py::overload_cast(&ir::_Tensor_::WithBuffer), + py::overload_cast(&ir::_Tensor_::WithBuffer), py::arg("memory_type"), py::arg("buffer_name") = "", - py::arg("type") = Type::type_t::Void) + py::arg("type") = Type::type_t::Void) .def("bind", py::overload_cast(&ir::_Tensor_::Bind)) .def("bind", py::overload_cast(&ir::_Tensor_::Bind)) - .def("__str__", [](const ir::Tensor &self) { return "name + ">"; }); + .def("__str__", [](const ir::Tensor &self) { + return "name + ">"; + }); py::class_ operation(*m, "Operation"); - operation.def(py::init<>()).def(py::init()).def_readwrite("name", &ir::Operation::name); + operation.def(py::init<>()) + .def(py::init()) + .def_readwrite("name", &ir::Operation::name); } auto PackedFuncCall(lang::PackedFunc &self, py::args args) { // NOLINT @@ -560,7 +670,8 @@ auto PackedFuncCall(lang::PackedFunc &self, py::args args) { // NOLINT } else if (py::isinstance(handle)) { cinn_args.Append(CINNValue(py::cast(handle))); } else { - LOG(FATAL) << "unsupported type: " << std::string(py::str(handle.get_type())); + LOG(FATAL) << "unsupported type: " + << std::string(py::str(handle.get_type())); } } lang::RetValue ret_value; @@ -576,8 +687,11 @@ void BindPackedFunc(py::module *m) { .def("size", &lang::Args::size) .def("__len__", &lang::Args::size) .def( - "__getitem__", [](lang::Args &self, int i) { return self[i]; }, py::return_value_policy::reference) - .def("__setitem__", [](lang::Args &self, int i, common::CINNValue &v) { self[i] = v; }); + "__getitem__", + [](lang::Args &self, int i) { return self[i]; }, + py::return_value_policy::reference) + .def("__setitem__", + [](lang::Args &self, int i, common::CINNValue &v) { self[i] = v; }); py::class_ packed_func(*m, "PackedFunc"); packed_func.def(py::init<>()) @@ -595,30 +709,37 @@ void BindRegistry(py::module *m) { py::arg("name"), py::arg("override") = false, py::return_value_policy::reference) - .def_static("register", &ir::Registry::Register, py::return_value_policy::reference) + .def_static("register", + &ir::Registry::Register, + py::return_value_policy::reference) .def_static("remove", &ir::Registry::Remove) .def_static("get", &ir::Registry::Get, py::return_value_policy::reference) .def_static("list_names", &ir::Registry::ListNames) - .def("set_body", py::overload_cast(&ir::Registry::SetBody), py::return_value_policy::reference); + .def("set_body", + py::overload_cast(&ir::Registry::SetBody), + py::return_value_policy::reference); #ifdef CINN_WITH_TEST - ir::Registry::Register("test_add_int64").SetBody([](lang::Args args, lang::RetValue *rv) { - int64_t x = args[0]; - int64_t y = args[1]; - *rv = x + y; - }); - - ir::Registry::Register("test_add_expr").SetBody([](lang::Args args, lang::RetValue *rv) { - ir::Expr x = args[0]; - ir::Expr y = args[1]; - *rv = x + y; - }); - - ir::Registry::Register("test_mul_float").SetBody([](lang::Args args, lang::RetValue *rv) { - float x = args[0]; - float y = args[1]; - *rv = x * y; - }); + ir::Registry::Register("test_add_int64") + .SetBody([](lang::Args args, lang::RetValue *rv) { + int64_t x = args[0]; + int64_t y = args[1]; + *rv = x + y; + }); + + ir::Registry::Register("test_add_expr") + .SetBody([](lang::Args args, lang::RetValue *rv) { + ir::Expr x = args[0]; + ir::Expr y = args[1]; + *rv = x + y; + }); + + ir::Registry::Register("test_mul_float") + .SetBody([](lang::Args args, lang::RetValue *rv) { + float x = args[0]; + float y = args[1]; + *rv = x * y; + }); #endif } } // namespace diff --git a/paddle/cinn/pybind/lang.cc b/paddle/cinn/pybind/lang.cc index 94819339a4243..3ff7b4d318e5f 100644 --- a/paddle/cinn/pybind/lang.cc +++ b/paddle/cinn/pybind/lang.cc @@ -49,7 +49,10 @@ void BindBuiltin(py::module *); void BindBuffer(py::module *m) { py::class_ buffer(*m, "Buffer"); - buffer.def(py::init(), py::arg("type"), py::arg("name") = "") + buffer + .def(py::init(), + py::arg("type"), + py::arg("name") = "") .def(py::init()) .def("buffer", &lang::Buffer::buffer); } @@ -61,10 +64,10 @@ void BindLower(py::module *m) { arg("name"), arg("stages"), arg("tensor_args"), - arg("scalar_args") = std::vector(), - arg("temp_tensors") = std::vector(), - arg("b") = nullptr, - arg("target") = common::DefaultHostTarget(), + arg("scalar_args") = std::vector(), + arg("temp_tensors") = std::vector(), + arg("b") = nullptr, + arg("target") = common::DefaultHostTarget(), arg("supprt_ir_schedule") = false); } @@ -75,24 +78,26 @@ void BindLowerVec(py::module *m) { arg("name"), arg("stages"), arg("tensor_args"), - arg("scalar_args") = std::vector(), - arg("temp_tensors") = std::vector(), - arg("b") = nullptr, - arg("target") = common::DefaultHostTarget(), + arg("scalar_args") = std::vector(), + arg("temp_tensors") = std::vector(), + arg("b") = nullptr, + arg("target") = common::DefaultHostTarget(), arg("supprt_ir_schedule") = false); } void BindCompute(py::module *m) { -#define MAKE_COMPUTE_FN(__fn) \ - py::overload_cast &, __fn, const std::string &, const std::vector &>( \ - &lang::Compute) +#define MAKE_COMPUTE_FN(__fn) \ + py::overload_cast &, \ + __fn, \ + const std::string &, \ + const std::vector &>(&lang::Compute) #define DEFINE_COMPUTE(__fn) \ m->def("compute", \ MAKE_COMPUTE_FN(__fn), \ arg("domin"), \ arg("fn"), \ - arg("name") = "", \ + arg("name") = "", \ arg("shape") = std::vector()) // DEFINE_COMPUTE(std::function); @@ -100,8 +105,9 @@ void BindCompute(py::module *m) { DEFINE_COMPUTE(std::function &)>); // DEFINE_COMPUTE(std::function); // DEFINE_COMPUTE(std::function); - // DEFINE_COMPUTE(std::function); - // DEFINE_COMPUTE(std::function); + // DEFINE_COMPUTE(std::function); DEFINE_COMPUTE(std::function); DEFINE_COMPUTE(lang::compute_handler_t); #undef DEFINE_COMPUTE @@ -113,12 +119,16 @@ void BindCompute(py::module *m) { .def_readwrite("name", &lang::ReturnType::name); m->def("call_lowered", - py::overload_cast &, const std::vector &>( - &lang::CallLowered)); - m->def("call_extern", py::overload_cast &, - const std::map> &>( + const std::vector &>( + &lang::CallLowered)); + m->def("call_extern", + py::overload_cast< + const std::string &, + const std::vector &, + const std::map> &>( &lang::CallExtern)); } @@ -146,7 +156,8 @@ void BindModule(py::module *m) { class PlaceholderWrapper { public: #define DEFINE_PLACEHOLDER(__dtype, __type) \ - if (dtype == #__dtype) placeholder_ = std::make_unique>(name, shape) + if (dtype == #__dtype) \ + placeholder_ = std::make_unique>(name, shape) #define INIT_PLACEHOLDER \ DEFINE_PLACEHOLDER(int32, int32_t); \ @@ -154,11 +165,15 @@ class PlaceholderWrapper { DEFINE_PLACEHOLDER(float32, float); \ DEFINE_PLACEHOLDER(float64, double) - PlaceholderWrapper(absl::string_view dtype, const std::string &name, const std::vector &shape) { + PlaceholderWrapper(absl::string_view dtype, + const std::string &name, + const std::vector &shape) { INIT_PLACEHOLDER; } - PlaceholderWrapper(absl::string_view dtype, const std::string &name, const std::vector &shape) { + PlaceholderWrapper(absl::string_view dtype, + const std::string &name, + const std::vector &shape) { INIT_PLACEHOLDER; } #undef INIT_PLACEHOLDER @@ -204,29 +219,51 @@ class PlaceholderWrapper { void BindPlaceholder(py::module *m) { py::class_ placeholder(*m, "Placeholder"); - placeholder.def(py::init &>()) - .def(py::init &>()) + placeholder + .def(py::init &>()) + .def(py::init &>()) .def("type", &PlaceholderWrapper::type) .def("tensor", &PlaceholderWrapper::tensor) - .def("__call__", [](PlaceholderWrapper &self, ir::Expr a) { return self(std::move(a)); }) .def("__call__", - [](PlaceholderWrapper &self, ir::Expr a, ir::Expr b) { return self(std::move(a), std::move(b)); }) + [](PlaceholderWrapper &self, ir::Expr a) { + return self(std::move(a)); + }) + .def("__call__", + [](PlaceholderWrapper &self, ir::Expr a, ir::Expr b) { + return self(std::move(a), std::move(b)); + }) .def("__call__", [](PlaceholderWrapper &self, ir::Expr a, ir::Expr b, ir::Expr c) { return self(std::move(a), std::move(b), std::move(c)); }) - .def("__call__", [](PlaceholderWrapper &self, const std::vector &indices) { return self(indices); }) + .def("__call__", + [](PlaceholderWrapper &self, const std::vector &indices) { + return self(indices); + }) .def("to_expr", [](PlaceholderWrapper &self) { return ir::Expr(self); }) - .def("to_tensor", [](PlaceholderWrapper &self) { return ir::Tensor(self); }); + .def("to_tensor", + [](PlaceholderWrapper &self) { return ir::Tensor(self); }); m->def("create_placeholder", - static_cast &, Type, const std::string &)>(&lang::CreatePlaceHolder)); + static_cast &, Type, const std::string &)>( + &lang::CreatePlaceHolder)); m->def("create_placeholder", - static_cast &, Type, const std::string &)>(&lang::CreatePlaceHolder)); + static_cast &, Type, const std::string &)>( + &lang::CreatePlaceHolder)); } void BindBuiltin(py::module *m) { - m->def("reduce_sum", &lang::ReduceSum, py::arg("e"), py::arg("reduce_axis"), py::arg("init") = Expr()); + m->def("reduce_sum", + &lang::ReduceSum, + py::arg("e"), + py::arg("reduce_axis"), + py::arg("init") = Expr()); m->def("reduce_mul", &lang::ReduceMul); m->def("reduce_max", &lang::ReduceMax); m->def("reduce_min", &lang::ReduceMin); diff --git a/paddle/cinn/pybind/pe.cc b/paddle/cinn/pybind/pe.cc index b91976029dc19..94204ae4b3e44 100644 --- a/paddle/cinn/pybind/pe.cc +++ b/paddle/cinn/pybind/pe.cc @@ -33,7 +33,11 @@ using utils::GetStreamCnt; using utils::StringFormat; void BindPE(py::module* m) { -#define BIND_UNARY(name__, fn__) m->def(#name__, &hlir::pe::fn__, py::arg("x"), py::arg("out") = "T_" #name__ "_out") +#define BIND_UNARY(name__, fn__) \ + m->def(#name__, \ + &hlir::pe::fn__, \ + py::arg("x"), \ + py::arg("out") = "T_" #name__ "_out") BIND_UNARY(exp, Exp); BIND_UNARY(erf, Erf); BIND_UNARY(sqrt, Sqrt); @@ -70,7 +74,12 @@ void BindPE(py::module* m) { BIND_UNARY(rsqrt, Rsqrt); #define BIND_BINARY(name__, fn__) \ - m->def(#name__, &hlir::pe::fn__, py::arg("x"), py::arg("y"), py::arg("out"), py::arg("axis") = Expr(-1)) + m->def(#name__, \ + &hlir::pe::fn__, \ + py::arg("x"), \ + py::arg("y"), \ + py::arg("out"), \ + py::arg("axis") = Expr(-1)) BIND_BINARY(add, Add); BIND_BINARY(atan2, Atan2); @@ -103,7 +112,7 @@ void BindPE(py::module* m) { py::arg("x"), \ py::arg("axes"), \ py::arg("keep_dims") = false, \ - py::arg("out") = "T_" #name__ "_out") + py::arg("out") = "T_" #name__ "_out") BIND_REDUCE(reduce_sum, ReduceSum); BIND_REDUCE(reduce_prod, ReduceProd); BIND_REDUCE(reduce_max, ReduceMax); @@ -117,8 +126,8 @@ void BindPE(py::module* m) { py::arg("tensor_b"), py::arg("trans_a") = false, py::arg("trans_b") = false, - py::arg("alpha") = 1, - py::arg("out") = "T_Matmul_out"); + py::arg("alpha") = 1, + py::arg("out") = "T_Matmul_out"); m->def("matmul_mkl", &hlir::pe::MatmulMKL, @@ -126,9 +135,9 @@ void BindPE(py::module* m) { py::arg("tensor_b"), py::arg("trans_a") = false, py::arg("trans_b") = false, - py::arg("alpha") = 1, - py::arg("out") = "T_Matmul_mkl_out", - py::arg("target") = common::DefaultHostTarget()); + py::arg("alpha") = 1, + py::arg("out") = "T_Matmul_mkl_out", + py::arg("target") = common::DefaultHostTarget()); } } // namespace pybind diff --git a/paddle/cinn/pybind/poly.cc b/paddle/cinn/pybind/poly.cc index 24c782df3deaa..e9e0edba2fd3b 100644 --- a/paddle/cinn/pybind/poly.cc +++ b/paddle/cinn/pybind/poly.cc @@ -38,22 +38,31 @@ void BindMap(py::module *m) { .def(py::init<>()) .def(py::init()) .def(py::init()) - .def("__eq__", [](Iterator &self, Iterator &other) { return self == other; }) - .def("__ne__", [](Iterator &self, Iterator &other) { return self != other; }) + .def("__eq__", + [](Iterator &self, Iterator &other) { return self == other; }) + .def("__ne__", + [](Iterator &self, Iterator &other) { return self != other; }) .def("__str__", [](Iterator &self) { return self.id; }) - .def("__repr__", [](Iterator &self) -> std::string { return llvm::formatv("", self.id); }); + .def("__repr__", [](Iterator &self) -> std::string { + return llvm::formatv("", self.id); + }); py::class_ condition(*m, "Condition"); - condition.def_readwrite("cond", &Condition::cond).def(py::init()).def("__str__", &Condition::__str__); + condition.def_readwrite("cond", &Condition::cond) + .def(py::init()) + .def("__str__", &Condition::__str__); } void BindStageMap(py::module *m) { DefineShared(m, "StageMap"); - py::class_> stage_map(*m, "StageMap"); + py::class_> stage_map(*m, + "StageMap"); stage_map // .def( "__getitem__", - [](poly::StageMap self, ir::Tensor &t) -> Stage & { return *self[t]; }, + [](poly::StageMap self, ir::Tensor &t) -> Stage & { + return *self[t]; + }, py::return_value_policy::reference); m->def("create_stages", &poly::CreateStages, py::arg("tensors")); @@ -71,20 +80,34 @@ void BindStage(py::module *m) { stage.def("id", &Stage::id) .def("expr", &Stage::expr) .def("axis", py::overload_cast(&Stage::axis, py::const_)) - .def("axis", py::overload_cast(&Stage::axis, py::const_)) + .def("axis", + py::overload_cast(&Stage::axis, py::const_)) .def("axis_names", &Stage::axis_names) .def("bind", &Stage::Bind) .def("compute_inline", &Stage::ComputeInline, - "Mark this tensor as inline, and will expand in-place in where it is used") + "Mark this tensor as inline, and will expand in-place in where it " + "is used") .def( "share_buffer_with", [](Stage &self, Stage &other) { self.ShareBufferWith(&other); }, "Share the underlying buffer with another tensor") - .def("split", py::overload_cast(&Stage::Split), arg("level"), arg("factor")) - .def("split", py::overload_cast(&Stage::Split), arg("level"), arg("factor")) - .def("split", py::overload_cast(&Stage::Split), arg("level"), arg("factor")) - .def("fuse", py::overload_cast(&Stage::Fuse), arg("level0"), arg("level1")) + .def("split", + py::overload_cast(&Stage::Split), + arg("level"), + arg("factor")) + .def("split", + py::overload_cast(&Stage::Split), + arg("level"), + arg("factor")) + .def("split", + py::overload_cast(&Stage::Split), + arg("level"), + arg("factor")) + .def("fuse", + py::overload_cast(&Stage::Fuse), + arg("level0"), + arg("level1")) .def("fuse", py::overload_cast &>(&Stage::Fuse)) .def("reorder", py::overload_cast &>(&Stage::Reorder), @@ -92,25 +115,36 @@ void BindStage(py::module *m) { .def("reorder", py::overload_cast &>(&Stage::Reorder), "Reorder the axis in the computation") - .def("tile", py::overload_cast(&Stage::Tile)) + .def("tile", + py::overload_cast( + &Stage::Tile)) .def("tile", py::overload_cast(&Stage::Tile)) .def("vectorize", py::overload_cast(&Stage::Vectorize)) - .def("vectorize", py::overload_cast(&Stage::Vectorize)) - .def("vectorize", py::overload_cast(&Stage::Vectorize)) + .def("vectorize", + py::overload_cast(&Stage::Vectorize)) + .def("vectorize", + py::overload_cast(&Stage::Vectorize)) .def("unroll", py::overload_cast(&Stage::Unroll)) .def("unroll", py::overload_cast(&Stage::Unroll)) .def("unroll", py::overload_cast(&Stage::Unroll)) .def("parallel", py::overload_cast(&Stage::Parallel)) .def("parallel", py::overload_cast(&Stage::Parallel)) .def("parallel", py::overload_cast(&Stage::Parallel)) - .def("compute_at", &Stage::ComputeAtSchedule, arg("other"), arg("level"), arg("kind") = Stage::kComputeAtAuto) + .def("compute_at", + &Stage::ComputeAtSchedule, + arg("other"), + arg("level"), + arg("kind") = Stage::kComputeAtAuto) .def("skew", &Stage::Skew) .def("ctrl_depend", &Stage::CtrlDepend) .def("cache_read", &Stage::CacheRead) .def("cache_write", &Stage::CacheWrite) - .def("sync_threads", py::overload_cast(&Stage::SyncThreads)) .def("sync_threads", - py::overload_cast &, poly::StageMap>(&Stage::SyncThreads)); + py::overload_cast(&Stage::SyncThreads)) + .def("sync_threads", + py::overload_cast &, + poly::StageMap>(&Stage::SyncThreads)); } } // namespace diff --git a/paddle/cinn/pybind/runtime.cc b/paddle/cinn/pybind/runtime.cc index 98c1afb542fd4..9d562298e5bfc 100644 --- a/paddle/cinn/pybind/runtime.cc +++ b/paddle/cinn/pybind/runtime.cc @@ -53,7 +53,9 @@ cinn_type_t NumpyTypeToCinn(py::dtype dt) { return cinn_unk_t(); } -cinn_buffer_t *CreateBufferFromNumpy(py::array data, cinn_device_kind_t device, int align = 0) { +cinn_buffer_t *CreateBufferFromNumpy(py::array data, + cinn_device_kind_t device, + int align = 0) { cinn_type_t type = NumpyTypeToCinn(data.dtype()); std::vector shape; std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape)); @@ -106,8 +108,10 @@ void BindSpecialTypes(py::module *m) { py::class_ void_ptr(*m, "VoidPointer"); void_ptr.def(py::init<>()); -#define VOID_PTR_SUPPORT_TYPE(__type) \ - void_ptr.def("set", [](VoidPointer &self, __type *p) { self.ptr = static_cast(p); }) +#define VOID_PTR_SUPPORT_TYPE(__type) \ + void_ptr.def("set", [](VoidPointer &self, __type *p) { \ + self.ptr = static_cast(p); \ + }) VOID_PTR_SUPPORT_TYPE(char); VOID_PTR_SUPPORT_TYPE(int8_t); @@ -135,7 +139,10 @@ void BindCinnRuntime(py::module *m) { .def_readwrite("bits", &cinn_type_t::bits) .def_readwrite("lanes", &cinn_type_t::lanes) .def(py::init<>()) - .def(py::init(), arg("code"), arg("bits"), arg("lanes") = 1) + .def(py::init(), + arg("code"), + arg("bits"), + arg("lanes") = 1) .def(py::self == cinn_type_t()) .def(py::self != cinn_type_t()) .def("bytes", &cinn_type_t::bytes); @@ -162,7 +169,8 @@ void BindCinnRuntime(py::module *m) { .value("cinn_buffer_on_device", cinn_buffer_on_device) .export_values(); - py::class_ cinn_device_interface(*m, "cinn_device_interface_t"); + py::class_ cinn_device_interface( + *m, "cinn_device_interface_t"); m->def("cinn_device_release", &cinn_device_release); m->def("cinn_buffer_copy_to_host", &cinn_buffer_copy_to_host); @@ -170,10 +178,13 @@ void BindCinnRuntime(py::module *m) { m->def("cinn_buffer_copy", &cinn_buffer_copy); m->def("cinn_device_sync", &cinn_device_sync); m->def("cinn_buffer_malloc", &cinn_buffer_malloc); - m->def("cinn_buffer_malloc", [](VoidPointer &p, cinn_buffer_t *buffer) { return cinn_buffer_malloc(p.ptr, buffer); }); + m->def("cinn_buffer_malloc", [](VoidPointer &p, cinn_buffer_t *buffer) { + return cinn_buffer_malloc(p.ptr, buffer); + }); m->def("cinn_buffer_free", &cinn_buffer_free); m->def("cinn_buffer_get_data_handle", &cinn_buffer_get_data_handle); - m->def("cinn_buffer_get_data_const_handle", &cinn_buffer_get_data_const_handle); + m->def("cinn_buffer_get_data_const_handle", + &cinn_buffer_get_data_const_handle); py::class_ cinn_buffer(*m, "cinn_buffer_t"); cinn_buffer.def_readwrite("device", &cinn_buffer_t::device) @@ -209,7 +220,10 @@ void BindCinnRuntime(py::module *m) { .def("set_flag", &cinn_buffer_t::set_flag) // Python methods .def("numpy", &BufferHostMemoryToNumpy) - .def(py::init(&CreateBufferFromNumpy), arg("data"), arg("device"), arg("align") = 0); + .def(py::init(&CreateBufferFromNumpy), + arg("data"), + arg("device"), + arg("align") = 0); m->def("cinn_x86_device_interface", &cinn_x86_device_interface) .def("cinn_buffer_load_float32", &cinn_buffer_load_float32) @@ -255,7 +269,8 @@ void BindCinnRuntime(py::module *m) { .def("to_void_p", &cinn_pod_value_t::operator void *) .def("to_cinn_buffer_t_p", &cinn_pod_value_t::operator cinn_buffer_t *) .def("to_char_p", &cinn_pod_value_t::operator char *) - .def("type_code", py::overload_cast<>(&cinn_pod_value_t::type_code, py::const_)) + .def("type_code", + py::overload_cast<>(&cinn_pod_value_t::type_code, py::const_)) .def("data_addr", &cinn_pod_value_t::data_addr); m->def("cinn_pod_value_to_float", &cinn_pod_value_to_float) @@ -266,7 +281,9 @@ void BindCinnRuntime(py::module *m) { .def("cinn_pod_value_to_void_p", &cinn_pod_value_to_void_p) .def("cinn_pod_value_to_buffer_p", &cinn_pod_value_to_buffer_p); - m->def("set_cinn_cudnn_deterministic", &cinn::runtime::SetCinnCudnnDeterministic, py::arg("state") = true); + m->def("set_cinn_cudnn_deterministic", + &cinn::runtime::SetCinnCudnnDeterministic, + py::arg("state") = true); m->def("seed", &cinn::runtime::RandomSeed::GetOrSet, py::arg("seed") = 0); m->def("clear_seed", &cinn::runtime::RandomSeed::Clear); } diff --git a/paddle/cinn/runtime/buffer.cc b/paddle/cinn/runtime/buffer.cc index 4fdb93cf1d6e3..6f9e6d51ecaa8 100755 --- a/paddle/cinn/runtime/buffer.cc +++ b/paddle/cinn/runtime/buffer.cc @@ -17,7 +17,8 @@ namespace cinn { namespace runtime { -Shape::Shape(const Shape &other) : data_(new value_type[other.ndims()]), ndims_(other.ndims()) { +Shape::Shape(const Shape &other) + : data_(new value_type[other.ndims()]), ndims_(other.ndims()) { if (ndims() > 0) { memcpy(data_, other.data(), ndims_ * sizeof(value_type)); } diff --git a/paddle/cinn/runtime/buffer.h b/paddle/cinn/runtime/buffer.h index ba47e6e2d4578..c3eb5c43b58e3 100755 --- a/paddle/cinn/runtime/buffer.h +++ b/paddle/cinn/runtime/buffer.h @@ -88,7 +88,8 @@ class Buffer { } T& operator()(int i0, int i1, int i2) { CHECK_EQ(shape_.ndims(), 3); - return static_cast(data_)[i0 * shape_[1] * shape_[2] + i1 * shape_[2] + i2]; + return static_cast( + data_)[i0 * shape_[1] * shape_[2] + i1 * shape_[2] + i2]; } private: diff --git a/paddle/cinn/runtime/cinn_runtime.cc b/paddle/cinn/runtime/cinn_runtime.cc index 7af3b7163b234..51c9ac0866cfa 100644 --- a/paddle/cinn/runtime/cinn_runtime.cc +++ b/paddle/cinn/runtime/cinn_runtime.cc @@ -34,7 +34,8 @@ int cinn_buffer_malloc(void* context, struct cinn_buffer_t* buf) { int cinn_buffer_free(void* context, struct cinn_buffer_t* buf) { // ASSERT_NOT_NULL(context) ASSERT_NOT_NULL(buf) - // If buffer is lazy, then we will not free this buffer, that will greatly improve performance. + // If buffer is lazy, then we will not free this buffer, that will greatly + // improve performance. if (buf->lazy) return 0; return buf->device_interface->impl->free(context, buf); } @@ -54,7 +55,8 @@ int cinn_device_sync(void* context, struct cinn_buffer_t* buf) { return 0; } -int cinn_device_release(void* context, const struct cinn_device_interface_t* device_interface) { +int cinn_device_release( + void* context, const struct cinn_device_interface_t* device_interface) { // ASSERT_NOT_NULL(context) ASSERT_NOT_NULL(device_interface) CINN_RUNTIME_NOT_IMPLEMENTED @@ -73,7 +75,9 @@ int cinn_buffer_copy_to_device(void* context, struct cinn_buffer_t* buf) { ASSERT_NOT_NULL(buf->device_interface) return buf->device_interface->impl->copy_to_device(context, buf); } -int cinn_buffer_copy(void* context, struct cinn_buffer_t* src, struct cinn_buffer_t* dst) { +int cinn_buffer_copy(void* context, + struct cinn_buffer_t* src, + struct cinn_buffer_t* dst) { // ASSERT_NOT_NULL(context); ASSERT_NOT_NULL(src); ASSERT_NOT_NULL(dst); @@ -90,17 +94,20 @@ void* cinn_buffer_get_data_const_handle(const struct cinn_buffer_t* buf) { return buf->memory; } -cinn_buffer_t* cinn_buffer_new_default(int target, uint64_t memory_size, int align) { - struct cinn_buffer_t* buf = (struct cinn_buffer_t*)malloc(sizeof(struct cinn_buffer_t)); - buf->type = cinn_float32_t(); - buf->device = (cinn_device_kind_t)target; - buf->memory = nullptr; - buf->memory_size = memory_size; - buf->align = align; - buf->lazy = true; +cinn_buffer_t* cinn_buffer_new_default(int target, + uint64_t memory_size, + int align) { + struct cinn_buffer_t* buf = + (struct cinn_buffer_t*)malloc(sizeof(struct cinn_buffer_t)); + buf->type = cinn_float32_t(); + buf->device = (cinn_device_kind_t)target; + buf->memory = nullptr; + buf->memory_size = memory_size; + buf->align = align; + buf->lazy = true; #ifdef __cplusplus buf->external_malloc = nullptr; - buf->external_free = nullptr; + buf->external_free = nullptr; #endif // __cplusplus // NOTE set device_interface for each buffer. switch (buf->device) { @@ -120,22 +127,48 @@ cinn_buffer_t* cinn_buffer_new_default(int target, uint64_t memory_size, int ali } cinn_type_t cinn_unk_t() { return cinn_type_t(cinn_type_unk, 0); } -cinn_type_t cinn_bool_t(int num_asterisks) { return cinn_type_t(cinn_type_int, 1, num_asterisks); } +cinn_type_t cinn_bool_t(int num_asterisks) { + return cinn_type_t(cinn_type_int, 1, num_asterisks); +} -cinn_type_t cinn_int8_t(int num_asterisks) { return cinn_type_t(cinn_type_int, 8, num_asterisks); } -cinn_type_t cinn_int16_t(int num_asterisks) { return cinn_type_t(cinn_type_int, 16, num_asterisks); } -cinn_type_t cinn_int32_t(int num_asterisks) { return cinn_type_t(cinn_type_int, 32, num_asterisks); } -cinn_type_t cinn_int64_t(int num_asterisks) { return cinn_type_t(cinn_type_int, 64, num_asterisks); } +cinn_type_t cinn_int8_t(int num_asterisks) { + return cinn_type_t(cinn_type_int, 8, num_asterisks); +} +cinn_type_t cinn_int16_t(int num_asterisks) { + return cinn_type_t(cinn_type_int, 16, num_asterisks); +} +cinn_type_t cinn_int32_t(int num_asterisks) { + return cinn_type_t(cinn_type_int, 32, num_asterisks); +} +cinn_type_t cinn_int64_t(int num_asterisks) { + return cinn_type_t(cinn_type_int, 64, num_asterisks); +} -cinn_type_t cinn_uint8_t(int num_asterisks) { return cinn_type_t(cinn_type_uint, 8, num_asterisks); } -cinn_type_t cinn_uint16_t(int num_asterisks) { return cinn_type_t(cinn_type_uint, 16, num_asterisks); } -cinn_type_t cinn_uint32_t(int num_asterisks) { return cinn_type_t(cinn_type_uint, 32, num_asterisks); } -cinn_type_t cinn_uint64_t(int num_asterisks) { return cinn_type_t(cinn_type_uint, 64, num_asterisks); } +cinn_type_t cinn_uint8_t(int num_asterisks) { + return cinn_type_t(cinn_type_uint, 8, num_asterisks); +} +cinn_type_t cinn_uint16_t(int num_asterisks) { + return cinn_type_t(cinn_type_uint, 16, num_asterisks); +} +cinn_type_t cinn_uint32_t(int num_asterisks) { + return cinn_type_t(cinn_type_uint, 32, num_asterisks); +} +cinn_type_t cinn_uint64_t(int num_asterisks) { + return cinn_type_t(cinn_type_uint, 64, num_asterisks); +} -cinn_type_t cinn_bfloat16_t(int num_asterisks) { return cinn_type_t(cinn_type_bfloat, 16, num_asterisks); } -cinn_type_t cinn_float16_t(int num_asterisks) { return cinn_type_t(cinn_type_float, 16, num_asterisks); } -cinn_type_t cinn_float32_t(int num_asterisks) { return cinn_type_t(cinn_type_float, 32, num_asterisks); } -cinn_type_t cinn_float64_t(int num_asterisks) { return cinn_type_t(cinn_type_float, 64, num_asterisks); } +cinn_type_t cinn_bfloat16_t(int num_asterisks) { + return cinn_type_t(cinn_type_bfloat, 16, num_asterisks); +} +cinn_type_t cinn_float16_t(int num_asterisks) { + return cinn_type_t(cinn_type_float, 16, num_asterisks); +} +cinn_type_t cinn_float32_t(int num_asterisks) { + return cinn_type_t(cinn_type_float, 32, num_asterisks); +} +cinn_type_t cinn_float64_t(int num_asterisks) { + return cinn_type_t(cinn_type_float, 64, num_asterisks); +} } // extern "C" @@ -146,13 +179,14 @@ struct cinn_buffer_t* cinn_buffer_t::new_(cinn_device_kind_t device, int32_t dimensions = shape.size(); CINN_CHECK(shape.size() < CINN_BUFFER_MAX_DIMS); - struct cinn_buffer_t* buf = (struct cinn_buffer_t*)malloc(sizeof(struct cinn_buffer_t)); + struct cinn_buffer_t* buf = + (struct cinn_buffer_t*)malloc(sizeof(struct cinn_buffer_t)); memcpy(&(buf->dims[0]), shape.data(), shape.size() * sizeof(int)); - buf->type = type; - buf->device = device; - buf->memory = nullptr; + buf->type = type; + buf->device = device; + buf->memory = nullptr; buf->memory_size = 0; - buf->lazy = true; + buf->lazy = true; // NOTE set device_interface for each buffer. switch (buf->device) { case cinn_x86_device: @@ -168,15 +202,18 @@ struct cinn_buffer_t* cinn_buffer_t::new_(cinn_device_kind_t device, } buf->dimensions = dimensions; - buf->align = align; + buf->align = align; #ifdef __cplusplus buf->external_malloc = nullptr; - buf->external_free = nullptr; + buf->external_free = nullptr; #endif // __cplusplus return buf; } -cinn_buffer_t* cinn_buffer_new(cinn_device_kind_t device, cinn_type_t type, const std::vector& shape, int align) { +cinn_buffer_t* cinn_buffer_new(cinn_device_kind_t device, + cinn_type_t type, + const std::vector& shape, + int align) { return cinn_buffer_t::new_(device, type, shape, align); } @@ -249,38 +286,73 @@ cinn_pod_value_t::operator char*() const { return static_cast(value_.v_handle); } -cinn_pod_value_t::cinn_pod_value_t(cinn_value_t value, int type_code) : value_(value), type_code_(type_code) {} -cinn_pod_value_t::cinn_pod_value_t(cinn_buffer_t* value) : type_code_(::cinn_type_code()) { +cinn_pod_value_t::cinn_pod_value_t(cinn_value_t value, int type_code) + : value_(value), type_code_(type_code) {} +cinn_pod_value_t::cinn_pod_value_t(cinn_buffer_t* value) + : type_code_(::cinn_type_code()) { value_.v_handle = value; } -cinn_pod_value_t::cinn_pod_value_t(bool value) : type_code_(::cinn_type_code()) { value_.v_int64 = value; } +cinn_pod_value_t::cinn_pod_value_t(bool value) + : type_code_(::cinn_type_code()) { + value_.v_int64 = value; +} -cinn_pod_value_t::cinn_pod_value_t(int8_t value) : type_code_(::cinn_type_code()) { value_.v_int64 = value; } -cinn_pod_value_t::cinn_pod_value_t(int16_t value) : type_code_(::cinn_type_code()) { value_.v_int64 = value; } -cinn_pod_value_t::cinn_pod_value_t(int32_t value) : type_code_(::cinn_type_code()) { value_.v_int64 = value; } -cinn_pod_value_t::cinn_pod_value_t(int64_t value) : type_code_(::cinn_type_code()) { value_.v_int64 = value; } +cinn_pod_value_t::cinn_pod_value_t(int8_t value) + : type_code_(::cinn_type_code()) { + value_.v_int64 = value; +} +cinn_pod_value_t::cinn_pod_value_t(int16_t value) + : type_code_(::cinn_type_code()) { + value_.v_int64 = value; +} +cinn_pod_value_t::cinn_pod_value_t(int32_t value) + : type_code_(::cinn_type_code()) { + value_.v_int64 = value; +} +cinn_pod_value_t::cinn_pod_value_t(int64_t value) + : type_code_(::cinn_type_code()) { + value_.v_int64 = value; +} -cinn_pod_value_t::cinn_pod_value_t(uint8_t value) : type_code_(::cinn_type_code()) { value_.v_int64 = value; } -cinn_pod_value_t::cinn_pod_value_t(uint16_t value) : type_code_(::cinn_type_code()) { +cinn_pod_value_t::cinn_pod_value_t(uint8_t value) + : type_code_(::cinn_type_code()) { value_.v_int64 = value; } -cinn_pod_value_t::cinn_pod_value_t(uint32_t value) : type_code_(::cinn_type_code()) { +cinn_pod_value_t::cinn_pod_value_t(uint16_t value) + : type_code_(::cinn_type_code()) { value_.v_int64 = value; } -cinn_pod_value_t::cinn_pod_value_t(uint64_t value) : type_code_(::cinn_type_code()) { +cinn_pod_value_t::cinn_pod_value_t(uint32_t value) + : type_code_(::cinn_type_code()) { + value_.v_int64 = value; +} +cinn_pod_value_t::cinn_pod_value_t(uint64_t value) + : type_code_(::cinn_type_code()) { value_.v_int64 = value; } -cinn_pod_value_t::cinn_pod_value_t(float value) : type_code_(::cinn_type_code()) { value_.v_float64 = value; } -cinn_pod_value_t::cinn_pod_value_t(bfloat16 value) : type_code_(::cinn_type_code()) { +cinn_pod_value_t::cinn_pod_value_t(float value) + : type_code_(::cinn_type_code()) { + value_.v_float64 = value; +} +cinn_pod_value_t::cinn_pod_value_t(bfloat16 value) + : type_code_(::cinn_type_code()) { + value_.v_float64 = value; +} +cinn_pod_value_t::cinn_pod_value_t(float16 value) + : type_code_(::cinn_type_code()) { value_.v_float64 = value; } -cinn_pod_value_t::cinn_pod_value_t(float16 value) : type_code_(::cinn_type_code()) { +cinn_pod_value_t::cinn_pod_value_t(double value) + : type_code_(::cinn_type_code()) { value_.v_float64 = value; } -cinn_pod_value_t::cinn_pod_value_t(double value) : type_code_(::cinn_type_code()) { value_.v_float64 = value; } -cinn_pod_value_t::cinn_pod_value_t(void* value) : type_code_(::cinn_type_code()) { value_.v_handle = value; } -cinn_pod_value_t::cinn_pod_value_t(const char* value) : type_code_(::cinn_type_code()) { +cinn_pod_value_t::cinn_pod_value_t(void* value) + : type_code_(::cinn_type_code()) { + value_.v_handle = value; +} +cinn_pod_value_t::cinn_pod_value_t(const char* value) + : type_code_(::cinn_type_code()) { value_.v_handle = const_cast(value); } @@ -303,28 +375,58 @@ uint8_t cinn_pod_value_to_uint8(cinn_pod_value_t* value) { return *value; } bool cinn_pod_value_to_bool(cinn_pod_value_t* value) { return *value; } void* cinn_pod_value_to_void_p(cinn_pod_value_t* value) { return *value; } -cinn_buffer_t* cinn_pod_value_to_buffer_p(cinn_pod_value_t* value) { return *value; } +cinn_buffer_t* cinn_pod_value_to_buffer_p(cinn_pod_value_t* value) { + return *value; +} // @} // @{ -void float_to_cinn_pod_value(float v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } -void bfloat16_to_cinn_pod_value(bfloat16 v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } -void float16_to_cinn_pod_value(float16 v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } -void double_to_cinn_pod_value(double v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } +void float_to_cinn_pod_value(float v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} +void bfloat16_to_cinn_pod_value(bfloat16 v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} +void float16_to_cinn_pod_value(float16 v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} +void double_to_cinn_pod_value(double v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} -void bool_to_cinn_pod_value(bool v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } +void bool_to_cinn_pod_value(bool v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} -void int8_to_cinn_pod_value(int8_t v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } -void int16_to_cinn_pod_value(int16_t v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } -void int32_to_cinn_pod_value(int32_t v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } -void int64_to_cinn_pod_value(int64_t v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } +void int8_to_cinn_pod_value(int8_t v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} +void int16_to_cinn_pod_value(int16_t v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} +void int32_to_cinn_pod_value(int32_t v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} +void int64_to_cinn_pod_value(int64_t v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} -void uint8_to_cinn_pod_value(uint8_t v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } -void uint16_to_cinn_pod_value(uint16_t v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } -void uint32_to_cinn_pod_value(uint32_t v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } -void uint64_to_cinn_pod_value(uint64_t v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } +void uint8_to_cinn_pod_value(uint8_t v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} +void uint16_to_cinn_pod_value(uint16_t v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} +void uint32_to_cinn_pod_value(uint32_t v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} +void uint64_to_cinn_pod_value(uint64_t v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} -void handle_to_cinn_pod_value(void* v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(v); } +void handle_to_cinn_pod_value(void* v, cinn_pod_value_t* out) { + *out = cinn_pod_value_t(v); +} void buffer_p_to_cinn_pod_value(const cinn_buffer_t* v, cinn_pod_value_t* out) { *out = cinn_pod_value_t(const_cast(v)); } @@ -385,7 +487,7 @@ void cinn_args_construct(cinn_pod_value_t* arr, int count, ...) { va_start(args, count); for (int i = 0; i < count; i++) { cinn_pod_value_t* elem_addr = va_arg(args, cinn_pod_value_t*); - arr[i] = *elem_addr; + arr[i] = *elem_addr; // debug_pod_value(*elem_addr, i); } va_end(args); diff --git a/paddle/cinn/runtime/cinn_runtime.h b/paddle/cinn/runtime/cinn_runtime.h index cdd75f7895d93..39ed8cbe5ee09 100755 --- a/paddle/cinn/runtime/cinn_runtime.h +++ b/paddle/cinn/runtime/cinn_runtime.h @@ -13,8 +13,8 @@ // limitations under the License. /** - * This file contains some core runtime concepts, the basic definition is used in C so that it can be deployed in some - * light-weight devices. + * This file contains some core runtime concepts, the basic definition is used + * in C so that it can be deployed in some light-weight devices. */ #ifndef CINN_RUNTIME_CINN_RUNTIME_H_ #define CINN_RUNTIME_CINN_RUNTIME_H_ @@ -49,12 +49,12 @@ extern "C" { //! Code for the primitive types supported in CINN. typedef enum cinn_type_code_t { - cinn_type_unk = -1, //! Unknown type - cinn_type_int = 0, //! signed int - cinn_type_uint = 1, //! unsigned int - cinn_type_float = 2, //! floating point - cinn_type_handle = 3, //! void* - cinn_type_bfloat = 4 //! bfloat16 + cinn_type_unk = -1, //! Unknown type + cinn_type_int = 0, //! signed int + cinn_type_uint = 1, //! unsigned int + cinn_type_float = 2, //! floating point + cinn_type_handle = 3, //! void* + cinn_type_bfloat = 4 //! bfloat16 } cinn_type_code_t; #ifndef CINN_ATTRIBUTE_ALIGN @@ -77,17 +77,23 @@ typedef struct cinn_type_t { //! Number of elements in a vector, 1 for scalar. uint16_t lanes; - //! Number of '*', e.g. for `float*`, the num_asterisks is 1, `float**` it is 2. + //! Number of '*', e.g. for `float*`, the num_asterisks is 1, `float**` it + //! is 2. uint8_t num_asterisks{0}; #ifdef __cplusplus CINN_ALWAYS_INLINE cinn_type_t() : code(cinn_type_int), bits(0), lanes(0) {} - CINN_ALWAYS_INLINE cinn_type_t(cinn_type_code_t code, uint8_t bits, uint16_t lanes = 1, uint8_t num_asterisks = 0) + CINN_ALWAYS_INLINE cinn_type_t(cinn_type_code_t code, + uint8_t bits, + uint16_t lanes = 1, + uint8_t num_asterisks = 0) : code(code), bits(bits), lanes(lanes), num_asterisks(num_asterisks) {} CINN_ALWAYS_INLINE bool operator==(const cinn_type_t& other) const { return code == other.code && bits == other.bits && lanes == other.lanes; } - CINN_ALWAYS_INLINE bool operator!=(const cinn_type_t& other) const { return !(*this == other); } + CINN_ALWAYS_INLINE bool operator!=(const cinn_type_t& other) const { + return !(*this == other); + } CINN_ALWAYS_INLINE uint16_t bytes() const { return (bits + 7) / 8; } #endif // __cplusplus } cinn_type_t; @@ -113,21 +119,21 @@ extern cinn_type_t cinn_float32_t(int num_asterisks = 0); extern cinn_type_t cinn_float64_t(int num_asterisks = 0); // @} -//! Help to define the size of a dimension, due to polyhedral representation, we no need to record the extend or -//! min(default to 0). +//! Help to define the size of a dimension, due to polyhedral representation, we +//! no need to record the extend or min(default to 0). typedef int cinn_dimension_t; //! Help to tell the kind of the device. typedef enum cinn_device_kind_t { - cinn_unk_device = -1, // Undefined device. - cinn_x86_device = 0, // X86 device - cinn_opencl_device = 1, // OpenCL device - cinn_arm_device = 2 // ARM device + cinn_unk_device = -1, // Undefined device. + cinn_x86_device = 0, // X86 device + cinn_opencl_device = 1, // OpenCL device + cinn_arm_device = 2 // ARM device } cinn_device_kind_t; //! Help to tell where the buffer locates. typedef enum cinn_buffer_kind_t { - cinn_buffer_on_host = 0, //! buffer on host + cinn_buffer_on_host = 0, //! buffer on host cinn_buffer_on_device = 1 << 1 // ! buffer on device e.g. GPU. } cinn_buffer_kind_t; @@ -142,17 +148,21 @@ struct cinn_device_interface_t { int (*malloc)(void* context, struct cinn_buffer_t* buf); int (*free)(void* context, struct cinn_buffer_t* buf); int (*sync)(void* context, struct cinn_buffer_t* buf); - int (*release)(void* context, const struct cinn_device_interface_t* device_interface); + int (*release)(void* context, + const struct cinn_device_interface_t* device_interface); int (*copy_to_host)(void* context, struct cinn_buffer_t* buf); int (*copy_to_device)(void* context, struct cinn_buffer_t* buf); - int (*buffer_copy)(void* context, struct cinn_buffer_t* src, struct cinn_buffer_t* dst); + int (*buffer_copy)(void* context, + struct cinn_buffer_t* src, + struct cinn_buffer_t* dst); struct cinn_device_interface_impl_t* impl; }; /** * Release all data associated with the given interface. */ -extern int cinn_device_release(void* context, const struct cinn_device_interface_t* device_interface); +extern int cinn_device_release( + void* context, const struct cinn_device_interface_t* device_interface); /* * Copy image data from device to host memory. @@ -163,7 +173,9 @@ extern int cinn_buffer_copy_to_host(void* context, struct cinn_buffer_t* buf); extern int cinn_buffer_copy_to_device(void* context, struct cinn_buffer_t* buf); //! Copy data from one buffer to another. -extern int cinn_buffer_copy(void* context, struct cinn_buffer_t* src, struct cinn_buffer_t* dst); +extern int cinn_buffer_copy(void* context, + struct cinn_buffer_t* src, + struct cinn_buffer_t* dst); //! Wait for current device operations to complete. extern int cinn_device_sync(void* context, struct cinn_buffer_t* buf); @@ -179,7 +191,9 @@ extern void* cinn_buffer_get_data_handle(struct cinn_buffer_t* buf); extern void* cinn_buffer_get_data_const_handle(const struct cinn_buffer_t* buf); //! Create a new default cinn_buffer. -extern cinn_buffer_t* cinn_buffer_new_default(int target, uint64_t memory_size, int align = 32); +extern cinn_buffer_t* cinn_buffer_new_default(int target, + uint64_t memory_size, + int align = 32); //! The raw representation of a buffer,used in the generated code/lib. #define CINN_BUFFER_MAX_DIMS 8 @@ -243,7 +257,8 @@ typedef struct cinn_buffer_t { // NOTE the buffer should be resized first. static void alloc(struct cinn_buffer_t*); - //! Set the shape of the buffer. NOTE this just record the shape, not allocate the memory. + //! Set the shape of the buffer. NOTE this just record the shape, not allocate + //! the memory. CINN_ALWAYS_INLINE void resize(const cinn_dimension_t* dims, int dimensions) { this->dimensions = dimensions; memcpy(this->dims, dims, dimensions * sizeof(cinn_dimension_t)); @@ -257,10 +272,18 @@ typedef struct cinn_buffer_t { return res; } - CINN_ALWAYS_INLINE bool on_host() const { return get_flag(cinn_buffer_on_host); } - CINN_ALWAYS_INLINE bool on_device() const { return get_flag(cinn_buffer_on_device); } - CINN_ALWAYS_INLINE void set_on_host(bool x = true) { set_flag(cinn_buffer_on_host, x); } - CINN_ALWAYS_INLINE void set_on_device(bool x = true) { set_flag(cinn_buffer_on_device, x); } + CINN_ALWAYS_INLINE bool on_host() const { + return get_flag(cinn_buffer_on_host); + } + CINN_ALWAYS_INLINE bool on_device() const { + return get_flag(cinn_buffer_on_device); + } + CINN_ALWAYS_INLINE void set_on_host(bool x = true) { + set_flag(cinn_buffer_on_host, x); + } + CINN_ALWAYS_INLINE void set_on_device(bool x = true) { + set_flag(cinn_buffer_on_device, x); + } CINN_ALWAYS_INLINE int device_sync(void* ctx = NULL) { if (device_interface && device_interface->sync) { @@ -270,9 +293,13 @@ typedef struct cinn_buffer_t { } CINN_ALWAYS_INLINE uint8_t* begin() const { return 0; } - CINN_ALWAYS_INLINE uint8_t* end() const { return memory + num_elements() * type.bytes(); } + CINN_ALWAYS_INLINE uint8_t* end() const { + return memory + num_elements() * type.bytes(); + } - CINN_ALWAYS_INLINE bool get_flag(cinn_buffer_kind_t flag) const { return (this->flag & flag) != 0; } + CINN_ALWAYS_INLINE bool get_flag(cinn_buffer_kind_t flag) const { + return (this->flag & flag) != 0; + } CINN_ALWAYS_INLINE void set_flag(cinn_buffer_kind_t flag, bool value) { if (value) this->flag |= flag; @@ -305,22 +332,28 @@ struct cinn_device_interface_impl_t { int (*release)(void* context); int (*copy_to_host)(void* context, struct cinn_buffer_t* buf); int (*copy_to_device)(void* context, struct cinn_buffer_t* buf); - int (*buffer_copy)(void* context, struct cinn_buffer_t* src, struct cinn_buffer_t* dst); + int (*buffer_copy)(void* context, + struct cinn_buffer_t* src, + struct cinn_buffer_t* dst); }; // The device implementations extern struct cinn_device_interface_t* cinn_x86_device_interface(); -inline cinn::common::bfloat16 cinn_buffer_load_bfloat16(struct cinn_buffer_t* buf, uint32_t index) { +inline cinn::common::bfloat16 cinn_buffer_load_bfloat16( + struct cinn_buffer_t* buf, uint32_t index) { return ((cinn::common::bfloat16*)buf->memory)[index]; // NOLINT } -inline cinn::common::float16 cinn_buffer_load_float16(struct cinn_buffer_t* buf, uint32_t index) { +inline cinn::common::float16 cinn_buffer_load_float16(struct cinn_buffer_t* buf, + uint32_t index) { return ((cinn::common::float16*)buf->memory)[index]; // NOLINT } -inline float cinn_buffer_load_float32(struct cinn_buffer_t* buf, uint32_t index) { +inline float cinn_buffer_load_float32(struct cinn_buffer_t* buf, + uint32_t index) { return ((float*)buf->memory)[index]; // NOLINT } -inline double cinn_buffer_load_float64(struct cinn_buffer_t* buf, uint32_t index) { +inline double cinn_buffer_load_float64(struct cinn_buffer_t* buf, + uint32_t index) { return ((double*)buf->memory)[index]; // NOLINT } #endif // __cplusplus @@ -329,7 +362,8 @@ inline double cinn_buffer_load_float64(struct cinn_buffer_t* buf, uint32_t index extern "C" { #endif -CINN_ALWAYS_INLINE void* cinn_buffer_slice(struct cinn_buffer_t* buf, uint32_t offset); +CINN_ALWAYS_INLINE void* cinn_buffer_slice(struct cinn_buffer_t* buf, + uint32_t offset); #ifdef __cplusplus } @@ -355,9 +389,14 @@ static inline int32_t cinn_max(int32_t a, int32_t b) { return a > b ? a : b; } fprintf(stderr, #v__ " is null"); \ return -1; \ } -#define CINN_LOG(fmt, ...) \ - do { \ - fprintf(stderr, "%s:%d:%s(): " fmt, __FILE__, __LINE__, __func__, __VA_ARGS__); \ +#define CINN_LOG(fmt, ...) \ + do { \ + fprintf(stderr, \ + "%s:%d:%s(): " fmt, \ + __FILE__, \ + __LINE__, \ + __func__, \ + __VA_ARGS__); \ } while (0) #define CINN_CHECK(cond) \ @@ -527,7 +566,8 @@ cinn_buffer_t* cinn_pod_value_to_buffer_p(cinn_pod_value_t* value); //! other specific types to cinn_pod_value // @{ void float_to_cinn_pod_value(float v, cinn_pod_value_t* out); -void bfloat16_to_cinn_pod_value(cinn::common::bfloat16 v, cinn_pod_value_t* out); +void bfloat16_to_cinn_pod_value(cinn::common::bfloat16 v, + cinn_pod_value_t* out); void float16_to_cinn_pod_value(cinn::common::float16 v, cinn_pod_value_t* out); void double_to_cinn_pod_value(double v, cinn_pod_value_t* out); @@ -544,7 +584,8 @@ void uint32_to_cinn_pod_value(uint32_t v, cinn_pod_value_t* out); void uint64_to_cinn_pod_value(uint64_t v, cinn_pod_value_t* out); void handle_to_cinn_pod_value(void* v, cinn_pod_value_t* out); -void buffer_p_to_cinn_pod_value(const struct cinn_buffer_t* v, cinn_pod_value_t* out); +void buffer_p_to_cinn_pod_value(const struct cinn_buffer_t* v, + cinn_pod_value_t* out); // @} void cinn_print_debug_string(const char* s, ...); diff --git a/paddle/cinn/runtime/cinn_runtime_test.cc b/paddle/cinn/runtime/cinn_runtime_test.cc index 73bf9b7359e8b..3e6f9b67d4e34 100644 --- a/paddle/cinn/runtime/cinn_runtime_test.cc +++ b/paddle/cinn/runtime/cinn_runtime_test.cc @@ -17,14 +17,15 @@ #include TEST(buffer, basic) { - auto* buffer = cinn_buffer_t::new_(cinn_x86_device, cinn_float32_t(), {3, 10}); + auto* buffer = + cinn_buffer_t::new_(cinn_x86_device, cinn_float32_t(), {3, 10}); ASSERT_TRUE(buffer); ASSERT_TRUE(buffer->device_interface); ASSERT_EQ(buffer->device_interface, cinn_x86_device_interface()); buffer->device_interface->impl->malloc(NULL, buffer); auto* data = reinterpret_cast(buffer->memory); - data[0] = 0.f; - data[1] = 1.f; + data[0] = 0.f; + data[1] = 1.f; EXPECT_EQ(data[0], 0.f); EXPECT_EQ(data[1], 1.f); } diff --git a/paddle/cinn/runtime/cinn_x86_device_impl.cc b/paddle/cinn/runtime/cinn_x86_device_impl.cc index 3581251affa83..dd5bb812c70a9 100644 --- a/paddle/cinn/runtime/cinn_x86_device_impl.cc +++ b/paddle/cinn/runtime/cinn_x86_device_impl.cc @@ -54,13 +54,18 @@ int cinn_x86_free(void* context, cinn_buffer_t* buf) { return 0; } -// All the following operations are not support by X86 device, just leave them empty. +// All the following operations are not support by X86 device, just leave them +// empty. // @{ int cinn_x86_sync(void* context, cinn_buffer_t* buf) { return 0; } int cinn_x86_release(void* context) { return 0; } int cinn_x86_copy_to_host(void* context, cinn_buffer_t* buf) { return 0; } int cinn_x86_copy_to_device(void* context, cinn_buffer_t* buf) { return 0; } -int cinn_x86_buffer_copy(void* context, cinn_buffer_t* src, cinn_buffer_t* dst) { return 0; } +int cinn_x86_buffer_copy(void* context, + cinn_buffer_t* src, + cinn_buffer_t* dst) { + return 0; +} // @} cinn_device_interface_impl_t cinn_x86_device_impl{&cinn_x86_malloc, @@ -71,14 +76,15 @@ cinn_device_interface_impl_t cinn_x86_device_impl{&cinn_x86_malloc, &cinn_x86_copy_to_device, &cinn_x86_buffer_copy}; -cinn_device_interface_t cinn_x86_device_interface_interface{&cinn_buffer_malloc, - &cinn_buffer_free, - &cinn_device_sync, - &cinn_device_release, - &cinn_buffer_copy_to_host, - &cinn_buffer_copy_to_device, - &cinn_buffer_copy, - &cinn_x86_device_impl}; +cinn_device_interface_t cinn_x86_device_interface_interface{ + &cinn_buffer_malloc, + &cinn_buffer_free, + &cinn_device_sync, + &cinn_device_release, + &cinn_buffer_copy_to_host, + &cinn_buffer_copy_to_device, + &cinn_buffer_copy, + &cinn_x86_device_impl}; struct cinn_device_interface_t* cinn_x86_device_interface() { return &cinn_x86_device_interface_interface; diff --git a/paddle/cinn/runtime/cpu/cblas.cc b/paddle/cinn/runtime/cpu/cblas.cc index 8c0594f49b3d3..8a9f7be63083c 100644 --- a/paddle/cinn/runtime/cpu/cblas.cc +++ b/paddle/cinn/runtime/cpu/cblas.cc @@ -21,7 +21,9 @@ namespace { -inline CBLAS_TRANSPOSE ToCblasTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; } +inline CBLAS_TRANSPOSE ToCblasTranspose(bool trans) { + return trans ? CblasTrans : CblasNoTrans; +} } // namespace @@ -101,20 +103,22 @@ void cinn_cpu_mkl_gemm_batch_fp32(float alpha, } /** - * This function is temporarily unavailable, see the error message in the following PR for details. - * The specific reason may be that the custom call does not support host op. - * See: https://github.com/PaddlePaddle/CINN/pull/1133 + * This function is temporarily unavailable, see the error message in the + * following PR for details. The specific reason may be that the custom call + * does not support host op. See: https://github.com/PaddlePaddle/CINN/pull/1133 */ -void cinn_call_cholesky_host(void* v_args, int num_args, int batch_size, int m, bool upper) { +void cinn_call_cholesky_host( + void* v_args, int num_args, int batch_size, int m, bool upper) { #ifdef CINN_WITH_MKL_CBLAS cinn_pod_value_t* args = static_cast(v_args); - cinn_buffer_t* x = args[0].operator cinn_buffer_t*(); + cinn_buffer_t* x = args[0].operator cinn_buffer_t*(); cinn_buffer_t* out = args[1].operator cinn_buffer_t*(); memcpy(out->memory, x->memory, x->memory_size); uint8_t bits = x->type.bits; - CHECK(bits == 32 || bits == 64) << "Unsupported bits = " << bits << " float data type for cholesky"; + CHECK(bits == 32 || bits == 64) + << "Unsupported bits = " << bits << " float data type for cholesky"; char uplo = upper ? 'U' : 'L'; for (int i = 0; i < batch_size; i++) { if (bits == 32) { @@ -135,43 +139,45 @@ CINN_REGISTER_HELPER(cinn_cpu_mkl) { using backends::FunctionProto; auto host_target = common::DefaultHostTarget(); - FunctionProto::shape_inference_t inference_shape_gemm = [](const std::vector& args, int offset) { - CHECK_EQ(offset, 0UL) << "Only one output"; - CHECK_EQ(args.size(), 12UL) << "Wrong number of arguments passed in"; - auto M = common::AutoSimplify(args[1]); - auto N = common::AutoSimplify(args[2]); - std::vector shape; - shape.push_back(M); - shape.push_back(N); - return shape; - }; - - FunctionProto::shape_inference_t inference_shape_gemm_batch = [](const std::vector& args, int offset) { - CHECK_EQ(offset, 0UL) << "Only one output"; - CHECK_EQ(args.size(), 16UL) << "Wrong number of arguments passed in"; - auto& A = args[14]; - auto A_tensor = A.as_tensor(); - CHECK(A_tensor); - - auto batch_size = common::AutoSimplify(args[1]); - int32_t batch_size_val = batch_size.as_int32(); - - auto M = common::AutoSimplify(args[2]); - auto N = common::AutoSimplify(args[3]); - - std::vector shape; - int total = 1; - for (auto& v : A_tensor->shape) { - auto val = common::AutoSimplify(v); - CHECK(val.is_constant()); - shape.push_back(val); - total *= val.as_int32(); - if (total >= batch_size_val) break; - } - shape.push_back(M); - shape.push_back(N); - return shape; - }; + FunctionProto::shape_inference_t inference_shape_gemm = + [](const std::vector& args, int offset) { + CHECK_EQ(offset, 0UL) << "Only one output"; + CHECK_EQ(args.size(), 12UL) << "Wrong number of arguments passed in"; + auto M = common::AutoSimplify(args[1]); + auto N = common::AutoSimplify(args[2]); + std::vector shape; + shape.push_back(M); + shape.push_back(N); + return shape; + }; + + FunctionProto::shape_inference_t inference_shape_gemm_batch = + [](const std::vector& args, int offset) { + CHECK_EQ(offset, 0UL) << "Only one output"; + CHECK_EQ(args.size(), 16UL) << "Wrong number of arguments passed in"; + auto& A = args[14]; + auto A_tensor = A.as_tensor(); + CHECK(A_tensor); + + auto batch_size = common::AutoSimplify(args[1]); + int32_t batch_size_val = batch_size.as_int32(); + + auto M = common::AutoSimplify(args[2]); + auto N = common::AutoSimplify(args[3]); + + std::vector shape; + int total = 1; + for (auto& v : A_tensor->shape) { + auto val = common::AutoSimplify(v); + CHECK(val.is_constant()); + shape.push_back(val); + total *= val.as_int32(); + if (total >= batch_size_val) break; + } + shape.push_back(M); + shape.push_back(N); + return shape; + }; REGISTER_EXTERN_FUNC_HELPER(cinn_cpu_mkl_gemm_fp32, host_target) .SetRetType() diff --git a/paddle/cinn/runtime/cpu/cblas.h b/paddle/cinn/runtime/cpu/cblas.h index 96126545113f3..7e8249f04504a 100644 --- a/paddle/cinn/runtime/cpu/cblas.h +++ b/paddle/cinn/runtime/cpu/cblas.h @@ -28,8 +28,9 @@ extern "C" { /** * \brief Do GEMM on buffer A and B and write result to buffer C. - * We pass the \param M, \param N, \param K although the shape can retrieve from cinn_buffer_t because the size of a - * matrix not equals the shape of a buffer it is stored. + * We pass the \param M, \param N, \param K although the shape can retrieve from + * cinn_buffer_t because the size of a matrix not equals the shape of a buffer + * it is stored. * @param alpha The scaling factor of the product of A and B * @param M Number of the rows of A * @param N the number of the columns in both B and C @@ -60,8 +61,9 @@ void cinn_cpu_mkl_gemm_fp32(float alpha, /** * \brief Do GEMM on buffer A and B and write result to buffer C. - * We pass the \param M, \param N, \param K although the shape can retrieve from cinn_buffer_t because the size of a - * matrix not equals the shape of a buffer it is stored. + * We pass the \param M, \param N, \param K although the shape can retrieve from + * cinn_buffer_t because the size of a matrix not equals the shape of a buffer + * it is stored. * @param alpha The scaling factor of the product of A and B * @param batch_size the batch size of A and B * @param M Number of the rows of A @@ -72,9 +74,12 @@ void cinn_cpu_mkl_gemm_fp32(float alpha, * @param lda The size of the first dimension of A * @param ldb The size of the first dimension of B * @param ldc The size of the first dimension of C - * @param a_stride The stride of A(number of elements, not bytes) between batches - * @param b_stride The stride of B(number of elements, not bytes) between batches - * @param c_stride The stride of C(number of elements, not bytes) between batches + * @param a_stride The stride of A(number of elements, not bytes) between + * batches + * @param b_stride The stride of B(number of elements, not bytes) between + * batches + * @param c_stride The stride of C(number of elements, not bytes) between + * batches * @param beta The scaling factor of C * @param A The matrix A * @param B The matrix B @@ -98,5 +103,6 @@ void cinn_cpu_mkl_gemm_batch_fp32(float alpha, cinn_buffer_t* B, cinn_buffer_t* C); -void cinn_call_cholesky_host(void* v_args, int num_args, int batch_size, int m, bool upper); +void cinn_call_cholesky_host( + void* v_args, int num_args, int batch_size, int m, bool upper); } // extern "C" diff --git a/paddle/cinn/runtime/cpu/host_intrinsics.cc b/paddle/cinn/runtime/cpu/host_intrinsics.cc index 4bfa08fbe2211..5f8d68e8d9cee 100644 --- a/paddle/cinn/runtime/cpu/host_intrinsics.cc +++ b/paddle/cinn/runtime/cpu/host_intrinsics.cc @@ -30,20 +30,21 @@ extern "C" { void __cinn_host_tanh_v(const cinn_buffer_t* x, cinn_buffer_t* out) { CINN_CHECK_EQ(x->num_elements(), out->num_elements()); - int xn = x->num_elements(); - auto* x_data = (float*)(x->memory); + int xn = x->num_elements(); + auto* x_data = (float*)(x->memory); auto* out_data = (float*)(out->memory); for (int i = 0; i < x->num_elements(); i++) { out_data[i] = tanhf(x_data[i]); } } -#define __cinn_host_find_kernel(buf, size, num, type, begin, stride) \ - do { \ - for (int i = (size - 1) * stride + begin; i >= begin; i -= stride) { \ - if (reinterpret_cast(buf->memory)[i] == num) return (i - begin) / stride; \ - } \ - return -1; \ +#define __cinn_host_find_kernel(buf, size, num, type, begin, stride) \ + do { \ + for (int i = (size - 1) * stride + begin; i >= begin; i -= stride) { \ + if (reinterpret_cast(buf->memory)[i] == num) \ + return (i - begin) / stride; \ + } \ + return -1; \ } while (0) inline int cinn_host_find_int(const cinn_buffer_t* buf, int size, int num) { @@ -54,20 +55,24 @@ inline int cinn_host_find_float(const cinn_buffer_t* buf, int size, float num) { __cinn_host_find_kernel(buf, size, num, float, 0, 1); } -inline int cinn_host_find_int_nd(const cinn_buffer_t* buf, int size, int num, int begin, int stride) { +inline int cinn_host_find_int_nd( + const cinn_buffer_t* buf, int size, int num, int begin, int stride) { __cinn_host_find_kernel(buf, size, num, int, begin, stride); } -inline int cinn_host_find_float_nd(const cinn_buffer_t* buf, int size, float num, int begin, int stride) { +inline int cinn_host_find_float_nd( + const cinn_buffer_t* buf, int size, float num, int begin, int stride) { __cinn_host_find_kernel(buf, size, num, float, begin, stride); } #undef __cinn_host_find_kernel -inline int cinn_host_next_smallest_int32(cinn_buffer_t* buf, int size, int num, int begin, int stride) { +inline int cinn_host_next_smallest_int32( + cinn_buffer_t* buf, int size, int num, int begin, int stride) { int id = -1; for (int i = begin; i < begin + size * stride; i += stride) { - if (id == -1 || reinterpret_cast(buf->memory)[i] < reinterpret_cast(buf->memory)[id]) { + if (id == -1 || reinterpret_cast(buf->memory)[i] < + reinterpret_cast(buf->memory)[id]) { id = i; } } @@ -78,14 +83,17 @@ inline int cinn_host_next_smallest_int32(cinn_buffer_t* buf, int size, int num, return -1; } -#define CINN_HOST_LT_NUM(TYPE_SUFFIX, TYPE) \ - inline int cinn_host_lt_num_##TYPE_SUFFIX( \ - const cinn_buffer_t* buf, const int size, const TYPE num, const int offset, const int stride) { \ - int out = 0; \ - for (int i = (size - 1) * stride + offset; i >= offset; i -= stride) { \ - if (reinterpret_cast(buf->memory)[i] < num) out++; \ - } \ - return out; \ +#define CINN_HOST_LT_NUM(TYPE_SUFFIX, TYPE) \ + inline int cinn_host_lt_num_##TYPE_SUFFIX(const cinn_buffer_t* buf, \ + const int size, \ + const TYPE num, \ + const int offset, \ + const int stride) { \ + int out = 0; \ + for (int i = (size - 1) * stride + offset; i >= offset; i -= stride) { \ + if (reinterpret_cast(buf->memory)[i] < num) out++; \ + } \ + return out; \ } CINN_HOST_LT_NUM(fp32, float) @@ -95,14 +103,17 @@ CINN_HOST_LT_NUM(int64, int64_t) #undef CINN_HOST_LT_NUM -#define CINN_HOST_GT_NUM(TYPE_SUFFIX, TYPE) \ - inline int cinn_host_gt_num_##TYPE_SUFFIX( \ - const cinn_buffer_t* buf, const int size, const TYPE num, const int offset, const int stride) { \ - int out = 0; \ - for (int i = (size - 1) * stride + offset; i >= offset; i -= stride) { \ - if (reinterpret_cast(buf->memory)[i] > num) out++; \ - } \ - return out; \ +#define CINN_HOST_GT_NUM(TYPE_SUFFIX, TYPE) \ + inline int cinn_host_gt_num_##TYPE_SUFFIX(const cinn_buffer_t* buf, \ + const int size, \ + const TYPE num, \ + const int offset, \ + const int stride) { \ + int out = 0; \ + for (int i = (size - 1) * stride + offset; i >= offset; i -= stride) { \ + if (reinterpret_cast(buf->memory)[i] > num) out++; \ + } \ + return out; \ } CINN_HOST_GT_NUM(fp32, float) @@ -125,28 +136,29 @@ int cinn_host_resize_bilinear(const cinn_buffer_t* buf, // same with paddle resize when use cv2 backend float scale_y = static_cast(in_h) / out_h; float scale_x = static_cast(in_w) / out_w; - float in_y = (y + 0.5F) * scale_y - 0.5F; - float in_x = (x + 0.5F) * scale_x - 0.5F; - int in_y_int = static_cast(std::floor(in_y)); - int in_x_int = static_cast(std::floor(in_x)); - float y_lerp = in_y - in_y_int; - float x_lerp = in_x - in_x_int; + float in_y = (y + 0.5F) * scale_y - 0.5F; + float in_x = (x + 0.5F) * scale_x - 0.5F; + int in_y_int = static_cast(std::floor(in_y)); + int in_x_int = static_cast(std::floor(in_x)); + float y_lerp = in_y - in_y_int; + float x_lerp = in_x - in_x_int; float p[2][2]; for (int i = 0; i < 2; ++i) { for (int j = 0; j < 2; ++j) { int near_y = in_y_int + i; int near_x = in_x_int + j; - near_y = std::max(std::min(near_y, in_h - 1), 0); - near_x = std::max(std::min(near_x, in_w - 1), 0); - p[i][j] = - reinterpret_cast(buf->memory)[n * c_size * in_h * in_w + c * in_h * in_w + near_y * in_w + near_x]; + near_y = std::max(std::min(near_y, in_h - 1), 0); + near_x = std::max(std::min(near_x, in_w - 1), 0); + p[i][j] = reinterpret_cast( + buf->memory)[n * c_size * in_h * in_w + c * in_h * in_w + + near_y * in_w + near_x]; } } - float top = p[0][0] * (1.0F - x_lerp) + p[0][1] * x_lerp; + float top = p[0][0] * (1.0F - x_lerp) + p[0][1] * x_lerp; float bottom = p[1][0] * (1.0F - x_lerp) + p[1][1] * x_lerp; - float value = top * (1.0F - y_lerp) + bottom * y_lerp; + float value = top * (1.0F - y_lerp) + bottom * y_lerp; return value; } @@ -163,10 +175,10 @@ int cinn_host_resize_bicubic(const cinn_buffer_t* buf, // same with paddle resize when use cv2 backend float scale_y = static_cast(in_h) / out_h; float scale_x = static_cast(in_w) / out_w; - float in_y = (y + 0.5F) * scale_y - 0.5F; - float in_x = (x + 0.5F) * scale_x - 0.5F; - int in_y_int = static_cast(std::floor(in_y)); - int in_x_int = static_cast(std::floor(in_x)); + float in_y = (y + 0.5F) * scale_y - 0.5F; + float in_x = (x + 0.5F) * scale_x - 0.5F; + int in_y_int = static_cast(std::floor(in_y)); + int in_x_int = static_cast(std::floor(in_x)); float y_fract = in_y - std::floor(in_y); float x_fract = in_x - std::floor(in_x); float p[4][4]; @@ -175,10 +187,11 @@ int cinn_host_resize_bicubic(const cinn_buffer_t* buf, for (int j = 0; j < 4; ++j) { int near_y = in_y_int + i - 1; int near_x = in_x_int + j - 1; - near_y = std::max(std::min(near_y, in_h - 1), 0); - near_x = std::max(std::min(near_x, in_w - 1), 0); - p[i][j] = - reinterpret_cast(buf->memory)[n * c_size * in_h * in_w + c * in_h * in_w + near_y * in_w + near_x]; + near_y = std::max(std::min(near_y, in_h - 1), 0); + near_x = std::max(std::min(near_x, in_w - 1), 0); + p[i][j] = reinterpret_cast( + buf->memory)[n * c_size * in_h * in_w + c * in_h * in_w + + near_y * in_w + near_x]; } } @@ -186,13 +199,13 @@ int cinn_host_resize_bicubic(const cinn_buffer_t* buf, float w[2][4]; for (int i = 0; i < 2; ++i) { - float t = (i == 0 ? x_fract : y_fract); + float t = (i == 0 ? x_fract : y_fract); float t2 = t * t; float t3 = t * t * t; - w[i][0] = alpha * (t3 - 2 * t2 + t); - w[i][1] = (alpha + 2) * t3 - (3 + alpha) * t2 + 1; - w[i][2] = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t; - w[i][3] = -alpha * t3 + alpha * t2; + w[i][0] = alpha * (t3 - 2 * t2 + t); + w[i][1] = (alpha + 2) * t3 - (3 + alpha) * t2 + 1; + w[i][2] = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t; + w[i][3] = -alpha * t3 + alpha * t2; } float col[4]; @@ -242,7 +255,9 @@ inline int FN_INT32(clz)(int x) { return __builtin_clz(x); } inline int FN_INT32(popc)(int x) { return __builtin_popcount(x); } -inline int FN_INT32(logical_right_shift)(int x, int y) { return ((unsigned int)x >> y); } +inline int FN_INT32(logical_right_shift)(int x, int y) { + return ((unsigned int)x >> y); +} #undef FN_INT32 @@ -254,7 +269,9 @@ inline int64_t FN_INT64(popc)(int64_t x) { return __builtin_popcountll(x); } inline int64_t FN_INT64(pow)(int64_t x, int64_t y) { return pow(x, y); } -inline int64_t FN_INT64(logical_right_shift)(int64_t x, int64_t y) { return ((uint64_t)x >> y); } +inline int64_t FN_INT64(logical_right_shift)(int64_t x, int64_t y) { + return ((uint64_t)x >> y); +} #undef FN_INT64 } // extern "C" @@ -262,8 +279,16 @@ inline int64_t FN_INT64(logical_right_shift)(int64_t x, int64_t y) { return ((ui namespace cinn { namespace runtime { -void cinn_assert_true_host(void* v_args, int num_args, int msg, bool only_warning) { - cinn::runtime::cinn_assert_true(v_args, num_args, msg, only_warning, nullptr, cinn::common::DefaultHostTarget()); +void cinn_assert_true_host(void* v_args, + int num_args, + int msg, + bool only_warning) { + cinn::runtime::cinn_assert_true(v_args, + num_args, + msg, + only_warning, + nullptr, + cinn::common::DefaultHostTarget()); } } // namespace runtime } // namespace cinn @@ -272,7 +297,8 @@ CINN_REGISTER_HELPER(host_intrinsics) { auto host_target = cinn::common::DefaultHostTarget(); using cinn::backends::FunctionProto; -#define REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32(func__) REGISTER_EXTERN_FUNC_1_IN_1_OUT(func__, host_target, float, float); +#define REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32(func__) \ + REGISTER_EXTERN_FUNC_1_IN_1_OUT(func__, host_target, float, float); REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32(erff); REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32(acosf); @@ -297,28 +323,32 @@ CINN_REGISTER_HELPER(host_intrinsics) { #undef REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32_INT -#define REGISTER_EXTERN_FUNC_2_IN_1_F(func__) REGISTER_EXTERN_FUNC_2_IN_1_OUT(func__, host_target, float, float, float); +#define REGISTER_EXTERN_FUNC_2_IN_1_F(func__) \ + REGISTER_EXTERN_FUNC_2_IN_1_OUT(func__, host_target, float, float, float); REGISTER_EXTERN_FUNC_2_IN_1_F(powf) #undef REGISTER_EXTERN_FUNC_2_IN_1_F #define REGISTER_EXTERN_FUNC_2_IN_1_FP32(func__) \ - REGISTER_EXTERN_FUNC_2_IN_1_OUT(cinn_host_##func__##_fp32, host_target, float, float, float); + REGISTER_EXTERN_FUNC_2_IN_1_OUT( \ + cinn_host_##func__##_fp32, host_target, float, float, float); REGISTER_EXTERN_FUNC_2_IN_1_FP32(pow) #undef REGISTER_EXTERN_FUNC_2_IN_1_FP32 #define REGISTER_EXTERN_FUNC_2_IN_1_FP64(func__) \ - REGISTER_EXTERN_FUNC_2_IN_1_OUT(cinn_host_##func__##_fp64, host_target, double, double, double); + REGISTER_EXTERN_FUNC_2_IN_1_OUT( \ + cinn_host_##func__##_fp64, host_target, double, double, double); REGISTER_EXTERN_FUNC_2_IN_1_FP64(pow) #undef REGISTER_EXTERN_FUNC_2_IN_1_FP64 #define REGISTER_EXTERN_FUNC_2_IN_1_INT32(func__) \ - REGISTER_EXTERN_FUNC_2_IN_1_OUT(cinn_host_##func__##_int32, host_target, int, int, int); + REGISTER_EXTERN_FUNC_2_IN_1_OUT( \ + cinn_host_##func__##_int32, host_target, int, int, int); REGISTER_EXTERN_FUNC_2_IN_1_INT32(pow) @@ -327,7 +357,8 @@ CINN_REGISTER_HELPER(host_intrinsics) { #undef REGISTER_EXTERN_FUNC_2_IN_1_INT32 #define REGISTER_EXTERN_FUNC_2_IN_1_INT64(func__) \ - REGISTER_EXTERN_FUNC_2_IN_1_OUT(cinn_host_##func__##_int64, host_target, int64_t, int64_t, int64_t); + REGISTER_EXTERN_FUNC_2_IN_1_OUT( \ + cinn_host_##func__##_int64, host_target, int64_t, int64_t, int64_t); REGISTER_EXTERN_FUNC_2_IN_1_INT64(pow) @@ -337,11 +368,13 @@ CINN_REGISTER_HELPER(host_intrinsics) { REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_host_clz_int32, host_target, int, int); - REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_host_clz_int64, host_target, int64_t, int64_t); + REGISTER_EXTERN_FUNC_1_IN_1_OUT( + cinn_host_clz_int64, host_target, int64_t, int64_t); REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_host_popc_int32, host_target, int, int); - REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_host_popc_int64, host_target, int64_t, int64_t); + REGISTER_EXTERN_FUNC_1_IN_1_OUT( + cinn_host_popc_int64, host_target, int64_t, int64_t); REGISTER_EXTERN_FUNC_HELPER(cinn_host_find_int, host_target) .SetRetType() @@ -446,7 +479,8 @@ CINN_REGISTER_HELPER(host_intrinsics) { .AddInputType() .End(); - // TODO(thisjiang): change msg type from 'int' to 'std::string' when custom call support 'std::string' type + // TODO(thisjiang): change msg type from 'int' to 'std::string' when custom + // call support 'std::string' type using cinn::runtime::cinn_assert_true_host; REGISTER_EXTERN_FUNC_HELPER(cinn_assert_true_host, host_target) .SetRetType() diff --git a/paddle/cinn/runtime/cpu/host_intrinsics.h b/paddle/cinn/runtime/cpu/host_intrinsics.h index 057e60316ef2c..5135c43f082ce 100644 --- a/paddle/cinn/runtime/cpu/host_intrinsics.h +++ b/paddle/cinn/runtime/cpu/host_intrinsics.h @@ -14,7 +14,8 @@ #pragma once /** - * \file This file implements some intrinsic functions for math operation in host device. + * \file This file implements some intrinsic functions for math operation in + * host device. */ #include "paddle/cinn/runtime/cinn_runtime.h" @@ -29,13 +30,18 @@ inline int cinn_host_find_int(const cinn_buffer_t* buf, int size, int num); inline int cinn_host_find_float(const cinn_buffer_t* buf, int size, float num); -inline int cinn_host_find_int_nd(const cinn_buffer_t* buf, int size, int num, int begin, int stride); +inline int cinn_host_find_int_nd( + const cinn_buffer_t* buf, int size, int num, int begin, int stride); -inline int cinn_host_find_float_nd(const cinn_buffer_t* buf, int size, float num, int begin, int stride); +inline int cinn_host_find_float_nd( + const cinn_buffer_t* buf, int size, float num, int begin, int stride); -#define CINN_HOST_LT_NUM(TYPE_SUFFIX, TYPE) \ - inline int cinn_host_lt_num_##TYPE_SUFFIX( \ - const cinn_buffer_t* buf, const int size, const TYPE num, const int offset, const int stride); +#define CINN_HOST_LT_NUM(TYPE_SUFFIX, TYPE) \ + inline int cinn_host_lt_num_##TYPE_SUFFIX(const cinn_buffer_t* buf, \ + const int size, \ + const TYPE num, \ + const int offset, \ + const int stride); CINN_HOST_LT_NUM(fp32, float) CINN_HOST_LT_NUM(fp64, double) @@ -44,9 +50,12 @@ CINN_HOST_LT_NUM(int64, int64_t) #undef CINN_HOST_LT_NUM -#define CINN_HOST_GT_NUM(TYPE_SUFFIX, TYPE) \ - inline int cinn_host_gt_num_##TYPE_SUFFIX( \ - const cinn_buffer_t* buf, const int size, const TYPE num, const int offset, const int stride); +#define CINN_HOST_GT_NUM(TYPE_SUFFIX, TYPE) \ + inline int cinn_host_gt_num_##TYPE_SUFFIX(const cinn_buffer_t* buf, \ + const int size, \ + const TYPE num, \ + const int offset, \ + const int stride); CINN_HOST_GT_NUM(fp32, float) CINN_HOST_GT_NUM(fp64, double) @@ -117,6 +126,9 @@ inline double FN_FP64(cbrt)(double x); namespace cinn { namespace runtime { -void cinn_assert_true_host(void* v_args, int num_args, int msg, bool only_warning); +void cinn_assert_true_host(void* v_args, + int num_args, + int msg, + bool only_warning); } // namespace runtime } // namespace cinn diff --git a/paddle/cinn/runtime/cpu/host_intrinsics_test.cc b/paddle/cinn/runtime/cpu/host_intrinsics_test.cc index 6520c8ede851a..22e13f8b0c3ab 100644 --- a/paddle/cinn/runtime/cpu/host_intrinsics_test.cc +++ b/paddle/cinn/runtime/cpu/host_intrinsics_test.cc @@ -33,7 +33,9 @@ TEST(tanh, basic) { Expr M(10), N(20); Placeholder x("x", {M, N}); auto y = Compute( - {M, N}, [&](Expr i, Expr j) { return CallExtern("tanh", {x(i, j)}); }, "y"); + {M, N}, + [&](Expr i, Expr j) { return CallExtern("tanh", {x(i, j)}); }, + "y"); auto stages = CreateStages({y}); @@ -49,15 +51,19 @@ TEST(tanh, basic) { jit->Link(builder.Build()); auto fn_ptr = jit->Lookup("fn"); - auto fnp = reinterpret_cast(fn_ptr); + auto fnp = reinterpret_cast(fn_ptr); ASSERT_TRUE(fnp); - auto* x_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); - auto* out_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_zero().Build(); - auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build(); + auto* x_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_random() + .Build(); + auto* out_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_zero() + .Build(); + auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build(); fnp(args.data(), args.size()); - auto* x_buf_data = reinterpret_cast(x_buf->memory); + auto* x_buf_data = reinterpret_cast(x_buf->memory); auto* out_buf_data = reinterpret_cast(out_buf->memory); for (int i = 0; i < x_buf->num_elements(); i++) { @@ -72,7 +78,8 @@ TEST(find_value_nd, basic) { auto y = Compute( {N}, [&](Expr i) { - return CallExtern("cinn_host_find_float_nd", {x, M, x({Expr(5), Expr(3)}), i, N}); + return CallExtern("cinn_host_find_float_nd", + {x, M, x({Expr(5), Expr(3)}), i, N}); }, "y"); @@ -90,21 +97,25 @@ TEST(find_value_nd, basic) { jit->Link(builder.Build()); auto fn_ptr = jit->Lookup("fn"); - auto fnp = reinterpret_cast(fn_ptr); + auto fnp = reinterpret_cast(fn_ptr); ASSERT_TRUE(fnp); - auto* x_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); - auto* out_buf = common::BufferBuilder(Int(32), {N.as_int32()}).set_zero().Build(); - auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build(); + auto* x_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_random() + .Build(); + auto* out_buf = + common::BufferBuilder(Int(32), {N.as_int32()}).set_zero().Build(); + auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build(); fnp(args.data(), args.size()); - auto* x_buf_data = reinterpret_cast(x_buf->memory); + auto* x_buf_data = reinterpret_cast(x_buf->memory); auto* out_buf_data = reinterpret_cast(out_buf->memory); for (int i = 0; i < out_buf->num_elements(); i++) { LOG_FIRST_N(INFO, 3) << out_buf_data[i]; if (out_buf_data[i] != -1) { - ASSERT_NEAR(x_buf_data[out_buf_data[i] * 20 + i], x_buf_data[5 * 20 + 3], 1e-5); + ASSERT_NEAR( + x_buf_data[out_buf_data[i] * 20 + i], x_buf_data[5 * 20 + 3], 1e-5); } } } @@ -115,7 +126,8 @@ TEST(cinn_host_lt_num_fp32, basic) { auto y = Compute( {N}, [&](Expr j) { - return CallExtern("cinn_host_lt_num_fp32", {x, M, x({Expr(0), j}), j, N}); + return CallExtern("cinn_host_lt_num_fp32", + {x, M, x({Expr(0), j}), j, N}); }, "y"); @@ -133,15 +145,18 @@ TEST(cinn_host_lt_num_fp32, basic) { jit->Link(builder.Build()); auto fn_ptr = jit->Lookup("fn"); - auto fnp = reinterpret_cast(fn_ptr); + auto fnp = reinterpret_cast(fn_ptr); ASSERT_TRUE(fnp); - auto* x_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); - auto* out_buf = common::BufferBuilder(Int(32), {N.as_int32()}).set_zero().Build(); - auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build(); + auto* x_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_random() + .Build(); + auto* out_buf = + common::BufferBuilder(Int(32), {N.as_int32()}).set_zero().Build(); + auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build(); fnp(args.data(), args.size()); - auto* x_buf_data = reinterpret_cast(x_buf->memory); + auto* x_buf_data = reinterpret_cast(x_buf->memory); auto* out_buf_data = reinterpret_cast(out_buf->memory); for (int j = 0; j < 20; j++) { @@ -162,7 +177,8 @@ TEST(cinn_host_gt_num_fp32, basic) { auto y = Compute( {N}, [&](Expr j) { - return CallExtern("cinn_host_gt_num_fp32", {x, M, x({Expr(0), j}), j, N}); + return CallExtern("cinn_host_gt_num_fp32", + {x, M, x({Expr(0), j}), j, N}); }, "y"); @@ -180,15 +196,18 @@ TEST(cinn_host_gt_num_fp32, basic) { jit->Link(builder.Build()); auto fn_ptr = jit->Lookup("fn"); - auto fnp = reinterpret_cast(fn_ptr); + auto fnp = reinterpret_cast(fn_ptr); ASSERT_TRUE(fnp); - auto* x_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); - auto* out_buf = common::BufferBuilder(Int(32), {N.as_int32()}).set_zero().Build(); - auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build(); + auto* x_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_random() + .Build(); + auto* out_buf = + common::BufferBuilder(Int(32), {N.as_int32()}).set_zero().Build(); + auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build(); fnp(args.data(), args.size()); - auto* x_buf_data = reinterpret_cast(x_buf->memory); + auto* x_buf_data = reinterpret_cast(x_buf->memory); auto* out_buf_data = reinterpret_cast(out_buf->memory); for (int j = 0; j < 20; j++) { diff --git a/paddle/cinn/runtime/cpu/mkl_math.cc b/paddle/cinn/runtime/cpu/mkl_math.cc index d375270803c87..f481ef072129d 100644 --- a/paddle/cinn/runtime/cpu/mkl_math.cc +++ b/paddle/cinn/runtime/cpu/mkl_math.cc @@ -24,14 +24,18 @@ #include "paddle/cinn/backends/function_prototype.h" #include "paddle/cinn/runtime/cpu/host_intrinsics.h" -#define CINN_MKL_VECTOR_MATH_FP(fn__, name__) \ - void cinn_mkl_##name__##_v_fp32(cinn_buffer_t *x, cinn_buffer_t *out) { \ - CHECK_EQ(x->num_elements(), out->num_elements()); \ - vs##fn__(x->num_elements(), reinterpret_cast(x->memory), reinterpret_cast(out->memory)); \ - } \ - void cinn_mkl_##name__##_v_fp64(cinn_buffer_t *x, cinn_buffer_t *out) { \ - CHECK_EQ(x->num_elements(), out->num_elements()); \ - vd##fn__(x->num_elements(), reinterpret_cast(x->memory), reinterpret_cast(out->memory)); \ +#define CINN_MKL_VECTOR_MATH_FP(fn__, name__) \ + void cinn_mkl_##name__##_v_fp32(cinn_buffer_t *x, cinn_buffer_t *out) { \ + CHECK_EQ(x->num_elements(), out->num_elements()); \ + vs##fn__(x->num_elements(), \ + reinterpret_cast(x->memory), \ + reinterpret_cast(out->memory)); \ + } \ + void cinn_mkl_##name__##_v_fp64(cinn_buffer_t *x, cinn_buffer_t *out) { \ + CHECK_EQ(x->num_elements(), out->num_elements()); \ + vd##fn__(x->num_elements(), \ + reinterpret_cast(x->memory), \ + reinterpret_cast(out->memory)); \ } CINN_MKL_VECTOR_MATH_FP(Exp, exp); diff --git a/paddle/cinn/runtime/cpu/mkl_math_test.cc b/paddle/cinn/runtime/cpu/mkl_math_test.cc index 720f200742578..f91a76ddd5411 100644 --- a/paddle/cinn/runtime/cpu/mkl_math_test.cc +++ b/paddle/cinn/runtime/cpu/mkl_math_test.cc @@ -29,7 +29,9 @@ namespace cinn { namespace runtime { namespace cpu { -cinn_buffer_t *CreateBuffer(const std::vector shape, bool random = true, int set_value = 0) { +cinn_buffer_t *CreateBuffer(const std::vector shape, + bool random = true, + int set_value = 0) { if (random) { return common::BufferBuilder(Float(32), shape).set_random().Build(); } else if (set_value != 0) { @@ -39,8 +41,11 @@ cinn_buffer_t *CreateBuffer(const std::vector shape, bool random = true, in } template -void TestCallElementwise( - const std::string &fn_name, FuncRuntime fn_runtime, bool is_elementwise, Type type = Float(32), int set_value = 0) { +void TestCallElementwise(const std::string &fn_name, + FuncRuntime fn_runtime, + bool is_elementwise, + Type type = Float(32), + int set_value = 0) { Expr M(10); Expr N(10); Placeholder x("x", {M, N}); @@ -50,11 +55,17 @@ void TestCallElementwise( std::vector lower_args({x}); if (is_elementwise) { out = Compute( - {M, N}, [&](Var i, Var j) -> Expr { return lang::CallExtern(fn_name, {x(i, j)}); }, fn_name + "_out"); + {M, N}, + [&](Var i, Var j) -> Expr { + return lang::CallExtern(fn_name, {x(i, j)}); + }, + fn_name + "_out"); lower_args.push_back(out); } else { auto comp_out = Compute( - {Expr(1)}, [&]() -> Expr { return lang::CallExtern(fn_name, {x}); }, fn_name + "_out"); + {Expr(1)}, + [&]() -> Expr { return lang::CallExtern(fn_name, {x}); }, + fn_name + "_out"); out = comp_out->TupleGet(0); out->WithBuffer(Float(32)); lower_args.push_back(out); @@ -71,7 +82,7 @@ void TestCallElementwise( LOG(INFO) << "func:\n" << func; - auto jit = backends::ExecutionEngine::Create({}); + auto jit = backends::ExecutionEngine::Create({}); auto module = builder.Build(); jit->Link(module); @@ -85,7 +96,8 @@ void TestCallElementwise( } else { A_buf = CreateBuffer({10, 10}); } - auto *B_buf = common::BufferBuilder(type, {10, 10}).set_align(type.bits()).Build(); + auto *B_buf = + common::BufferBuilder(type, {10, 10}).set_align(type.bits()).Build(); cinn_pod_value_t a_arg(A_buf), b_arg(B_buf); cinn_pod_value_t args[] = {a_arg, b_arg}; @@ -109,16 +121,24 @@ bool isnan(float e) { return std::isnan(e); } bool isfinite(float e) { return std::isfinite(e); } bool isinf(float e) { return std::isinf(e); } -#define TEST_MKL_MATH_FP32(test_name__, is_elementwise) \ - TEST(mkl_math, test_name__) { TestCallElementwise(#test_name__, test_name__##f, is_elementwise); } -#define TEST_CINN_MKL_MATH_FP32(test_name__, is_elementwise) \ - TEST(mkl_math, test_name__) { \ - TestCallElementwise("cinn_mkl_" #test_name__ "_v_fp32", test_name__##f, is_elementwise); \ +#define TEST_MKL_MATH_FP32(test_name__, is_elementwise) \ + TEST(mkl_math, test_name__) { \ + TestCallElementwise(#test_name__, test_name__##f, is_elementwise); \ + } +#define TEST_CINN_MKL_MATH_FP32(test_name__, is_elementwise) \ + TEST(mkl_math, test_name__) { \ + TestCallElementwise( \ + "cinn_mkl_" #test_name__ "_v_fp32", test_name__##f, is_elementwise); \ + } +#define TEST_MKL_MATH_FP32_BOOL(test_name__, is_elementwise) \ + TEST(mkl_math, test_name__) { \ + TestCallElementwise(#test_name__, test_name__, is_elementwise, Bool()); \ + } +#define TEST_MKL_MATH_FP32_SET(test_name__, is_elementwise, value) \ + TEST(mkl_math, test_name__) { \ + TestCallElementwise( \ + #test_name__, test_name__##f, is_elementwise, Float(32), value); \ } -#define TEST_MKL_MATH_FP32_BOOL(test_name__, is_elementwise) \ - TEST(mkl_math, test_name__) { TestCallElementwise(#test_name__, test_name__, is_elementwise, Bool()); } -#define TEST_MKL_MATH_FP32_SET(test_name__, is_elementwise, value) \ - TEST(mkl_math, test_name__) { TestCallElementwise(#test_name__, test_name__##f, is_elementwise, Float(32), value); } TEST_CINN_MKL_MATH_FP32(exp, false) TEST_CINN_MKL_MATH_FP32(erf, false) @@ -146,7 +166,9 @@ TEST_CINN_MKL_MATH_FP32(tanh, false) TEST_MKL_MATH_FP32_BOOL(isfinite, true) TEST_MKL_MATH_FP32_BOOL(isinf, true) -TEST(mkl_math, tanh_v_fp32) { TestCallElementwise("cinn_mkl_tanh_v_fp32", tanhf, false); } +TEST(mkl_math, tanh_v_fp32) { + TestCallElementwise("cinn_mkl_tanh_v_fp32", tanhf, false); +} TEST(cinn_cpu_mkl_gemm_fp32, test) { Expr M(30); @@ -191,17 +213,23 @@ TEST(cinn_cpu_mkl_gemm_fp32, test) { LOG(INFO) << "func:\n" << func; - auto jit = backends::SimpleJIT::Create(); + auto jit = backends::SimpleJIT::Create(); auto module = builder.Build(); jit->Link(module, /*optimize=*/true); - auto fn = jit->Lookup("fn"); + auto fn = jit->Lookup("fn"); auto fn_ptr = reinterpret_cast(fn); // test with real data - auto *A_buf = common::BufferBuilder(Float(32), {M.as_int32(), K.as_int32()}).set_random().Build(); - auto *B_buf = common::BufferBuilder(Float(32), {K.as_int32(), N.as_int32()}).set_random().Build(); - auto *C_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_zero().Build(); + auto *A_buf = common::BufferBuilder(Float(32), {M.as_int32(), K.as_int32()}) + .set_random() + .Build(); + auto *B_buf = common::BufferBuilder(Float(32), {K.as_int32(), N.as_int32()}) + .set_random() + .Build(); + auto *C_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}) + .set_zero() + .Build(); auto args = common::ArgsBuilder().Add(A_buf).Add(B_buf).Add(C_buf).Build(); diff --git a/paddle/cinn/runtime/cpu/mkldnn_math.cc b/paddle/cinn/runtime/cpu/mkldnn_math.cc index 7bb4457979026..9bf107994b0e2 100644 --- a/paddle/cinn/runtime/cpu/mkldnn_math.cc +++ b/paddle/cinn/runtime/cpu/mkldnn_math.cc @@ -22,17 +22,22 @@ using dnnl::algorithm; using dnnl::memory; using tag = memory::format_tag; -using dt = memory::data_type; - -void cinn_cpu_mkldnn_softmax_fp32( - int batch, int channel, int h, int w, int axis, cinn_buffer_t* inputs, cinn_buffer_t* out) { +using dt = memory::data_type; + +void cinn_cpu_mkldnn_softmax_fp32(int batch, + int channel, + int h, + int w, + int axis, + cinn_buffer_t* inputs, + cinn_buffer_t* out) { auto engine = dnnl::engine(dnnl::engine::kind::cpu, 0); dnnl::stream engine_stream(engine); memory::dims src_dims = {batch, channel}; if (h != 1) src_dims.push_back(h); if (w != 1) src_dims.push_back(w); - int size = src_dims.size(); + int size = src_dims.size(); auto format_tag = tag::nc; switch (size) { case 2: @@ -49,14 +54,17 @@ void cinn_cpu_mkldnn_softmax_fp32( break; } - auto src_md = memory::desc(src_dims, dt::f32, format_tag); - auto src_mem = memory(src_md, engine, reinterpret_cast(inputs->memory)); - auto dst_mem = memory(src_md, engine, reinterpret_cast(out->memory)); - auto softmax_d = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, src_md, axis); - auto softmax_pd = dnnl::softmax_forward::primitive_desc(softmax_d, engine); + auto src_md = memory::desc(src_dims, dt::f32, format_tag); + auto src_mem = + memory(src_md, engine, reinterpret_cast(inputs->memory)); + auto dst_mem = memory(src_md, engine, reinterpret_cast(out->memory)); + auto softmax_d = dnnl::softmax_forward::desc( + dnnl::prop_kind::forward_inference, src_md, axis); + auto softmax_pd = dnnl::softmax_forward::primitive_desc(softmax_d, engine); auto softmax_prim = dnnl::softmax_forward(softmax_pd); - softmax_prim.execute(engine_stream, {{DNNL_ARG_SRC, src_mem}, {DNNL_ARG_DST, dst_mem}}); + softmax_prim.execute(engine_stream, + {{DNNL_ARG_SRC, src_mem}, {DNNL_ARG_DST, dst_mem}}); engine_stream.wait(); } @@ -80,45 +88,52 @@ void cinn_cpu_mkldnn_conv2d_nchw_fp32(int batch_size, auto cpu_engine = dnnl::engine(dnnl::engine::kind::cpu, 0); dnnl::stream cpu_stream(cpu_engine); - memory::dims conv_src_tz = {batch_size, c_in, input_h, input_w}; + memory::dims conv_src_tz = {batch_size, c_in, input_h, input_w}; memory::dims conv_weights_tz = {c_out, c_in, filter_h, filter_w}; if (group > 1) { conv_weights_tz = {group, c_out / group, c_in / group, filter_h, filter_w}; } - int out_h = (input_h - ((filter_h - 1) * dilation_h + 1) + 2 * pad_h) / stride_h + 1; - int out_w = (input_w - ((filter_w - 1) * dilation_w + 1) + 2 * pad_w) / stride_w + 1; - memory::dims conv_dst_tz = {batch_size, c_out, out_h, out_w}; - memory::dims conv_strides = {stride_h, stride_w}; - memory::dims conv_paddings = {pad_h, pad_w}; + int out_h = + (input_h - ((filter_h - 1) * dilation_h + 1) + 2 * pad_h) / stride_h + 1; + int out_w = + (input_w - ((filter_w - 1) * dilation_w + 1) + 2 * pad_w) / stride_w + 1; + memory::dims conv_dst_tz = {batch_size, c_out, out_h, out_w}; + memory::dims conv_strides = {stride_h, stride_w}; + memory::dims conv_paddings = {pad_h, pad_w}; memory::dims conv_dilations = {dilation_h - 1, dilation_w - 1}; - auto conv_user_src_memory = - memory({{conv_src_tz}, dt::f32, tag::nchw}, cpu_engine, reinterpret_cast(inputs->memory)); - auto conv_user_weights_memory = memory({{conv_weights_tz}, dt::f32, group > 1 ? tag::goihw : tag::oihw}, - cpu_engine, - reinterpret_cast(weights->memory)); - auto conv_user_dst_memory = - memory({{conv_dst_tz}, dt::f32, tag::nchw}, cpu_engine, reinterpret_cast(out->memory)); - - auto conv_src_md = memory::desc({conv_src_tz}, dt::f32, tag::any); + auto conv_user_src_memory = memory({{conv_src_tz}, dt::f32, tag::nchw}, + cpu_engine, + reinterpret_cast(inputs->memory)); + auto conv_user_weights_memory = + memory({{conv_weights_tz}, dt::f32, group > 1 ? tag::goihw : tag::oihw}, + cpu_engine, + reinterpret_cast(weights->memory)); + auto conv_user_dst_memory = memory({{conv_dst_tz}, dt::f32, tag::nchw}, + cpu_engine, + reinterpret_cast(out->memory)); + + auto conv_src_md = memory::desc({conv_src_tz}, dt::f32, tag::any); auto conv_weights_md = memory::desc({conv_weights_tz}, dt::f32, tag::any); - auto conv_dst_md = memory::desc({conv_dst_tz}, dt::f32, tag::nchw); - - auto conv_desc = dnnl::convolution_forward::desc(dnnl::prop_kind::forward_inference, - dnnl::algorithm::convolution_direct, - conv_src_md, - conv_weights_md, - conv_dst_md, - conv_strides, - conv_dilations, - conv_paddings, - conv_paddings); - - auto conv_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc, cpu_engine); - - auto conv_src_memory = conv_user_src_memory; + auto conv_dst_md = memory::desc({conv_dst_tz}, dt::f32, tag::nchw); + + auto conv_desc = + dnnl::convolution_forward::desc(dnnl::prop_kind::forward_inference, + dnnl::algorithm::convolution_direct, + conv_src_md, + conv_weights_md, + conv_dst_md, + conv_strides, + conv_dilations, + conv_paddings, + conv_paddings); + + auto conv_prim_desc = + dnnl::convolution_forward::primitive_desc(conv_desc, cpu_engine); + + auto conv_src_memory = conv_user_src_memory; auto conv_weights_memory = conv_user_weights_memory; - auto conv_dst_memory = conv_user_dst_memory; + auto conv_dst_memory = conv_user_dst_memory; if (conv_prim_desc.dst_desc() != conv_user_dst_memory.get_desc()) { conv_dst_memory = memory(conv_prim_desc.dst_desc(), cpu_engine); } @@ -128,7 +143,8 @@ void cinn_cpu_mkldnn_conv2d_nchw_fp32(int batch_size, {DNNL_ARG_WEIGHTS, conv_weights_memory}, {DNNL_ARG_DST, conv_dst_memory}}); if (conv_prim_desc.dst_desc() != conv_user_dst_memory.get_desc()) { - dnnl::reorder(conv_dst_memory, conv_user_dst_memory).execute(cpu_stream, conv_dst_memory, conv_user_dst_memory); + dnnl::reorder(conv_dst_memory, conv_user_dst_memory) + .execute(cpu_stream, conv_dst_memory, conv_user_dst_memory); } else { conv_user_dst_memory = conv_dst_memory; } @@ -141,30 +157,35 @@ CINN_REGISTER_HELPER(cinn_cpu_mkldnn) { using backends::FunctionProto; auto host_target = common::DefaultHostTarget(); - FunctionProto::shape_inference_t inference_shape_conv2d_nchw = [](const std::vector& args, int offset) { - CHECK_EQ(args.size(), 16UL) << "Wrong number of arguments passed in"; - auto N = common::AutoSimplify(args[0]); - int input_h = common::AutoSimplify(args[2]).as_int32(); - int input_w = common::AutoSimplify(args[3]).as_int32(); - auto c_out = common::AutoSimplify(args[4]); - int filter_h = common::AutoSimplify(args[6]).as_int32(); - int filter_w = common::AutoSimplify(args[7]).as_int32(); - int pad_h = common::AutoSimplify(args[8]).as_int32(); - int pad_w = common::AutoSimplify(args[9]).as_int32(); - int stride_h = common::AutoSimplify(args[10]).as_int32(); - int stride_w = common::AutoSimplify(args[11]).as_int32(); - int dilation_h = common::AutoSimplify(args[12]).as_int32(); - int dilation_w = common::AutoSimplify(args[13]).as_int32(); - int out_h = (input_h - ((filter_h - 1) * dilation_h + 1) + 2 * pad_h) / stride_h + 1; - int out_w = (input_w - ((filter_w - 1) * dilation_w + 1) + 2 * pad_w) / stride_w + 1; - - std::vector shape; - shape.push_back(N); - shape.push_back(c_out); - shape.push_back(Expr(out_h)); - shape.push_back(Expr(out_w)); - return shape; - }; + FunctionProto::shape_inference_t inference_shape_conv2d_nchw = + [](const std::vector& args, int offset) { + CHECK_EQ(args.size(), 16UL) << "Wrong number of arguments passed in"; + auto N = common::AutoSimplify(args[0]); + int input_h = common::AutoSimplify(args[2]).as_int32(); + int input_w = common::AutoSimplify(args[3]).as_int32(); + auto c_out = common::AutoSimplify(args[4]); + int filter_h = common::AutoSimplify(args[6]).as_int32(); + int filter_w = common::AutoSimplify(args[7]).as_int32(); + int pad_h = common::AutoSimplify(args[8]).as_int32(); + int pad_w = common::AutoSimplify(args[9]).as_int32(); + int stride_h = common::AutoSimplify(args[10]).as_int32(); + int stride_w = common::AutoSimplify(args[11]).as_int32(); + int dilation_h = common::AutoSimplify(args[12]).as_int32(); + int dilation_w = common::AutoSimplify(args[13]).as_int32(); + int out_h = (input_h - ((filter_h - 1) * dilation_h + 1) + 2 * pad_h) / + stride_h + + 1; + int out_w = (input_w - ((filter_w - 1) * dilation_w + 1) + 2 * pad_w) / + stride_w + + 1; + + std::vector shape; + shape.push_back(N); + shape.push_back(c_out); + shape.push_back(Expr(out_h)); + shape.push_back(Expr(out_w)); + return shape; + }; REGISTER_EXTERN_FUNC_HELPER(cinn_cpu_mkldnn_conv2d_nchw_fp32, host_target) .SetRetType() diff --git a/paddle/cinn/runtime/cpu/mkldnn_math.h b/paddle/cinn/runtime/cpu/mkldnn_math.h index ebac2bf1f20e7..c8cb49c1cc329 100644 --- a/paddle/cinn/runtime/cpu/mkldnn_math.h +++ b/paddle/cinn/runtime/cpu/mkldnn_math.h @@ -21,8 +21,13 @@ // define some C APIs extern "C" { -void cinn_cpu_mkldnn_softmax_fp32( - int batch, int channel, int h, int w, int axis, cinn_buffer_t* inputs, cinn_buffer_t* out); +void cinn_cpu_mkldnn_softmax_fp32(int batch, + int channel, + int h, + int w, + int axis, + cinn_buffer_t* inputs, + cinn_buffer_t* out); void cinn_cpu_mkldnn_conv2d_nchw_fp32(int batch_size, int c_in, diff --git a/paddle/cinn/runtime/cpu/mkldnn_math_test.cc b/paddle/cinn/runtime/cpu/mkldnn_math_test.cc index 3bb576fd59ae8..26d06d715d550 100644 --- a/paddle/cinn/runtime/cpu/mkldnn_math_test.cc +++ b/paddle/cinn/runtime/cpu/mkldnn_math_test.cc @@ -29,7 +29,9 @@ namespace cinn { namespace runtime { namespace cpu { -cinn_buffer_t *CreateBuffer(const std::vector shape, bool random = true, int set_value = 0) { +cinn_buffer_t *CreateBuffer(const std::vector shape, + bool random = true, + int set_value = 0) { if (random) { return common::BufferBuilder(Float(32), shape).set_random().Build(); } else if (set_value != 0) { @@ -53,8 +55,10 @@ TEST(cinn_cpu_mkldnn_conv2d_nchw_fp32, test) { int dilation_h(1); int dilation_w(1); - Placeholder input("input", {Expr(n), Expr(c_in), Expr(i_h), Expr(i_w)}); - Placeholder weights("weights", {Expr(c_out), Expr(c_in), Expr(k_h), Expr(k_w)}); + Placeholder input("input", + {Expr(n), Expr(c_in), Expr(i_h), Expr(i_w)}); + Placeholder weights("weights", + {Expr(c_out), Expr(c_in), Expr(k_h), Expr(k_w)}); auto call = Compute( {Expr(1)}, @@ -95,19 +99,24 @@ TEST(cinn_cpu_mkldnn_conv2d_nchw_fp32, test) { LOG(INFO) << "func:\n" << func; - auto jit = backends::SimpleJIT::Create(); + auto jit = backends::SimpleJIT::Create(); auto module = builder.Build(); jit->Link(module, /*optimize=*/true); - auto fn = jit->Lookup("fn"); + auto fn = jit->Lookup("fn"); auto fn_ptr = reinterpret_cast(fn); // test with real data - int o_h = (i_h - ((k_h - 1) * dilation_h + 1) + pad_h * 2) / stride_h + 1; - int o_w = (i_w - ((k_w - 1) * dilation_w + 1) + pad_w * 2) / stride_w + 1; - auto *A_buf = common::BufferBuilder(Float(32), {n, c_in, i_h, i_w}).set_random().Build(); - auto *B_buf = common::BufferBuilder(Float(32), {c_out, c_in, k_h, k_w}).set_random().Build(); - auto *C_buf = common::BufferBuilder(Float(32), {n, c_out, o_h, o_w}).set_zero().Build(); + int o_h = (i_h - ((k_h - 1) * dilation_h + 1) + pad_h * 2) / stride_h + 1; + int o_w = (i_w - ((k_w - 1) * dilation_w + 1) + pad_w * 2) / stride_w + 1; + auto *A_buf = common::BufferBuilder(Float(32), {n, c_in, i_h, i_w}) + .set_random() + .Build(); + auto *B_buf = common::BufferBuilder(Float(32), {c_out, c_in, k_h, k_w}) + .set_random() + .Build(); + auto *C_buf = + common::BufferBuilder(Float(32), {n, c_out, o_h, o_w}).set_zero().Build(); auto args = common::ArgsBuilder().Add(A_buf).Add(B_buf).Add(C_buf).Build(); diff --git a/paddle/cinn/runtime/cpu/thread_backend.cc b/paddle/cinn/runtime/cpu/thread_backend.cc index 91c22959af05e..c6c49dfe5d505 100644 --- a/paddle/cinn/runtime/cpu/thread_backend.cc +++ b/paddle/cinn/runtime/cpu/thread_backend.cc @@ -28,7 +28,7 @@ int max_concurrency() { int max_concurrency = 1; - const char* val = getenv("CINN_NUM_THREADS"); + const char* val = getenv("CINN_NUM_THREADS"); if (val == nullptr) { val = getenv("OMP_NUM_THREADS"); } @@ -43,7 +43,9 @@ int max_concurrency() { return std::max(max_concurrency, 1); } -int cinn_backend_parallel_launch(FCINNParallelLambda flambda, void* datas, int num_task) { +int cinn_backend_parallel_launch(FCINNParallelLambda flambda, + void* datas, + int num_task) { int num_workers = max_concurrency(); if (num_task == 0) num_task = num_workers; #ifdef CINN_USE_OPENMP @@ -63,7 +65,8 @@ CINN_REGISTER_HELPER(cinn_backend_parallel) { using namespace cinn; // NOLINT using backends::FunctionProto; auto host_target = common::DefaultHostTarget(); - backends::GlobalSymbolRegistry::Global().RegisterFn(runtime::intrinsic::parallel_launch, - reinterpret_cast(&cinn_backend_parallel_launch)); + backends::GlobalSymbolRegistry::Global().RegisterFn( + runtime::intrinsic::parallel_launch, + reinterpret_cast(&cinn_backend_parallel_launch)); return true; } diff --git a/paddle/cinn/runtime/cpu/thread_backend.h b/paddle/cinn/runtime/cpu/thread_backend.h index e98e03fbd53a0..f1fac38f22dc8 100644 --- a/paddle/cinn/runtime/cpu/thread_backend.h +++ b/paddle/cinn/runtime/cpu/thread_backend.h @@ -41,6 +41,8 @@ typedef int (*FCINNParallelLambda)(int task_id, int num_task, void* datas); * * @return 0 when no error is thrown, -1 when failure happens */ -int cinn_backend_parallel_launch(FCINNParallelLambda flambda, void* datas, int num_task); +int cinn_backend_parallel_launch(FCINNParallelLambda flambda, + void* datas, + int num_task); } // extern "C" diff --git a/paddle/cinn/runtime/cuda/bfloat16.h b/paddle/cinn/runtime/cuda/bfloat16.h index 27501008bf5bf..40ed0fed07cd2 100644 --- a/paddle/cinn/runtime/cuda/bfloat16.h +++ b/paddle/cinn/runtime/cuda/bfloat16.h @@ -69,17 +69,17 @@ struct CINN_ALIGN(2) bfloat16 { #ifdef __cplusplus // Constructors - bfloat16() = default; + bfloat16() = default; bfloat16(const bfloat16& o) = default; bfloat16& operator=(const bfloat16& o) = default; - bfloat16(bfloat16&& o) = default; + bfloat16(bfloat16&& o) = default; bfloat16& operator=(bfloat16&& o) = default; - ~bfloat16() = default; + ~bfloat16() = default; __host__ __device__ inline explicit bfloat16(float val) { #if defined(CINN_CUDA_BF16) __nv_bfloat16 tmp = __float2bfloat16(val); - x = *reinterpret_cast(&tmp); + x = *reinterpret_cast(&tmp); #else std::memcpy(&x, reinterpret_cast(&val) + 2, 2); #endif @@ -92,7 +92,8 @@ struct CINN_ALIGN(2) bfloat16 { #endif template - __host__ __device__ inline explicit bfloat16(const T& val) : x(bfloat16(static_cast(val)).x) {} + __host__ __device__ inline explicit bfloat16(const T& val) + : x(bfloat16(static_cast(val)).x) {} // Assignment operators #if defined(CINN_CUDA_BF16) @@ -162,9 +163,10 @@ struct CINN_ALIGN(2) bfloat16 { #ifdef CINN_CUDA_BF16 return __bfloat162float(*reinterpret_cast(&x)); #else - float val = 0.f; + float val = 0.f; uint16_t temp = x; - std::memcpy(reinterpret_cast(&val) + 2, reinterpret_cast(&temp), 2); + std::memcpy( + reinterpret_cast(&val) + 2, reinterpret_cast(&temp), 2); return val; #endif } @@ -175,9 +177,13 @@ struct CINN_ALIGN(2) bfloat16 { } #endif - __host__ __device__ inline explicit operator bool() const { return (x & 0x7fff) != 0; } + __host__ __device__ inline explicit operator bool() const { + return (x & 0x7fff) != 0; + } - __host__ __device__ inline explicit operator int8_t() const { return static_cast(static_cast(*this)); } + __host__ __device__ inline explicit operator int8_t() const { + return static_cast(static_cast(*this)); + } __host__ __device__ inline explicit operator uint8_t() const { return static_cast(static_cast(*this)); @@ -207,11 +213,14 @@ struct CINN_ALIGN(2) bfloat16 { return static_cast(static_cast(*this)); } - __host__ __device__ inline operator double() const { return static_cast(static_cast(*this)); } + __host__ __device__ inline operator double() const { + return static_cast(static_cast(*this)); + } #endif // __cplusplus }; -__host__ __device__ inline bfloat16 operator+(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bfloat16 operator+(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return bfloat16(__hadd(a.to_nv_bfloat16(), b.to_nv_bfloat16())); #else @@ -219,7 +228,8 @@ __host__ __device__ inline bfloat16 operator+(const bfloat16& a, const bfloat16& #endif } -__host__ __device__ inline bfloat16 operator-(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bfloat16 operator-(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return bfloat16(__hsub(a.to_nv_bfloat16(), b.to_nv_bfloat16())); #else @@ -227,7 +237,8 @@ __host__ __device__ inline bfloat16 operator-(const bfloat16& a, const bfloat16& #endif } -__host__ __device__ inline bfloat16 operator*(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bfloat16 operator*(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return bfloat16(__hmul(a.to_nv_bfloat16(), b.to_nv_bfloat16())); #else @@ -235,7 +246,8 @@ __host__ __device__ inline bfloat16 operator*(const bfloat16& a, const bfloat16& #endif } -__host__ __device__ inline bfloat16 operator/(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bfloat16 operator/(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return bfloat16(__hdiv(a.to_nv_bfloat16(), b.to_nv_bfloat16())); #else @@ -280,7 +292,8 @@ __host__ __device__ inline bfloat16 raw_uint16_to_bfloat16(uint16_t a) { } // Comparison operators -__host__ __device__ inline bool operator==(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bool operator==(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __heq(a.to_nv_bfloat16(), b.to_nv_bfloat16()); #else @@ -288,7 +301,8 @@ __host__ __device__ inline bool operator==(const bfloat16& a, const bfloat16& b) #endif } -__host__ __device__ inline bool operator!=(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bool operator!=(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __hne(a.to_nv_bfloat16(), b.to_nv_bfloat16()); #else @@ -296,7 +310,8 @@ __host__ __device__ inline bool operator!=(const bfloat16& a, const bfloat16& b) #endif } -__host__ __device__ inline bool operator<(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bool operator<(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __hlt(a.to_nv_bfloat16(), b.to_nv_bfloat16()); #else @@ -304,7 +319,8 @@ __host__ __device__ inline bool operator<(const bfloat16& a, const bfloat16& b) #endif } -__host__ __device__ inline bool operator<=(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bool operator<=(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __hle(a.to_nv_bfloat16(), b.to_nv_bfloat16()); #else @@ -312,7 +328,8 @@ __host__ __device__ inline bool operator<=(const bfloat16& a, const bfloat16& b) #endif } -__host__ __device__ inline bool operator>(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bool operator>(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __hgt(a.to_nv_bfloat16(), b.to_nv_bfloat16()); #else @@ -320,7 +337,8 @@ __host__ __device__ inline bool operator>(const bfloat16& a, const bfloat16& b) #endif } -__host__ __device__ inline bool operator>=(const bfloat16& a, const bfloat16& b) { +__host__ __device__ inline bool operator>=(const bfloat16& a, + const bfloat16& b) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __hge(a.to_nv_bfloat16(), b.to_nv_bfloat16()); #else @@ -344,7 +362,9 @@ __host__ __device__ inline bool(isinf)(const bfloat16& a) { #endif } -__host__ __device__ inline bool(isfinite)(const bfloat16& a) { return !((isnan)(a)) && !((isinf)(a)); } +__host__ __device__ inline bool(isfinite)(const bfloat16& a) { + return !((isnan)(a)) && !((isinf)(a)); +} __host__ __device__ inline bfloat16(abs)(const bfloat16& a) { #if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -365,36 +385,43 @@ __device__ inline cinn::common::bfloat16 __shfl_sync(unsigned mask, cinn::common::bfloat16 var, int srcLane, int width = warpSize) { - return cinn::common::bfloat16(__shfl_sync(mask, var.to_nv_bfloat16(), srcLane, width)); + return cinn::common::bfloat16( + __shfl_sync(mask, var.to_nv_bfloat16(), srcLane, width)); } -__device__ inline cinn::common::bfloat16 __shfl_up_sync(unsigned mask, - cinn::common::bfloat16 var, - unsigned int delta, - int width = warpSize) { - return cinn::common::bfloat16(__shfl_up_sync(mask, var.to_nv_bfloat16(), delta, width)); +__device__ inline cinn::common::bfloat16 __shfl_up_sync( + unsigned mask, + cinn::common::bfloat16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::bfloat16( + __shfl_up_sync(mask, var.to_nv_bfloat16(), delta, width)); } -__device__ inline cinn::common::bfloat16 __shfl_down_sync(unsigned mask, - cinn::common::bfloat16 var, - unsigned int delta, - int width = warpSize) { - return cinn::common::bfloat16(__shfl_down_sync(mask, var.to_nv_bfloat16(), delta, width)); +__device__ inline cinn::common::bfloat16 __shfl_down_sync( + unsigned mask, + cinn::common::bfloat16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::bfloat16( + __shfl_down_sync(mask, var.to_nv_bfloat16(), delta, width)); } -__device__ inline cinn::common::bfloat16 __shfl_xor_sync(unsigned mask, - cinn::common::bfloat16 var, - int laneMask, - int width = warpSize) { - return cinn::common::bfloat16(__shfl_xor_sync(mask, var.to_nv_bfloat16(), laneMask, width)); +__device__ inline cinn::common::bfloat16 __shfl_xor_sync( + unsigned mask, + cinn::common::bfloat16 var, + int laneMask, + int width = warpSize) { + return cinn::common::bfloat16( + __shfl_xor_sync(mask, var.to_nv_bfloat16(), laneMask, width)); } -__host__ __device__ inline cinn::common::bfloat16 max(const cinn::common::bfloat16& a, - const cinn::common::bfloat16& b) { +__host__ __device__ inline cinn::common::bfloat16 max( + const cinn::common::bfloat16& a, const cinn::common::bfloat16& b) { return a > b ? a : b; } -__host__ __device__ inline cinn::common::bfloat16 min(const cinn::common::bfloat16& a, - const cinn::common::bfloat16& b) { +__host__ __device__ inline cinn::common::bfloat16 min( + const cinn::common::bfloat16& a, const cinn::common::bfloat16& b) { return a < b ? a : b; } #endif // __cplusplus && CINN_CUDA_FP16 diff --git a/paddle/cinn/runtime/cuda/cublas_util.h b/paddle/cinn/runtime/cuda/cublas_util.h index cd72cc61c2a6c..e23aca31ce1d2 100644 --- a/paddle/cinn/runtime/cuda/cublas_util.h +++ b/paddle/cinn/runtime/cuda/cublas_util.h @@ -15,8 +15,8 @@ #include -#include "paddle/cinn/common/type.h" #include "glog/logging.h" +#include "paddle/cinn/common/type.h" namespace cinn { namespace runtime { @@ -54,7 +54,7 @@ inline cublasStatus_t cublasGemm(cudaDataType_t dtype, ldc); } else if (dtype == CUDA_R_64F) { const double alpha_fp64 = static_cast(alpha); - const double beta_fp64 = static_cast(beta); + const double beta_fp64 = static_cast(beta); return cublasDgemm(handle, transa, transb, @@ -154,7 +154,7 @@ inline cublasStatus_t cublasGemmStridedBatched(cudaDataType_t dtype, batchCount); } else if (dtype == CUDA_R_64F) { const double alpha_fp64 = static_cast(alpha); - const double beta_fp64 = static_cast(beta); + const double beta_fp64 = static_cast(beta); return cublasDgemmStridedBatched(handle, transa, transb, @@ -176,24 +176,25 @@ inline cublasStatus_t cublasGemmStridedBatched(cudaDataType_t dtype, } else if (dtype == CUDA_R_16F) { common::float16 alpha_fp16{alpha}; common::float16 beta_fp16{beta}; - return cublasHgemmStridedBatched(handle, - transa, - transb, - m, - n, - k, - reinterpret_cast(&alpha_fp16), - reinterpret_cast(A), - lda, - strideA, - reinterpret_cast(B), - ldb, - strideB, - reinterpret_cast(&beta_fp16), - reinterpret_cast<__half *>(C), - ldc, - strideC, - batchCount); + return cublasHgemmStridedBatched( + handle, + transa, + transb, + m, + n, + k, + reinterpret_cast(&alpha_fp16), + reinterpret_cast(A), + lda, + strideA, + reinterpret_cast(B), + ldb, + strideB, + reinterpret_cast(&beta_fp16), + reinterpret_cast<__half *>(C), + ldc, + strideC, + batchCount); } else if (dtype == CUDA_R_16BF) { #if CUDA_VERSION >= 11000 return cublasGemmStridedBatchedEx(handle, @@ -220,7 +221,8 @@ inline cublasStatus_t cublasGemmStridedBatched(cudaDataType_t dtype, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); #else - LOG(FATAL) << "cublasGemmStridedBatched with bfloat16 is not supported on cuda <= 11"; + LOG(FATAL) << "cublasGemmStridedBatched with bfloat16 is not supported on " + "cuda <= 11"; #endif } LOG(FATAL) << "Unsupported cublasGemmStridedBatched precision."; @@ -260,7 +262,7 @@ inline cublasStatus_t cublasGemmBatched(cudaDataType_t dtype, batchCount); } else if (dtype == CUDA_R_64F) { const double alpha_fp64 = static_cast(alpha); - const double beta_fp64 = static_cast(beta); + const double beta_fp64 = static_cast(beta); return cublasDgemmBatched(handle, transa, transb, @@ -317,7 +319,8 @@ inline cublasStatus_t cublasGemmBatched(cudaDataType_t dtype, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); #else - LOG(FATAL) << "cublasGemmBatched with bfloat16 is not supported on cuda <= 11"; + LOG(FATAL) + << "cublasGemmBatched with bfloat16 is not supported on cuda <= 11"; #endif } LOG(FATAL) << "Unsupported cublasGemmBatched precision."; diff --git a/paddle/cinn/runtime/cuda/cuda_instrinsics_bfloat16.cc b/paddle/cinn/runtime/cuda/cuda_instrinsics_bfloat16.cc index eb0160552db0d..9e3610850d97c 100644 --- a/paddle/cinn/runtime/cuda/cuda_instrinsics_bfloat16.cc +++ b/paddle/cinn/runtime/cuda/cuda_instrinsics_bfloat16.cc @@ -28,7 +28,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics_bfloat16) { // bfloat16 #define REGISTER_EXTERN_FUNC_2_IN_1_BF16(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(cinn_nvgpu_##func__##_bf16, target, bfloat16, bfloat16, bfloat16); + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_nvgpu_##func__##_bf16, target, bfloat16, bfloat16, bfloat16); REGISTER_EXTERN_FUNC_2_IN_1_BF16(pow) REGISTER_EXTERN_FUNC_2_IN_1_BF16(mod) @@ -36,7 +37,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics_bfloat16) { #undef REGISTER_EXTERN_FUNC_2_IN_1_BF16 #define REGISTER_EXTERN_FUNC_1_IN_1_BF16(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_bf16, target, bfloat16, bfloat16); + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_bf16, target, bfloat16, bfloat16); REGISTER_EXTERN_FUNC_1_IN_1_BF16(ceil) REGISTER_EXTERN_FUNC_1_IN_1_BF16(floor) @@ -68,7 +70,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics_bfloat16) { #undef REGISTER_EXTERN_FUNC_1_IN_1_BF16 #define REGISTER_EXTERN_FUNC_1_IN_1_BF16_OUT_BOOL(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_bf16, target, bfloat16, bool); + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_bf16, target, bfloat16, bool); REGISTER_EXTERN_FUNC_1_IN_1_BF16_OUT_BOOL(isnan) REGISTER_EXTERN_FUNC_1_IN_1_BF16_OUT_BOOL(isinf) diff --git a/paddle/cinn/runtime/cuda/cuda_instrinsics_float16.cc b/paddle/cinn/runtime/cuda/cuda_instrinsics_float16.cc index 5910c9dcac0e1..62b74d56e1b3e 100644 --- a/paddle/cinn/runtime/cuda/cuda_instrinsics_float16.cc +++ b/paddle/cinn/runtime/cuda/cuda_instrinsics_float16.cc @@ -28,7 +28,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics_float16) { // float16 #define REGISTER_EXTERN_FUNC_2_IN_1_FP16(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(cinn_nvgpu_##func__##_fp16, target, float16, float16, float16); + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_nvgpu_##func__##_fp16, target, float16, float16, float16); REGISTER_EXTERN_FUNC_2_IN_1_FP16(pow) REGISTER_EXTERN_FUNC_2_IN_1_FP16(mod) @@ -36,7 +37,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics_float16) { #undef REGISTER_EXTERN_FUNC_2_IN_1_FP16 #define REGISTER_EXTERN_FUNC_1_IN_1_FP16(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_fp16, target, float16, float16); + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_fp16, target, float16, float16); REGISTER_EXTERN_FUNC_1_IN_1_FP16(ceil) REGISTER_EXTERN_FUNC_1_IN_1_FP16(floor) @@ -68,7 +70,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics_float16) { #undef REGISTER_EXTERN_FUNC_1_IN_1_FP16 #define REGISTER_EXTERN_FUNC_1_IN_1_FP16_OUT_BOOL(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_fp16, target, float16, bool); + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_fp16, target, float16, bool); REGISTER_EXTERN_FUNC_1_IN_1_FP16_OUT_BOOL(isnan) REGISTER_EXTERN_FUNC_1_IN_1_FP16_OUT_BOOL(isinf) @@ -104,16 +107,17 @@ CINN_REGISTER_HELPER(cuda_intrinsics_float16) { #undef REGISTER_CINN_NVGPU_LT_NUM -#define REGISTER_CINN_NVGPU_INDEX_ADD(TYPE_SUFFIX, TYPE) \ - REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_nvgpu_index_add_##TYPE_SUFFIX, target) \ - .SetRetType() \ - .AddInputType() \ - .AddInputType() \ - .AddInputType() \ - .AddInputType() \ - .AddInputType() \ - .AddInputType() \ - .AddInputType() \ +#define REGISTER_CINN_NVGPU_INDEX_ADD(TYPE_SUFFIX, TYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_nvgpu_index_add_##TYPE_SUFFIX, \ + target) \ + .SetRetType() \ + .AddInputType() \ + .AddInputType() \ + .AddInputType() \ + .AddInputType() \ + .AddInputType() \ + .AddInputType() \ + .AddInputType() \ .End(); REGISTER_CINN_NVGPU_INDEX_ADD(fp16, float16); diff --git a/paddle/cinn/runtime/cuda/cuda_intrinsics.cc b/paddle/cinn/runtime/cuda/cuda_intrinsics.cc index 2da1e46ed7b10..b46b48b27f55a 100644 --- a/paddle/cinn/runtime/cuda/cuda_intrinsics.cc +++ b/paddle/cinn/runtime/cuda/cuda_intrinsics.cc @@ -25,7 +25,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { // bool for 1 input 1 output #define REGISTER_EXTERN_FUNC_1_IN_1_OUT_BOOL(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_bool, target, bool, bool) + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_bool, target, bool, bool) REGISTER_EXTERN_FUNC_1_IN_1_OUT_BOOL(bitwise_not); @@ -33,7 +34,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { // bool for 2 input 1 output #define REGISTER_EXTERN_FUNC_2_IN_1_OUT_BOOL(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(cinn_nvgpu_##func__##_bool, target, bool, bool, bool) + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_nvgpu_##func__##_bool, target, bool, bool, bool) REGISTER_EXTERN_FUNC_2_IN_1_OUT_BOOL(bitwise_and); REGISTER_EXTERN_FUNC_2_IN_1_OUT_BOOL(bitwise_or); @@ -43,7 +45,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { // uint8 for 1 input 1 output #define REGISTER_EXTERN_FUNC_1_IN_1_OUT_UINT8(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_uint8, target, uint8_t, uint8_t) + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_uint8, target, uint8_t, uint8_t) REGISTER_EXTERN_FUNC_1_IN_1_OUT_UINT8(bitwise_not); @@ -51,7 +54,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { // uint8 for 2 input 1 output #define REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(cinn_nvgpu_##func__##_uint8, target, uint8_t, uint8_t, uint8_t); + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_nvgpu_##func__##_uint8, target, uint8_t, uint8_t, uint8_t); REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8(bitwise_and); REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8(bitwise_or); @@ -62,7 +66,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { // int8 for 1 input 1 output #define REGISTER_EXTERN_FUNC_1_IN_1_OUT_INT8(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_int8, target, int8_t, int8_t) + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_int8, target, int8_t, int8_t) REGISTER_EXTERN_FUNC_1_IN_1_OUT_INT8(bitwise_not); @@ -70,7 +75,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { // int8 for 2 input 1 output #define REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(cinn_nvgpu_##func__##_int8, target, int8_t, int8_t, int8_t); + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_nvgpu_##func__##_int8, target, int8_t, int8_t, int8_t); REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8(bitwise_and); REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8(bitwise_or); @@ -81,7 +87,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { // int16 for 1 input 1 output #define REGISTER_EXTERN_FUNC_1_IN_1_OUT_INT16(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_int16, target, int16_t, int16_t) + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_int16, target, int16_t, int16_t) REGISTER_EXTERN_FUNC_1_IN_1_OUT_INT16(bitwise_not); @@ -89,7 +96,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { // int16 for 2 input 1 output #define REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(cinn_nvgpu_##func__##_int16, target, int16_t, int16_t, int16_t); + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_nvgpu_##func__##_int16, target, int16_t, int16_t, int16_t); REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16(bitwise_and); REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16(bitwise_or); @@ -100,7 +108,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { // float #define REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_fp32, target, float, float); + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_fp32, target, float, float); REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(abs); REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(exp); @@ -132,7 +141,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { #undef REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT #define REGISTER_EXTERN_FUNC_1_IN_FLOAT_1_OUT_BOOL(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_fp32, target, float, bool); + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_fp32, target, float, bool); REGISTER_EXTERN_FUNC_1_IN_FLOAT_1_OUT_BOOL(isnan); REGISTER_EXTERN_FUNC_1_IN_FLOAT_1_OUT_BOOL(isfinite); @@ -141,7 +151,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { #undef REGISTER_EXTERN_FUNC_1_IN_FLOAT_1_OUT_BOOL #define REGISTER_EXTERN_FUNC_2_IN_1_FLOAT(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(cinn_nvgpu_##func__##_fp32, target, float, float, float); + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_nvgpu_##func__##_fp32, target, float, float, float); REGISTER_EXTERN_FUNC_2_IN_1_FLOAT(pow) REGISTER_EXTERN_FUNC_2_IN_1_FLOAT(mod) @@ -151,7 +162,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { // double #define REGISTER_EXTERN_FUNC_1_IN_1_FP64(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_fp64, target, double, double); + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_fp64, target, double, double); REGISTER_EXTERN_FUNC_1_IN_1_FP64(abs); REGISTER_EXTERN_FUNC_1_IN_1_FP64(exp); @@ -183,7 +195,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { #undef REGISTER_EXTERN_FUNC_1_IN_1_FP64 #define REGISTER_EXTERN_FUNC_1_IN_FP64_1_OUT_BOOL(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_fp64, target, double, bool); + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_fp64, target, double, bool); REGISTER_EXTERN_FUNC_1_IN_FP64_1_OUT_BOOL(isnan); REGISTER_EXTERN_FUNC_1_IN_FP64_1_OUT_BOOL(isfinite); @@ -192,7 +205,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { #undef REGISTER_EXTERN_FUNC_1_IN_FP64_1_OUT_BOOL #define REGISTER_EXTERN_FUNC_2_IN_1_FP64(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(cinn_nvgpu_##func__##_fp64, target, double, double, double); + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_nvgpu_##func__##_fp64, target, double, double, double); REGISTER_EXTERN_FUNC_2_IN_1_FP64(pow) REGISTER_EXTERN_FUNC_2_IN_1_FP64(mod) @@ -202,7 +216,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { // int32 #define REGISTER_EXTERN_FUNC_1_IN_1_INT32(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_int32, target, int, int); + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_int32, target, int, int); REGISTER_EXTERN_FUNC_1_IN_1_INT32(bitwise_not) REGISTER_EXTERN_FUNC_1_IN_1_INT32(clz) @@ -212,7 +227,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { #undef REGISTER_EXTERN_FUNC_1_IN_1_INT32 #define REGISTER_EXTERN_FUNC_1_IN_1_INT64(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_int64, target, int64_t, int64_t); + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT( \ + cinn_nvgpu_##func__##_int64, target, int64_t, int64_t); REGISTER_EXTERN_FUNC_1_IN_1_INT64(bitwise_not) REGISTER_EXTERN_FUNC_1_IN_1_INT64(clz) @@ -222,7 +238,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { #undef REGISTER_EXTERN_FUNC_1_IN_1_INT64 #define REGISTER_EXTERN_FUNC_2_IN_1_INT32(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(cinn_nvgpu_##func__##_int32, target, int, int, int); + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_nvgpu_##func__##_int32, target, int, int, int); REGISTER_EXTERN_FUNC_2_IN_1_INT32(pow) REGISTER_EXTERN_FUNC_2_IN_1_INT32(left_shift) @@ -236,7 +253,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { #undef REGISTER_EXTERN_FUNC_2_IN_1_INT32 #define REGISTER_EXTERN_FUNC_2_IN_1_INT64(func__) \ - REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(cinn_nvgpu_##func__##_int64, target, int64_t, int64_t, int64_t); + REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT( \ + cinn_nvgpu_##func__##_int64, target, int64_t, int64_t, int64_t); REGISTER_EXTERN_FUNC_2_IN_1_INT64(pow) REGISTER_EXTERN_FUNC_2_IN_1_INT64(bitwise_and) @@ -247,16 +265,16 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { #undef REGISTER_EXTERN_FUNC_2_IN_1_INT64 - FunctionProto::shape_inference_t inference_shape_globalpool = [](const std::vector &args, - int offset) { - auto t = args[0].as_tensor(); - std::vector shape; - shape.push_back(t->shape[0]); - shape.push_back(t->shape[1]); - shape.push_back(cinn::ir::Expr(1)); - shape.push_back(cinn::ir::Expr(1)); - return shape; - }; + FunctionProto::shape_inference_t inference_shape_globalpool = + [](const std::vector &args, int offset) { + auto t = args[0].as_tensor(); + std::vector shape; + shape.push_back(t->shape[0]); + shape.push_back(t->shape[1]); + shape.push_back(cinn::ir::Expr(1)); + shape.push_back(cinn::ir::Expr(1)); + return shape; + }; REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_cuda_find_int, target) .SetRetType() @@ -349,16 +367,17 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { #undef _REGISTER_CINN_NVGPU_GT_NUM -#define _REGISTER_CINN_NVGPU_INDEX_ADD(TYPE_SUFFIX, TYPE) \ - REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_nvgpu_index_add_##TYPE_SUFFIX, target) \ - .SetRetType() \ - .AddInputType() \ - .AddInputType() \ - .AddInputType() \ - .AddInputType() \ - .AddInputType() \ - .AddInputType() \ - .AddInputType() \ +#define _REGISTER_CINN_NVGPU_INDEX_ADD(TYPE_SUFFIX, TYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_nvgpu_index_add_##TYPE_SUFFIX, \ + target) \ + .SetRetType() \ + .AddInputType() \ + .AddInputType() \ + .AddInputType() \ + .AddInputType() \ + .AddInputType() \ + .AddInputType() \ + .AddInputType() \ .End(); _REGISTER_CINN_NVGPU_INDEX_ADD(bool, bool); @@ -403,7 +422,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { CINN_REGISTER_HELPER(cinn_cuda_host_api) { using cinn::runtime::cuda::cinn_call_cuda_kernel; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_cuda_kernel, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_cuda_kernel, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // kernel_fn .AddInputType() // args @@ -418,7 +438,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_cublas; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_cublas, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_cublas, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -438,7 +459,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_batched_cublas; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_batched_cublas, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_batched_cublas, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -459,7 +481,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_cuda_memset; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_cuda_memset, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_cuda_memset, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -469,7 +492,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_cuda_memcpy; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_cuda_memcpy, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_cuda_memcpy, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -478,7 +502,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_gaussian_random; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_gaussian_random, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_gaussian_random, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -489,7 +514,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_uniform_random; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_uniform_random, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_uniform_random, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -500,7 +526,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_randint; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_randint, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_randint, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -509,7 +536,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_cholesky_nvgpu; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_cholesky_nvgpu, cinn::common::DefaultNVGPUTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_cholesky_nvgpu, + cinn::common::DefaultNVGPUTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -520,7 +548,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_triangular_solve_nvgpu; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_triangular_solve_nvgpu, cinn::common::DefaultNVGPUTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_triangular_solve_nvgpu, + cinn::common::DefaultNVGPUTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -534,9 +563,11 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .AddInputType() // stream .End(); - // TODO(thisjiang): change msg type from 'int' to 'std::string' when custom call support 'std::string' type + // TODO(thisjiang): change msg type from 'int' to 'std::string' when custom + // call support 'std::string' type using cinn::runtime::cuda::cinn_assert_true_nvgpu; - REGISTER_EXTERN_FUNC_HELPER(cinn_assert_true_nvgpu, cinn::common::DefaultNVGPUTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_assert_true_nvgpu, + cinn::common::DefaultNVGPUTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -547,7 +578,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { #ifdef CINN_WITH_CUDNN using cinn::runtime::cuda::cinn_call_cudnn_conv2d_forward; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_conv2d_forward, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_conv2d_forward, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -577,7 +609,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_cudnn_conv2d_backward_data; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_conv2d_backward_data, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_conv2d_backward_data, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -607,7 +640,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_cudnn_conv2d_backward_filter; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_conv2d_backward_filter, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_conv2d_backward_filter, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -637,7 +671,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_cudnn_pool2d_forward; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_pool2d_forward, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_pool2d_forward, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -663,7 +698,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_cudnn_pool2d_backward; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_pool2d_backward, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_pool2d_backward, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // v_args .AddInputType() // num_args @@ -689,7 +725,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_cudnn_softmax_forward; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_softmax_forward, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_softmax_forward, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // args .AddInputType() // num_args @@ -709,7 +746,8 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .End(); using cinn::runtime::cuda::cinn_call_cudnn_softmax_backward; - REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_softmax_backward, cinn::common::DefaultHostTarget()) + REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_softmax_backward, + cinn::common::DefaultHostTarget()) .SetRetType() .AddInputType() // args .AddInputType() // num_args diff --git a/paddle/cinn/runtime/cuda/cuda_intrinsics_reduce.cc b/paddle/cinn/runtime/cuda/cuda_intrinsics_reduce.cc index d3c9e9423549e..9f7bdefcd2fcc 100644 --- a/paddle/cinn/runtime/cuda/cuda_intrinsics_reduce.cc +++ b/paddle/cinn/runtime/cuda/cuda_intrinsics_reduce.cc @@ -93,10 +93,11 @@ CINN_REGISTER_HELPER(cuda_intrinsics_reduce) { .AddInputType() .End(); -#define REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL(REDUCE_TYPE, DTYPE) \ - REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_block_reduce_##REDUCE_TYPE##_internal, target) \ - .SetRetType() \ - .AddInputType() \ +#define REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL(REDUCE_TYPE, DTYPE) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER( \ + cinn_block_reduce_##REDUCE_TYPE##_internal, target) \ + .SetRetType() \ + .AddInputType() \ .End(); EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL) diff --git a/paddle/cinn/runtime/cuda/cuda_module.cc b/paddle/cinn/runtime/cuda/cuda_module.cc index c5facd97bb1cd..2df567c547cbc 100644 --- a/paddle/cinn/runtime/cuda/cuda_module.cc +++ b/paddle/cinn/runtime/cuda/cuda_module.cc @@ -32,7 +32,8 @@ namespace cinn { namespace runtime { namespace cuda { -CUDAModule::CUDAModule(const std::string& data, Kind kind) : data_(data), kind_(kind) { +CUDAModule::CUDAModule(const std::string& data, Kind kind) + : data_(data), kind_(kind) { CHECK(!data.empty()); cudaGetDeviceCount(&num_devices_); @@ -54,13 +55,15 @@ void CUDAModule::LaunchKernel(int device_id, void** args, size_t share_memory_size, CUstream stream) { - VLOG(3) << "cuLaunchKernel with func_name : " << func_name << ", gridDim.x:" << gridDim.x - << ", gridDim.y:" << gridDim.y << ", gridDim.z:" << gridDim.z << ", blockDim.x:" << blockDim.x + VLOG(3) << "cuLaunchKernel with func_name : " << func_name + << ", gridDim.x:" << gridDim.x << ", gridDim.y:" << gridDim.y + << ", gridDim.z:" << gridDim.z << ", blockDim.x:" << blockDim.x << ", blockDim.y:" << blockDim.y << ", blockDim.z:" << blockDim.z << ", share_memory_size:" << share_memory_size; auto function = GetFunction(device_id, func_name); CHECK(function); - cinn::utils::RecordEvent record_run("cuLaunchKernel", cinn::utils::EventType::kInstruction); + cinn::utils::RecordEvent record_run("cuLaunchKernel", + cinn::utils::EventType::kInstruction); CUDA_DRIVER_CALL(cuLaunchKernel(function, gridDim.x, gridDim.y, @@ -74,9 +77,11 @@ void CUDAModule::LaunchKernel(int device_id, nullptr)); } -CUfunction CUDAModule::GetFunction(int device_id, const std::string& func_name) { +CUfunction CUDAModule::GetFunction(int device_id, + const std::string& func_name) { VLOG(5) << "GetFuncion : " << func_name << " with device_id : " << device_id; - cinn::utils::RecordEvent record_run("cuLaunchKernel", cinn::utils::EventType::kOrdinary); + cinn::utils::RecordEvent record_run("cuLaunchKernel", + cinn::utils::EventType::kOrdinary); if (!module_per_card_[device_id]) { std::lock_guard lock(mutex_); // Compilation with parameters @@ -85,9 +90,9 @@ CUfunction CUDAModule::GetFunction(int device_id, const std::string& func_name) std::vector jit_opt_vals(jit_num_options); // set up size of compilation log buffer - jit_options[0] = CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES; + jit_options[0] = CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES; size_t log_buffer_size = 1024; - jit_opt_vals[0] = reinterpret_cast(log_buffer_size); + jit_opt_vals[0] = reinterpret_cast(log_buffer_size); // set up pointer to the compilation log buffer jit_options[1] = CU_JIT_ERROR_LOG_BUFFER; @@ -96,43 +101,53 @@ CUfunction CUDAModule::GetFunction(int device_id, const std::string& func_name) int value = 1; // Specifies whether to create debug information in output (-g) - jit_options[2] = CU_JIT_GENERATE_DEBUG_INFO; + jit_options[2] = CU_JIT_GENERATE_DEBUG_INFO; jit_opt_vals[2] = reinterpret_cast(value); // Generate verbose log messages - jit_options[3] = CU_JIT_LOG_VERBOSE; + jit_options[3] = CU_JIT_LOG_VERBOSE; jit_opt_vals[3] = reinterpret_cast(value); // Generate line number information (-lineinfo) - jit_options[4] = CU_JIT_GENERATE_LINE_INFO; + jit_options[4] = CU_JIT_GENERATE_LINE_INFO; jit_opt_vals[4] = reinterpret_cast(value); if (runtime::CanUseNvccCompiler()) { - CUDA_DRIVER_CALL(cuModuleLoad(&module_per_card_[device_id], data_.c_str())); + CUDA_DRIVER_CALL( + cuModuleLoad(&module_per_card_[device_id], data_.c_str())); } else { - CUDA_DRIVER_CALL(cuModuleLoadDataEx( - &module_per_card_[device_id], data_.c_str(), jit_num_options, jit_options.data(), jit_opt_vals.data())); + CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module_per_card_[device_id], + data_.c_str(), + jit_num_options, + jit_options.data(), + jit_opt_vals.data())); } } CUfunction func; - CUDA_DRIVER_CALL(cuModuleGetFunction(&func, module_per_card_[device_id], func_name.c_str())); + CUDA_DRIVER_CALL(cuModuleGetFunction( + &func, module_per_card_[device_id], func_name.c_str())); return func; } -CUdeviceptr CUDAModule::GetGlobal(int device_id, const std::string& name, size_t nbytes) { +CUdeviceptr CUDAModule::GetGlobal(int device_id, + const std::string& name, + size_t nbytes) { if (!module_per_card_[device_id]) { std::lock_guard lock(mutex_); if (runtime::CanUseNvccCompiler()) { - CUDA_DRIVER_CALL(cuModuleLoad(&module_per_card_[device_id], data_.c_str())); + CUDA_DRIVER_CALL( + cuModuleLoad(&module_per_card_[device_id], data_.c_str())); } else { - CUDA_DRIVER_CALL(cuModuleLoadData(&module_per_card_[device_id], data_.c_str())); + CUDA_DRIVER_CALL( + cuModuleLoadData(&module_per_card_[device_id], data_.c_str())); } } size_t _nbytes; CUdeviceptr global; - CUDA_DRIVER_CALL(cuModuleGetGlobal(&global, &_nbytes, module_per_card_[device_id], name.c_str())); + CUDA_DRIVER_CALL(cuModuleGetGlobal( + &global, &_nbytes, module_per_card_[device_id], name.c_str())); return global; } diff --git a/paddle/cinn/runtime/cuda/cuda_module.h b/paddle/cinn/runtime/cuda/cuda_module.h index 8bb276a0c3e55..c904d7976d542 100644 --- a/paddle/cinn/runtime/cuda/cuda_module.h +++ b/paddle/cinn/runtime/cuda/cuda_module.h @@ -36,7 +36,7 @@ namespace cuda { class CUDAModule { public: enum class Kind { - PTX = 0, + PTX = 0, CUBIN = 1, }; @@ -48,7 +48,7 @@ class CUDAModule { dim3 blockDim, void** args, size_t share_memory_size = 0, - CUstream stream = nullptr); + CUstream stream = nullptr); //! Get a function. CUfunction GetFunction(int device_id, const std::string& func_name); diff --git a/paddle/cinn/runtime/cuda/cuda_module_test.cc b/paddle/cinn/runtime/cuda/cuda_module_test.cc index 85d66dd355596..3e150890cc3e6 100644 --- a/paddle/cinn/runtime/cuda/cuda_module_test.cc +++ b/paddle/cinn/runtime/cuda/cuda_module_test.cc @@ -102,13 +102,20 @@ TEST(CUDAModule, float16) { auto* y_p{y_device.data()}; void* args[] = {&x_p, &size, &y_p}; - cuda_module.LaunchKernel(0, "cast_fp32_to_fp16_cuda_kernel", blocks_per_grid, threads_per_block, args); + cuda_module.LaunchKernel(0, + "cast_fp32_to_fp16_cuda_kernel", + blocks_per_grid, + threads_per_block, + args); CUDA_CALL(cudaDeviceSynchronize()); std::vector y_host = y_device.to_host(); - bool res = std::equal(x_host.begin(), x_host.end(), y_host.begin(), [](float x, float16 y) -> bool { - return std::abs(x - static_cast(y)) < 1e-2f; - }); + bool res = std::equal(x_host.begin(), + x_host.end(), + y_host.begin(), + [](float x, float16 y) -> bool { + return std::abs(x - static_cast(y)) < 1e-2f; + }); CHECK(res) << "The difference between two arrays exceeds the bound."; } @@ -164,13 +171,20 @@ TEST(CUDAModule, bfloat16) { auto* y_p{y_device.data()}; void* args[] = {&x_p, &size, &y_p}; - cuda_module.LaunchKernel(0, "cast_fp32_to_bf16_cuda_kernel", blocks_per_grid, threads_per_block, args); + cuda_module.LaunchKernel(0, + "cast_fp32_to_bf16_cuda_kernel", + blocks_per_grid, + threads_per_block, + args); CUDA_CALL(cudaDeviceSynchronize()); std::vector y_host = y_device.to_host(); - bool res = std::equal(x_host.begin(), x_host.end(), y_host.begin(), [](float x, bfloat16 y) -> bool { - return std::abs(x - static_cast(y)) < 1e-2f; - }); + bool res = std::equal(x_host.begin(), + x_host.end(), + y_host.begin(), + [](float x, bfloat16 y) -> bool { + return std::abs(x - static_cast(y)) < 1e-2f; + }); CHECK(res) << "The difference between two arrays exceeds the bound."; } diff --git a/paddle/cinn/runtime/cuda/cuda_util.cc b/paddle/cinn/runtime/cuda/cuda_util.cc index 8a40c99aa328d..366863904a213 100644 --- a/paddle/cinn/runtime/cuda/cuda_util.cc +++ b/paddle/cinn/runtime/cuda/cuda_util.cc @@ -65,11 +65,14 @@ class CublasHandle { CUDA_CALL(cudaDeviceGetMemPool(&mem_pool, 0)); uint64_t threshold = UINT32_MAX; - CUDA_CALL(cudaMemPoolSetAttribute(mem_pool, cudaMemPoolAttrReleaseThreshold, &threshold)); + CUDA_CALL(cudaMemPoolSetAttribute( + mem_pool, cudaMemPoolAttrReleaseThreshold, &threshold)); int enable = 1; - CUDA_CALL(cudaMemPoolSetAttribute(mem_pool, cudaMemPoolReuseFollowEventDependencies, &enable)); - CUDA_CALL(cudaMemPoolSetAttribute(mem_pool, cudaMemPoolReuseAllowInternalDependencies, &enable)); + CUDA_CALL(cudaMemPoolSetAttribute( + mem_pool, cudaMemPoolReuseFollowEventDependencies, &enable)); + CUDA_CALL(cudaMemPoolSetAttribute( + mem_pool, cudaMemPoolReuseAllowInternalDependencies, &enable)); } cudaStream_t custream; cublasHandle_t cuhandle; @@ -85,12 +88,15 @@ void cinn_call_cuda_kernel(void *kernel_fn, int block_y, int block_z, void *stream) { - VLOG(3) << "cinn_call_cuda_kernel, grid_dim={" << grid_x << ", " << grid_y << ", " << grid_z << "}, block_dim={" - << block_x << ", " << block_y << ", " << block_z << "}, num_args=" << num_args << ", stream=" << stream; + VLOG(3) << "cinn_call_cuda_kernel, grid_dim={" << grid_x << ", " << grid_y + << ", " << grid_z << "}, block_dim={" << block_x << ", " << block_y + << ", " << block_z << "}, num_args=" << num_args + << ", stream=" << stream; std::vector kernel_args; { - cinn::utils::RecordEvent record_run("prepare_args", cinn::utils::EventType::kInstruction); + cinn::utils::RecordEvent record_run("prepare_args", + cinn::utils::EventType::kInstruction); kernel_args.reserve(num_args); cinn_pod_value_t *args = static_cast(v_args); for (int idx = 0; idx < num_args; ++idx) { @@ -103,7 +109,8 @@ void cinn_call_cuda_kernel(void *kernel_fn, } { - cinn::utils::RecordEvent record_run("cuLaunchKernel", cinn::utils::EventType::kInstruction); + cinn::utils::RecordEvent record_run("cuLaunchKernel", + cinn::utils::EventType::kInstruction); CUDA_DRIVER_CALL(cuLaunchKernel(static_cast(kernel_fn), grid_x, grid_y, @@ -134,15 +141,17 @@ void cinn_call_cublas(void *v_args, int b3, int b4, void *stream) { - cinn::utils::RecordEvent record_run("cinn_call_cublas", cinn::utils::EventType::kInstruction); + cinn::utils::RecordEvent record_run("cinn_call_cublas", + cinn::utils::EventType::kInstruction); CHECK_EQ(num_args, 3); cublasHandle_t &cuhandle = CublasHandle::GetInstance().GetCublasHandle(); - cinn_pod_value_t *args = static_cast(v_args); - cudaStream_t custream = static_cast(stream); + cinn_pod_value_t *args = static_cast(v_args); + cudaStream_t custream = static_cast(stream); CUBLAS_CALL(cublasSetStream(cuhandle, custream)); VLOG(3) << "a1 ~ a4: " << a1 << " " << a2 << " " << a3 << " " << a4; VLOG(3) << "b1 ~ b4: " << b1 << " " << b2 << " " << b3 << " " << b4; - VLOG(3) << "trans_a: " << trans_a << ", trans_b: " << trans_b << ", trans_o: " << trans_o; + VLOG(3) << "trans_a: " << trans_a << ", trans_b: " << trans_b + << ", trans_o: " << trans_o; void *A = args[0].operator cinn_buffer_t *()->memory; void *B = args[1].operator cinn_buffer_t *()->memory; @@ -152,22 +161,28 @@ void cinn_call_cublas(void *v_args, int n = trans_o ? (trans_b ? b3 : b4) : (trans_a ? a4 : a3); int k = trans_a ? a3 : a4; - cublasOperation_t trans_op_l = - trans_o ? (trans_a ? CUBLAS_OP_N : CUBLAS_OP_T) : (trans_b ? CUBLAS_OP_T : CUBLAS_OP_N); - cublasOperation_t trans_op_r = - trans_o ? (trans_b ? CUBLAS_OP_N : CUBLAS_OP_T) : (trans_a ? CUBLAS_OP_T : CUBLAS_OP_N); - int ldl = trans_op_l == CUBLAS_OP_N ? m : k; // trans_o ? (trans_a ? k : m) : (trans_b ? k : m); - int ldr = trans_op_r == CUBLAS_OP_N ? k : n; // trans_o ? (trans_b ? n : k) : (trans_a ? n : k); + cublasOperation_t trans_op_l = trans_o + ? (trans_a ? CUBLAS_OP_N : CUBLAS_OP_T) + : (trans_b ? CUBLAS_OP_T : CUBLAS_OP_N); + cublasOperation_t trans_op_r = trans_o + ? (trans_b ? CUBLAS_OP_N : CUBLAS_OP_T) + : (trans_a ? CUBLAS_OP_T : CUBLAS_OP_N); + int ldl = trans_op_l == CUBLAS_OP_N + ? m + : k; // trans_o ? (trans_a ? k : m) : (trans_b ? k : m); + int ldr = trans_op_r == CUBLAS_OP_N + ? k + : n; // trans_o ? (trans_b ? n : k) : (trans_a ? n : k); int ldc = m; void *lhs = trans_o ? A : B; void *rhs = trans_o ? B : A; cudaDataType_t cuda_dtype; - auto type_code = args[0].operator cinn_buffer_t *()->type.code; - bool is_float = type_code == cinn_type_float; + auto type_code = args[0].operator cinn_buffer_t *()->type.code; + bool is_float = type_code == cinn_type_float; bool is_bfloat16 = type_code == cinn_type_bfloat; - int bytes = args[0].operator cinn_buffer_t *()->type.bits / CHAR_BIT; + int bytes = args[0].operator cinn_buffer_t *()->type.bits / CHAR_BIT; if (is_float && bytes == sizeof(common::float16)) { cuda_dtype = CUDA_R_16F; } else if (is_float && bytes == sizeof(float)) { @@ -177,29 +192,61 @@ void cinn_call_cublas(void *v_args, } else if (is_bfloat16) { cuda_dtype = CUDA_R_16BF; } else { - LOG(FATAL) << "unsupported cublas data type: " << static_cast(type_code) << ", bytes = " << bytes; + LOG(FATAL) << "unsupported cublas data type: " + << static_cast(type_code) << ", bytes = " << bytes; } if (a1 * a2 * b1 * b2 == 1) { VLOG(3) << "call cublasGemm for a1 * a2 * b1 * b2 == 1"; - cinn::utils::RecordEvent record_run("Call cublasGemm", cinn::utils::EventType::kInstruction); - CUBLAS_CALL( - cublasGemm(cuda_dtype, cuhandle, trans_op_l, trans_op_r, m, n, k, alpha, lhs, ldl, rhs, ldr, beta, C, ldc)); + cinn::utils::RecordEvent record_run("Call cublasGemm", + cinn::utils::EventType::kInstruction); + CUBLAS_CALL(cublasGemm(cuda_dtype, + cuhandle, + trans_op_l, + trans_op_r, + m, + n, + k, + alpha, + lhs, + ldl, + rhs, + ldr, + beta, + C, + ldc)); } else if (a1 * b1 == 1) { CHECK(a2 == b2 || a2 == 1 || b2 == 1); if (b2 == 1 && trans_op_r == CUBLAS_OP_N) { // In case of [1, bs, M, K] * [1, 1, K, N] - VLOG(3) << "call cublasGemm for a1 * b1 = 1, b2 = 1, trans_op_r:" << trans_op_r; - cinn::utils::RecordEvent record_run("Call cublasGemm", cinn::utils::EventType::kInstruction); - CUBLAS_CALL(cublasGemm( - cuda_dtype, cuhandle, trans_op_l, trans_op_r, m, a2 * n, k, alpha, lhs, ldl, A, ldr, beta, C, ldc)); + VLOG(3) << "call cublasGemm for a1 * b1 = 1, b2 = 1, trans_op_r:" + << trans_op_r; + cinn::utils::RecordEvent record_run("Call cublasGemm", + cinn::utils::EventType::kInstruction); + CUBLAS_CALL(cublasGemm(cuda_dtype, + cuhandle, + trans_op_l, + trans_op_r, + m, + a2 * n, + k, + alpha, + lhs, + ldl, + A, + ldr, + beta, + C, + ldc)); } else { int stride_l = trans_o ? (a2 > 1 ? a3 * a4 : 0) : (b2 > 1 ? b3 * b4 : 0); int stride_r = trans_o ? (b2 > 1 ? b3 * b4 : 0) : (a2 > 1 ? a3 * a4 : 0); - int batch = std::max(a2, b2); - VLOG(3) << "call cublasGemmStridedBatched with a1*b1 = 1, stride_l = " << stride_l << ", stride_r = " << stride_r + int batch = std::max(a2, b2); + VLOG(3) << "call cublasGemmStridedBatched with a1*b1 = 1, stride_l = " + << stride_l << ", stride_r = " << stride_r << ", batch = " << batch; - cinn::utils::RecordEvent record_run("Call cublasGemmStridedBatched", cinn::utils::EventType::kInstruction); + cinn::utils::RecordEvent record_run("Call cublasGemmStridedBatched", + cinn::utils::EventType::kInstruction); CUBLAS_CALL(cublasGemmStridedBatched(cuda_dtype, cuhandle, trans_op_l, @@ -221,44 +268,53 @@ void cinn_call_cublas(void *v_args, batch)); } } else { - int l1 = trans_o ? a1 : b1, l2 = trans_o ? a2 : b2, l3 = trans_o ? a3 : b3, l4 = trans_o ? a4 : b4; - int r1 = trans_o ? b1 : a1, r2 = trans_o ? b2 : a2, r3 = trans_o ? b3 : a3, r4 = trans_o ? b4 : a4; + int l1 = trans_o ? a1 : b1, l2 = trans_o ? a2 : b2, l3 = trans_o ? a3 : b3, + l4 = trans_o ? a4 : b4; + int r1 = trans_o ? b1 : a1, r2 = trans_o ? b2 : a2, r3 = trans_o ? b3 : a3, + r4 = trans_o ? b4 : a4; - if ((l1 == r1 && l2 == r2) || (l1 == 1 && l2 == 1) || (r1 == 1 && r2 == 1)) { + if ((l1 == r1 && l2 == r2) || (l1 == 1 && l2 == 1) || + (r1 == 1 && r2 == 1)) { int stride_l = (l1 == 1 && l2 == 1) ? 0 : l3 * l4; int stride_r = (r1 == 1 && r2 == 1) ? 0 : r3 * r4; // four types matmul: // (N, L) * (N, L) , (N, 1) * (N, 1) // (N, L) * (1, 1) , (1, 1) * (N, L) - VLOG(3) << "call cublasGemmStridedBatched for stride_l = " << stride_l << ", stride_r = " << stride_r + VLOG(3) << "call cublasGemmStridedBatched for stride_l = " << stride_l + << ", stride_r = " << stride_r << ", batch = " << std::max(l1, r1) * std::max(l2, r2); - cinn::utils::RecordEvent record_run("Call cublasGemmStridedBatched", cinn::utils::EventType::kInstruction); - CUBLAS_CALL(cublasGemmStridedBatched(cuda_dtype, - cuhandle, - trans_op_l, - trans_op_r, - m, - n, - k, - alpha, - lhs, - ldl, - stride_l, - rhs, - ldr, - stride_r, - beta, - C, - ldc, - m * n, - std::max(l1, r1) * std::max(l2, r2))); + cinn::utils::RecordEvent record_run("Call cublasGemmStridedBatched", + cinn::utils::EventType::kInstruction); + CUBLAS_CALL( + cublasGemmStridedBatched(cuda_dtype, + cuhandle, + trans_op_l, + trans_op_r, + m, + n, + k, + alpha, + lhs, + ldl, + stride_l, + rhs, + ldr, + stride_r, + beta, + C, + ldc, + m * n, + std::max(l1, r1) * std::max(l2, r2))); } else { - cinn::utils::RecordEvent record_run("Call cublasGemmBatched", cinn::utils::EventType::kInstruction); + cinn::utils::RecordEvent record_run("Call cublasGemmBatched", + cinn::utils::EventType::kInstruction); // (N, L) / (N, 1) / (1, L) - int bstride_l = (l1 != 1 && l2 != 1) ? (l2 * m * k) : ((l1 != 1) ? m * k : 0); + int bstride_l = + (l1 != 1 && l2 != 1) ? (l2 * m * k) : ((l1 != 1) ? m * k : 0); // (N, L) / (N, 1) / (1, L) - int bstride_r = (r1 != 1 && r2 != 1) ? (r2 * k * n) : ((r1 != 1) ? k * n : 0); + int bstride_r = + (r1 != 1 && r2 != 1) ? (r2 * k * n) : ((r1 != 1) ? k * n : 0); int bstride_c = std::max(l2, r2) * m * n; int stride_l = l2 == 1 ? 0 : l3 * l4; @@ -268,9 +324,12 @@ void cinn_call_cublas(void *v_args, // (N, 1) * (N, L) , (1, L) * (N, L) // (N, 1) * (1, L) , (1, L) * (N, 1) - void **ptr_arr = nullptr; + void **ptr_arr = nullptr; cudaStream_t g_stream = CublasHandle::GetInstance().GetCuStream(); - CUDA_CALL(cudaMallocAsync(&ptr_arr, sizeof(void *) * 3 * std::max(l1, r1) * std::max(l2, r2), g_stream)); + CUDA_CALL(cudaMallocAsync( + &ptr_arr, + sizeof(void *) * 3 * std::max(l1, r1) * std::max(l2, r2), + g_stream)); std::vector ptr(3 * std::max(l1, r1) * std::max(l2, r2)); void **ptr_a = ptr.data(); @@ -279,31 +338,39 @@ void cinn_call_cublas(void *v_args, for (int idx = 0, index = 0; idx < std::max(l1, r1); ++idx) { for (int idy = 0; idy < std::max(l2, r2); ++idy) { - ptr_a[index] = reinterpret_cast(lhs) + (idx * bstride_l + idy * stride_l) * bytes; - ptr_b[index] = reinterpret_cast(rhs) + (idx * bstride_r + idy * stride_r) * bytes; - ptr_c[index] = reinterpret_cast(C) + (idx * bstride_c + idy * m * n) * bytes; + ptr_a[index] = reinterpret_cast(lhs) + + (idx * bstride_l + idy * stride_l) * bytes; + ptr_b[index] = reinterpret_cast(rhs) + + (idx * bstride_r + idy * stride_r) * bytes; + ptr_c[index] = reinterpret_cast(C) + + (idx * bstride_c + idy * m * n) * bytes; ++index; } } - CUDA_CALL(cudaMemcpyAsync(ptr_arr, ptr.data(), ptr.size() * sizeof(void *), cudaMemcpyHostToDevice, g_stream)); + CUDA_CALL(cudaMemcpyAsync(ptr_arr, + ptr.data(), + ptr.size() * sizeof(void *), + cudaMemcpyHostToDevice, + g_stream)); CUDA_CALL(cudaStreamSynchronize(g_stream)); - CUBLAS_CALL(cublasGemmBatched(cuda_dtype, - cuhandle, - trans_op_l, - trans_op_r, - m, - n, - k, - alpha, - ptr_arr, - ldl, - ptr_arr + std::max(l1, r1) * std::max(l2, r2), - ldr, - beta, - ptr_arr + std::max(l1, r1) * std::max(l2, r2) * 2, - ldc, - std::max(l1, r1) * std::max(l2, r2))); + CUBLAS_CALL( + cublasGemmBatched(cuda_dtype, + cuhandle, + trans_op_l, + trans_op_r, + m, + n, + k, + alpha, + ptr_arr, + ldl, + ptr_arr + std::max(l1, r1) * std::max(l2, r2), + ldr, + beta, + ptr_arr + std::max(l1, r1) * std::max(l2, r2) * 2, + ldc, + std::max(l1, r1) * std::max(l2, r2))); CUDA_CALL(cudaFreeAsync(ptr_arr, custream)); } } @@ -329,15 +396,15 @@ void cinn_call_batched_cublas(void *v_args, // A * [B, C, D, ...] or [B, C, D, ...] * A CHECK_EQ((num_args - 1) % 2, 0); cublasHandle_t &cuhandle = CublasHandle::GetInstance().GetCublasHandle(); - cinn_pod_value_t *args = static_cast(v_args); - cudaStream_t custream = static_cast(stream); + cinn_pod_value_t *args = static_cast(v_args); + cudaStream_t custream = static_cast(stream); CUBLAS_CALL(cublasSetStream(cuhandle, custream)); cudaDataType_t cuda_dtype; - auto type_code = args[0].operator cinn_buffer_t *()->type.code; - bool is_float = type_code == cinn_type_float; + auto type_code = args[0].operator cinn_buffer_t *()->type.code; + bool is_float = type_code == cinn_type_float; bool is_bfloat16 = type_code == cinn_type_bfloat; - int bytes = args[0].operator cinn_buffer_t *()->type.bits / CHAR_BIT; + int bytes = args[0].operator cinn_buffer_t *()->type.bits / CHAR_BIT; if (is_float && bytes == sizeof(common::float16)) { cuda_dtype = CUDA_R_16F; } else if (is_float && bytes == sizeof(float)) { @@ -347,23 +414,32 @@ void cinn_call_batched_cublas(void *v_args, } else if (is_bfloat16) { cuda_dtype = CUDA_R_16BF; } else { - LOG(FATAL) << "unsupported cublas data type: " << static_cast(type_code) << ", bytes = " << bytes; + LOG(FATAL) << "unsupported cublas data type: " + << static_cast(type_code) << ", bytes = " << bytes; } int m = trans_o ? (trans_a ? a4 : a3) : (trans_b ? b3 : b4); int n = trans_o ? (trans_b ? b3 : b4) : (trans_a ? a4 : a3); int k = trans_a ? a3 : a4; - cublasOperation_t trans_op_l = - trans_o ? (trans_a ? CUBLAS_OP_N : CUBLAS_OP_T) : (trans_b ? CUBLAS_OP_T : CUBLAS_OP_N); - cublasOperation_t trans_op_r = - trans_o ? (trans_b ? CUBLAS_OP_N : CUBLAS_OP_T) : (trans_a ? CUBLAS_OP_T : CUBLAS_OP_N); - int ldl = trans_op_l == CUBLAS_OP_N ? m : k; // trans_o ? (trans_a ? k : m) : (trans_b ? k : m); - int ldr = trans_op_r == CUBLAS_OP_N ? k : n; // trans_o ? (trans_b ? n : k) : (trans_a ? n : k); + cublasOperation_t trans_op_l = trans_o + ? (trans_a ? CUBLAS_OP_N : CUBLAS_OP_T) + : (trans_b ? CUBLAS_OP_T : CUBLAS_OP_N); + cublasOperation_t trans_op_r = trans_o + ? (trans_b ? CUBLAS_OP_N : CUBLAS_OP_T) + : (trans_a ? CUBLAS_OP_T : CUBLAS_OP_N); + int ldl = trans_op_l == CUBLAS_OP_N + ? m + : k; // trans_o ? (trans_a ? k : m) : (trans_b ? k : m); + int ldr = trans_op_r == CUBLAS_OP_N + ? k + : n; // trans_o ? (trans_b ? n : k) : (trans_a ? n : k); int ldc = m; - int l1 = trans_o ? a1 : b1, l2 = trans_o ? a2 : b2, l3 = trans_o ? a3 : b3, l4 = trans_o ? a4 : b4; - int r1 = trans_o ? b1 : a1, r2 = trans_o ? b2 : a2, r3 = trans_o ? b3 : a3, r4 = trans_o ? b4 : a4; + int l1 = trans_o ? a1 : b1, l2 = trans_o ? a2 : b2, l3 = trans_o ? a3 : b3, + l4 = trans_o ? a4 : b4; + int r1 = trans_o ? b1 : a1, r2 = trans_o ? b2 : a2, r3 = trans_o ? b3 : a3, + r4 = trans_o ? b4 : a4; // (N, L): L * M * K // (N, 1): 1 * M * K @@ -383,9 +459,10 @@ void cinn_call_batched_cublas(void *v_args, std::vector ptr(3 * std::max(l1, r1) * std::max(l2, r2) * num_gemm); void **ptr_a = ptr.data(); void **ptr_b = ptr.data() + std::max(l1, r1) * std::max(l2, r2) * num_gemm; - void **ptr_c = ptr.data() + std::max(l1, r1) * std::max(l2, r2) * num_gemm * 2; + void **ptr_c = + ptr.data() + std::max(l1, r1) * std::max(l2, r2) * num_gemm * 2; - void **ptr_arr = nullptr; + void **ptr_arr = nullptr; cudaStream_t g_stream = CublasHandle::GetInstance().GetCuStream(); CUDA_CALL(cudaMallocAsync(&ptr_arr, sizeof(void *) * ptr.size(), g_stream)); @@ -397,8 +474,8 @@ void cinn_call_batched_cublas(void *v_args, // if opside is 1, exhange A,B. if (opside) { auto tmp = A; - A = B; - B = tmp; + A = B; + B = tmp; } void *lhs = trans_o ? A : B; @@ -406,59 +483,74 @@ void cinn_call_batched_cublas(void *v_args, for (int idx = 0; idx < std::max(l1, r1); ++idx) { for (int idy = 0; idy < std::max(l2, r2); ++idy) { - ptr_a[index] = reinterpret_cast(lhs) + (idx * bstride_l + idy * stride_l) * bytes; - ptr_b[index] = reinterpret_cast(rhs) + (idx * bstride_r + idy * stride_r) * bytes; - ptr_c[index] = reinterpret_cast(C) + (idx * bstride_c + idy * m * n) * bytes; + ptr_a[index] = reinterpret_cast(lhs) + + (idx * bstride_l + idy * stride_l) * bytes; + ptr_b[index] = reinterpret_cast(rhs) + + (idx * bstride_r + idy * stride_r) * bytes; + ptr_c[index] = reinterpret_cast(C) + + (idx * bstride_c + idy * m * n) * bytes; ++index; } } } - CUDA_CALL(cudaMemcpyAsync(ptr_arr, ptr.data(), ptr.size() * sizeof(void *), cudaMemcpyHostToDevice, g_stream)); + CUDA_CALL(cudaMemcpyAsync(ptr_arr, + ptr.data(), + ptr.size() * sizeof(void *), + cudaMemcpyHostToDevice, + g_stream)); CUDA_CALL(cudaStreamSynchronize(g_stream)); - CUBLAS_CALL(cublasGemmBatched(cuda_dtype, - cuhandle, - trans_op_l, - trans_op_r, - m, - n, - k, - alpha, - ptr_arr, - ldl, - ptr_arr + std::max(l1, r1) * std::max(l2, r2) * num_gemm, - ldr, - beta, - ptr_arr + std::max(l1, r1) * std::max(l2, r2) * 2 * num_gemm, - ldc, - std::max(l1, r1) * std::max(l2, r2) * num_gemm)); + CUBLAS_CALL(cublasGemmBatched( + cuda_dtype, + cuhandle, + trans_op_l, + trans_op_r, + m, + n, + k, + alpha, + ptr_arr, + ldl, + ptr_arr + std::max(l1, r1) * std::max(l2, r2) * num_gemm, + ldr, + beta, + ptr_arr + std::max(l1, r1) * std::max(l2, r2) * 2 * num_gemm, + ldc, + std::max(l1, r1) * std::max(l2, r2) * num_gemm)); CUDA_CALL(cudaFreeAsync(ptr_arr, custream)); } -void cinn_call_cuda_memset(void *v_args, int num_args, int value, size_t count, void *stream) { +void cinn_call_cuda_memset( + void *v_args, int num_args, int value, size_t count, void *stream) { CHECK_EQ(num_args, 1) << "The cinn_call_cuda_memset only accept a output"; - VLOG(4) << "call cinn_call_cuda_memset with value=" << value << ", count=" << count; + VLOG(4) << "call cinn_call_cuda_memset with value=" << value + << ", count=" << count; cinn_pod_value_t *args = static_cast(v_args); - void *output = args[0].operator cinn_buffer_t *()->memory; + void *output = args[0].operator cinn_buffer_t *()->memory; cudaStream_t custream = static_cast(stream); CUDA_CALL(cudaMemsetAsync(output, value, count, custream)); } -void cinn_call_cuda_memcpy(void *v_args, int num_args, size_t count, void *stream) { - CHECK_EQ(num_args, 2) << "The cinn_call_cuda_memcpy only accept a input and a output"; +void cinn_call_cuda_memcpy(void *v_args, + int num_args, + size_t count, + void *stream) { + CHECK_EQ(num_args, 2) + << "The cinn_call_cuda_memcpy only accept a input and a output"; VLOG(4) << "call cinn_call_cuda_memcpy with count=" << count; cinn_pod_value_t *args = static_cast(v_args); - void *input = args[0].operator cinn_buffer_t *()->memory; - void *output = args[1].operator cinn_buffer_t *()->memory; + void *input = args[0].operator cinn_buffer_t *()->memory; + void *output = args[1].operator cinn_buffer_t *()->memory; cudaStream_t custream = static_cast(stream); - CUDA_CALL(cudaMemcpyAsync(output, input, count, cudaMemcpyDeviceToDevice, custream)); + CUDA_CALL(cudaMemcpyAsync( + output, input, count, cudaMemcpyDeviceToDevice, custream)); } #ifdef CINN_WITH_CUDNN @@ -491,7 +583,9 @@ class CudnnHandle { } private: - CudnnHandle() : workspace_(nullptr), size_(0) { CUDNN_CALL(cudnnCreate(&cuhandle_)); } + CudnnHandle() : workspace_(nullptr), size_(0) { + CUDNN_CALL(cudnnCreate(&cuhandle_)); + } cudnnHandle_t cuhandle_; void *workspace_; size_t size_; @@ -505,8 +599,12 @@ class ConvAlgoMap { static ConvAlgoMap instance; return instance; } - void InsertAlgo(const std::string &key, const int algo) { algo_map_[key] = algo; } - int GetAlgo(const std::string &key) { return algo_map_.count(key) ? algo_map_[key] : -1; } + void InsertAlgo(const std::string &key, const int algo) { + algo_map_[key] = algo; + } + int GetAlgo(const std::string &key) { + return algo_map_.count(key) ? algo_map_[key] : -1; + } private: ConvAlgoMap() {} @@ -516,17 +614,17 @@ class ConvAlgoMap { cudnnDataType_t convert_to_cudnn_dtype(void *v_args, int num_args) { CHECK_GT(num_args, 0) << "the number of arguments must larger than zero"; cinn_pod_value_t *args = static_cast(v_args); - auto type_code = args[0].operator cinn_buffer_t *()->type.code; - int bits = args[0].operator cinn_buffer_t *()->type.bits; + auto type_code = args[0].operator cinn_buffer_t *()->type.code; + int bits = args[0].operator cinn_buffer_t *()->type.bits; for (int i = 1; i < num_args; ++i) { auto t = args[i].operator cinn_buffer_t *()->type.code; - int b = args[0].operator cinn_buffer_t *()->type.bits; + int b = args[0].operator cinn_buffer_t *()->type.bits; if (t != type_code || bits != b) { LOG(FATAL) << "The types of all arguments need to be consistent."; } } cudnnDataType_t data_type; - bool is_float = type_code == cinn_type_float; + bool is_float = type_code == cinn_type_float; bool is_bfloat16 = type_code == cinn_type_bfloat; if (is_float && bits == 16) { data_type = CUDNN_DATA_HALF; @@ -537,7 +635,8 @@ cudnnDataType_t convert_to_cudnn_dtype(void *v_args, int num_args) { } else if (is_float && bits == 64) { data_type = CUDNN_DATA_DOUBLE; } else { - LOG(FATAL) << "unsupported cudnn data type: " << static_cast(type_code) << ", bits = " << bits; + LOG(FATAL) << "unsupported cudnn data type: " << static_cast(type_code) + << ", bits = " << bits; } return data_type; } @@ -551,7 +650,8 @@ cudnnDataType_t get_cudnn_compute_dtype(cudnnDataType_t data_type) { case CUDNN_DATA_DOUBLE: return CUDNN_DATA_DOUBLE; default: - LOG(FATAL) << "unsupported cudnn data type, only support float16/bfloat16/float32/float64 now!"; + LOG(FATAL) << "unsupported cudnn data type, only support " + "float16/bfloat16/float32/float64 now!"; } return CUDNN_DATA_FLOAT; } @@ -629,47 +729,64 @@ void cinn_call_cudnn_conv2d_forward(void *v_args, cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle(); CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); cinn_pod_value_t *args = static_cast(v_args); - void *_x = args[0].operator cinn_buffer_t *()->memory; - void *_w = args[1].operator cinn_buffer_t *()->memory; - void *_y = args[2].operator cinn_buffer_t *()->memory; + void *_x = args[0].operator cinn_buffer_t *()->memory; + void *_w = args[1].operator cinn_buffer_t *()->memory; + void *_y = args[2].operator cinn_buffer_t *()->memory; cudnnTensorFormat_t tensor_format = static_cast(format); - cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); + cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); cudnnTensorDescriptor_t x_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); cudnnFilterDescriptor_t w_desc; CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc)); - CUDNN_CALL(cudnnSetFilter4dDescriptor(w_desc, data_type, tensor_format, filter_n, filter_c, filter_h, filter_w)); + CUDNN_CALL(cudnnSetFilter4dDescriptor(w_desc, + data_type, + tensor_format, + filter_n, + filter_c, + filter_h, + filter_w)); cudnnConvolutionDescriptor_t conv_desc; CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc)); - CUDNN_CALL(cudnnSetConvolution2dDescriptor(conv_desc, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - CUDNN_CROSS_CORRELATION, - get_cudnn_compute_dtype(data_type))); + CUDNN_CALL( + cudnnSetConvolution2dDescriptor(conv_desc, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + CUDNN_CROSS_CORRELATION, + get_cudnn_compute_dtype(data_type))); CUDNN_CALL(cudnnSetConvolutionGroupCount(conv_desc, groups)); CUDNN_CALL(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH)); cudnnTensorDescriptor_t y_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, tensor_format, data_type, output_n, output_c, output_h, output_w)); - - auto &conv_algo_map = ConvAlgoMap::GetInstance(); - std::string hash_key = "conv2d forward, layout=" + debug_cudnn_tensor_format(tensor_format) + - ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" + std::to_string(input_n) + - "," + std::to_string(input_c) + "," + std::to_string(input_h) + "," + std::to_string(input_w) + - "}, filter_nchw={" + std::to_string(filter_n) + "," + std::to_string(filter_c) + "," + - std::to_string(filter_h) + "," + std::to_string(filter_w) + "}, output_nchw={" + - std::to_string(output_n) + "," + std::to_string(output_c) + "," + std::to_string(output_h) + - "," + std::to_string(output_w) + "}"; + CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, + tensor_format, + data_type, + output_n, + output_c, + output_h, + output_w)); + + auto &conv_algo_map = ConvAlgoMap::GetInstance(); + std::string hash_key = + "conv2d forward, layout=" + debug_cudnn_tensor_format(tensor_format) + + ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" + + std::to_string(input_n) + "," + std::to_string(input_c) + "," + + std::to_string(input_h) + "," + std::to_string(input_w) + + "}, filter_nchw={" + std::to_string(filter_n) + "," + + std::to_string(filter_c) + "," + std::to_string(filter_h) + "," + + std::to_string(filter_w) + "}, output_nchw={" + std::to_string(output_n) + + "," + std::to_string(output_c) + "," + std::to_string(output_h) + "," + + std::to_string(output_w) + "}"; VLOG(4) << hash_key; cudnnConvolutionFwdAlgo_t algo; int algo_int = conv_algo_map.GetAlgo(hash_key); @@ -678,7 +795,8 @@ void cinn_call_cudnn_conv2d_forward(void *v_args, } else { int count = 0; cudnnConvolutionFwdAlgoPerf_t algo_perf; - CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(handle, x_desc, w_desc, conv_desc, y_desc, 1, &count, &algo_perf)); + CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm( + handle, x_desc, w_desc, conv_desc, y_desc, 1, &count, &algo_perf)); algo = algo_perf.algo; conv_algo_map.InsertAlgo(hash_key, static_cast(algo_perf.algo)); @@ -689,12 +807,14 @@ void cinn_call_cudnn_conv2d_forward(void *v_args, } size_t workspace_size = 0; - CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(handle, x_desc, w_desc, conv_desc, y_desc, algo, &workspace_size)); + CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize( + handle, x_desc, w_desc, conv_desc, y_desc, algo, &workspace_size)); - void *workspace_data = CudnnHandle::GetInstance().GetWorkSpace(workspace_size); + void *workspace_data = + CudnnHandle::GetInstance().GetWorkSpace(workspace_size); if (data_type == CUDNN_DATA_DOUBLE) { const double alpha_fp64 = static_cast(alpha); - const double beta_fp64 = static_cast(beta); + const double beta_fp64 = static_cast(beta); CUDNN_CALL(cudnnConvolutionForward(handle, &alpha_fp64, x_desc, @@ -709,8 +829,19 @@ void cinn_call_cudnn_conv2d_forward(void *v_args, y_desc, _y)); } else { - CUDNN_CALL(cudnnConvolutionForward( - handle, &alpha, x_desc, _x, w_desc, _w, conv_desc, algo, workspace_data, workspace_size, &beta, y_desc, _y)); + CUDNN_CALL(cudnnConvolutionForward(handle, + &alpha, + x_desc, + _x, + w_desc, + _w, + conv_desc, + algo, + workspace_data, + workspace_size, + &beta, + y_desc, + _y)); } CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc)); @@ -748,47 +879,65 @@ void cinn_call_cudnn_conv2d_backward_data(void *v_args, cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle(); CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); cinn_pod_value_t *args = static_cast(v_args); - void *_w = args[0].operator cinn_buffer_t *()->memory; - void *_dy = args[1].operator cinn_buffer_t *()->memory; - void *_dx = args[2].operator cinn_buffer_t *()->memory; + void *_w = args[0].operator cinn_buffer_t *()->memory; + void *_dy = args[1].operator cinn_buffer_t *()->memory; + void *_dx = args[2].operator cinn_buffer_t *()->memory; cudnnTensorFormat_t tensor_format = static_cast(format); - cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); + cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); cudnnTensorDescriptor_t x_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); cudnnFilterDescriptor_t w_desc; CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc)); - CUDNN_CALL(cudnnSetFilter4dDescriptor(w_desc, data_type, tensor_format, filter_n, filter_c, filter_h, filter_w)); + CUDNN_CALL(cudnnSetFilter4dDescriptor(w_desc, + data_type, + tensor_format, + filter_n, + filter_c, + filter_h, + filter_w)); cudnnConvolutionDescriptor_t conv_desc; CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc)); - CUDNN_CALL(cudnnSetConvolution2dDescriptor(conv_desc, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - CUDNN_CROSS_CORRELATION, - get_cudnn_compute_dtype(data_type))); + CUDNN_CALL( + cudnnSetConvolution2dDescriptor(conv_desc, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + CUDNN_CROSS_CORRELATION, + get_cudnn_compute_dtype(data_type))); CUDNN_CALL(cudnnSetConvolutionGroupCount(conv_desc, groups)); CUDNN_CALL(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH)); cudnnTensorDescriptor_t y_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, tensor_format, data_type, output_n, output_c, output_h, output_w)); - - auto &conv_algo_map = ConvAlgoMap::GetInstance(); - std::string hash_key = "conv2d backward data, layout=" + debug_cudnn_tensor_format(tensor_format) + - ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" + std::to_string(input_n) + - "," + std::to_string(input_c) + "," + std::to_string(input_h) + "," + std::to_string(input_w) + - "}, filter_nchw={" + std::to_string(filter_n) + "," + std::to_string(filter_c) + "," + - std::to_string(filter_h) + "," + std::to_string(filter_w) + "}, output_nchw={" + - std::to_string(output_n) + "," + std::to_string(output_c) + "," + std::to_string(output_h) + - "," + std::to_string(output_w) + "}"; + CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, + tensor_format, + data_type, + output_n, + output_c, + output_h, + output_w)); + + auto &conv_algo_map = ConvAlgoMap::GetInstance(); + std::string hash_key = + "conv2d backward data, layout=" + + debug_cudnn_tensor_format(tensor_format) + + ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" + + std::to_string(input_n) + "," + std::to_string(input_c) + "," + + std::to_string(input_h) + "," + std::to_string(input_w) + + "}, filter_nchw={" + std::to_string(filter_n) + "," + + std::to_string(filter_c) + "," + std::to_string(filter_h) + "," + + std::to_string(filter_w) + "}, output_nchw={" + std::to_string(output_n) + + "," + std::to_string(output_c) + "," + std::to_string(output_h) + "," + + std::to_string(output_w) + "}"; VLOG(4) << hash_key; @@ -799,8 +948,8 @@ void cinn_call_cudnn_conv2d_backward_data(void *v_args, } else { int count = 0; cudnnConvolutionBwdDataAlgoPerf_t algo_perf; - CUDNN_CALL( - cudnnFindConvolutionBackwardDataAlgorithm(handle, w_desc, y_desc, conv_desc, x_desc, 1, &count, &algo_perf)); + CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm( + handle, w_desc, y_desc, conv_desc, x_desc, 1, &count, &algo_perf)); algo = algo_perf.algo; conv_algo_map.InsertAlgo(hash_key, static_cast(algo_perf.algo)); @@ -811,13 +960,14 @@ void cinn_call_cudnn_conv2d_backward_data(void *v_args, } size_t workspace_size = 0; - CUDNN_CALL( - cudnnGetConvolutionBackwardDataWorkspaceSize(handle, w_desc, y_desc, conv_desc, x_desc, algo, &workspace_size)); + CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, w_desc, y_desc, conv_desc, x_desc, algo, &workspace_size)); - void *workspace_data = CudnnHandle::GetInstance().GetWorkSpace(workspace_size); + void *workspace_data = + CudnnHandle::GetInstance().GetWorkSpace(workspace_size); if (data_type == CUDNN_DATA_DOUBLE) { const double alpha_fp64 = static_cast(alpha); - const double beta_fp64 = static_cast(beta); + const double beta_fp64 = static_cast(beta); CUDNN_CALL(cudnnConvolutionBackwardData(handle, &alpha_fp64, w_desc, @@ -832,8 +982,19 @@ void cinn_call_cudnn_conv2d_backward_data(void *v_args, x_desc, _dx)); } else { - CUDNN_CALL(cudnnConvolutionBackwardData( - handle, &alpha, w_desc, _w, y_desc, _dy, conv_desc, algo, workspace_data, workspace_size, &beta, x_desc, _dx)); + CUDNN_CALL(cudnnConvolutionBackwardData(handle, + &alpha, + w_desc, + _w, + y_desc, + _dy, + conv_desc, + algo, + workspace_data, + workspace_size, + &beta, + x_desc, + _dx)); } CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc)); @@ -872,47 +1033,65 @@ void cinn_call_cudnn_conv2d_backward_filter(void *v_args, CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); cinn_pod_value_t *args = static_cast(v_args); - void *_x = args[0].operator cinn_buffer_t *()->memory; + void *_x = args[0].operator cinn_buffer_t *()->memory; void *_dy = args[1].operator cinn_buffer_t *()->memory; void *_dw = args[2].operator cinn_buffer_t *()->memory; cudnnTensorFormat_t tensor_format = static_cast(format); - cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); + cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); cudnnTensorDescriptor_t x_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); cudnnFilterDescriptor_t w_desc; CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc)); - CUDNN_CALL(cudnnSetFilter4dDescriptor(w_desc, data_type, tensor_format, filter_n, filter_c, filter_h, filter_w)); + CUDNN_CALL(cudnnSetFilter4dDescriptor(w_desc, + data_type, + tensor_format, + filter_n, + filter_c, + filter_h, + filter_w)); cudnnConvolutionDescriptor_t conv_desc; CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc)); - CUDNN_CALL(cudnnSetConvolution2dDescriptor(conv_desc, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - CUDNN_CROSS_CORRELATION, - get_cudnn_compute_dtype(data_type))); + CUDNN_CALL( + cudnnSetConvolution2dDescriptor(conv_desc, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + CUDNN_CROSS_CORRELATION, + get_cudnn_compute_dtype(data_type))); CUDNN_CALL(cudnnSetConvolutionGroupCount(conv_desc, groups)); CUDNN_CALL(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH)); cudnnTensorDescriptor_t y_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, tensor_format, data_type, output_n, output_c, output_h, output_w)); - - auto &algo_map = ConvAlgoMap::GetInstance(); - std::string hash_key = "conv2d backward filter, layout=" + debug_cudnn_tensor_format(tensor_format) + - ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" + std::to_string(input_n) + - "," + std::to_string(input_c) + "," + std::to_string(input_h) + "," + std::to_string(input_w) + - "}, filter_nchw={" + std::to_string(filter_n) + "," + std::to_string(filter_c) + "," + - std::to_string(filter_h) + "," + std::to_string(filter_w) + "}, output_nchw={" + - std::to_string(output_n) + "," + std::to_string(output_c) + "," + std::to_string(output_h) + - "," + std::to_string(output_w) + "}"; + CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, + tensor_format, + data_type, + output_n, + output_c, + output_h, + output_w)); + + auto &algo_map = ConvAlgoMap::GetInstance(); + std::string hash_key = + "conv2d backward filter, layout=" + + debug_cudnn_tensor_format(tensor_format) + + ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" + + std::to_string(input_n) + "," + std::to_string(input_c) + "," + + std::to_string(input_h) + "," + std::to_string(input_w) + + "}, filter_nchw={" + std::to_string(filter_n) + "," + + std::to_string(filter_c) + "," + std::to_string(filter_h) + "," + + std::to_string(filter_w) + "}, output_nchw={" + std::to_string(output_n) + + "," + std::to_string(output_c) + "," + std::to_string(output_h) + "," + + std::to_string(output_w) + "}"; VLOG(4) << hash_key; @@ -923,8 +1102,8 @@ void cinn_call_cudnn_conv2d_backward_filter(void *v_args, } else { int count = 0; cudnnConvolutionBwdFilterAlgoPerf_t algo_perf; - CUDNN_CALL( - cudnnFindConvolutionBackwardFilterAlgorithm(handle, x_desc, y_desc, conv_desc, w_desc, 1, &count, &algo_perf)); + CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm( + handle, x_desc, y_desc, conv_desc, w_desc, 1, &count, &algo_perf)); algo = algo_perf.algo; algo_map.InsertAlgo(hash_key, static_cast(algo_perf.algo)); @@ -935,13 +1114,14 @@ void cinn_call_cudnn_conv2d_backward_filter(void *v_args, } size_t workspace_size = 0; - CUDNN_CALL( - cudnnGetConvolutionBackwardFilterWorkspaceSize(handle, x_desc, y_desc, conv_desc, w_desc, algo, &workspace_size)); + CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, x_desc, y_desc, conv_desc, w_desc, algo, &workspace_size)); - void *workspace_data = CudnnHandle::GetInstance().GetWorkSpace(workspace_size); + void *workspace_data = + CudnnHandle::GetInstance().GetWorkSpace(workspace_size); if (data_type == CUDNN_DATA_DOUBLE) { const double alpha_fp64 = static_cast(alpha); - const double beta_fp64 = static_cast(beta); + const double beta_fp64 = static_cast(beta); CUDNN_CALL(cudnnConvolutionBackwardFilter(handle, &alpha_fp64, x_desc, @@ -956,8 +1136,19 @@ void cinn_call_cudnn_conv2d_backward_filter(void *v_args, w_desc, _dw)); } else { - CUDNN_CALL(cudnnConvolutionBackwardFilter( - handle, &alpha, x_desc, _x, y_desc, _dy, conv_desc, algo, workspace_data, workspace_size, &beta, w_desc, _dw)); + CUDNN_CALL(cudnnConvolutionBackwardFilter(handle, + &alpha, + x_desc, + _x, + y_desc, + _dy, + conv_desc, + algo, + workspace_data, + workspace_size, + &beta, + w_desc, + _dw)); } CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc)); @@ -995,9 +1186,9 @@ void cinn_call_cudnn_pool2d_forward(void *v_args, void *_x = args[0].operator cinn_buffer_t *()->memory; void *_y = args[1].operator cinn_buffer_t *()->memory; - cudnnPoolingMode_t pool_mode = static_cast(mode); + cudnnPoolingMode_t pool_mode = static_cast(mode); cudnnTensorFormat_t tensor_format = static_cast(format); - cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); + cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); if (GetCinnCudnnDeterministic() && pool_mode == CUDNN_POOLING_MAX) { pool_mode = CUDNN_POOLING_MAX_DETERMINISTIC; @@ -1005,34 +1196,54 @@ void cinn_call_cudnn_pool2d_forward(void *v_args, std::string hash_key = "pool2d forward, layout=" + debug_cudnn_tensor_format(tensor_format) + - ", pool_type=" + debug_cudnn_pool_mode(pool_mode) + ", dtype=" + debug_cudnn_tensor_dtype(data_type) + - ", input_nchw={" + std::to_string(input_n) + "," + std::to_string(input_c) + "," + std::to_string(input_h) + "," + - std::to_string(input_w) + "}, kernel_hw={" + std::to_string(kernel_h) + "," + std::to_string(kernel_w) + - "}, pad_hw={" + std::to_string(pad_h) + "," + std::to_string(pad_w) + "}, stride_hw={" + - std::to_string(stride_h) + "," + std::to_string(stride_w) + "}, output_nchw={" + std::to_string(output_n) + "," + - std::to_string(output_c) + "," + std::to_string(output_h) + "," + std::to_string(output_w) + "}"; + ", pool_type=" + debug_cudnn_pool_mode(pool_mode) + + ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" + + std::to_string(input_n) + "," + std::to_string(input_c) + "," + + std::to_string(input_h) + "," + std::to_string(input_w) + + "}, kernel_hw={" + std::to_string(kernel_h) + "," + + std::to_string(kernel_w) + "}, pad_hw={" + std::to_string(pad_h) + "," + + std::to_string(pad_w) + "}, stride_hw={" + std::to_string(stride_h) + + "," + std::to_string(stride_w) + "}, output_nchw={" + + std::to_string(output_n) + "," + std::to_string(output_c) + "," + + std::to_string(output_h) + "," + std::to_string(output_w) + "}"; VLOG(4) << hash_key; cudnnPoolingDescriptor_t pool_desc; CUDNN_CALL(cudnnCreatePoolingDescriptor(&pool_desc)); - CUDNN_CALL(cudnnSetPooling2dDescriptor( - pool_desc, pool_mode, CUDNN_NOT_PROPAGATE_NAN, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w)); + CUDNN_CALL(cudnnSetPooling2dDescriptor(pool_desc, + pool_mode, + CUDNN_NOT_PROPAGATE_NAN, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w)); cudnnTensorDescriptor_t x_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); cudnnTensorDescriptor_t y_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, tensor_format, data_type, output_n, output_c, output_h, output_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, + tensor_format, + data_type, + output_n, + output_c, + output_h, + output_w)); if (data_type == CUDNN_DATA_DOUBLE) { const double alpha_fp64 = static_cast(alpha); - const double beta_fp64 = static_cast(beta); - CUDNN_CALL(cudnnPoolingForward(handle, pool_desc, &alpha_fp64, x_desc, _x, &beta_fp64, y_desc, _y)); + const double beta_fp64 = static_cast(beta); + CUDNN_CALL(cudnnPoolingForward( + handle, pool_desc, &alpha_fp64, x_desc, _x, &beta_fp64, y_desc, _y)); } else { - CUDNN_CALL(cudnnPoolingForward(handle, pool_desc, &alpha, x_desc, _x, &beta, y_desc, _y)); + CUDNN_CALL(cudnnPoolingForward( + handle, pool_desc, &alpha, x_desc, _x, &beta, y_desc, _y)); } CUDNN_CALL(cudnnDestroyPoolingDescriptor(pool_desc)); @@ -1066,14 +1277,14 @@ void cinn_call_cudnn_pool2d_backward(void *v_args, CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); cinn_pod_value_t *args = static_cast(v_args); - void *_x = args[0].operator cinn_buffer_t *()->memory; - void *_y = args[1].operator cinn_buffer_t *()->memory; + void *_x = args[0].operator cinn_buffer_t *()->memory; + void *_y = args[1].operator cinn_buffer_t *()->memory; void *_dy = args[2].operator cinn_buffer_t *()->memory; void *_dx = args[3].operator cinn_buffer_t *()->memory; - cudnnPoolingMode_t pool_mode = static_cast(mode); + cudnnPoolingMode_t pool_mode = static_cast(mode); cudnnTensorFormat_t tensor_format = static_cast(format); - cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); + cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); if (GetCinnCudnnDeterministic() && pool_mode == CUDNN_POOLING_MAX) { pool_mode = CUDNN_POOLING_MAX_DETERMINISTIC; @@ -1081,36 +1292,74 @@ void cinn_call_cudnn_pool2d_backward(void *v_args, std::string hash_key = "pool2d backward, layout=" + debug_cudnn_tensor_format(tensor_format) + - ", pool_type=" + debug_cudnn_pool_mode(pool_mode) + ", dtype=" + debug_cudnn_tensor_dtype(data_type) + - ", input_nchw={" + std::to_string(input_n) + "," + std::to_string(input_c) + "," + std::to_string(input_h) + "," + - std::to_string(input_w) + "}, kernel_hw={" + std::to_string(kernel_h) + "," + std::to_string(kernel_w) + - "}, pad_hw={" + std::to_string(pad_h) + "," + std::to_string(pad_w) + "}, stride_hw={" + - std::to_string(stride_h) + "," + std::to_string(stride_w) + ", output_nchw={" + std::to_string(output_n) + "," + - std::to_string(output_c) + "," + std::to_string(output_h) + "," + std::to_string(output_w) + "}"; + ", pool_type=" + debug_cudnn_pool_mode(pool_mode) + + ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" + + std::to_string(input_n) + "," + std::to_string(input_c) + "," + + std::to_string(input_h) + "," + std::to_string(input_w) + + "}, kernel_hw={" + std::to_string(kernel_h) + "," + + std::to_string(kernel_w) + "}, pad_hw={" + std::to_string(pad_h) + "," + + std::to_string(pad_w) + "}, stride_hw={" + std::to_string(stride_h) + + "," + std::to_string(stride_w) + ", output_nchw={" + + std::to_string(output_n) + "," + std::to_string(output_c) + "," + + std::to_string(output_h) + "," + std::to_string(output_w) + "}"; VLOG(4) << hash_key; cudnnPoolingDescriptor_t pool_desc; CUDNN_CALL(cudnnCreatePoolingDescriptor(&pool_desc)); - CUDNN_CALL(cudnnSetPooling2dDescriptor( - pool_desc, pool_mode, CUDNN_NOT_PROPAGATE_NAN, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w)); + CUDNN_CALL(cudnnSetPooling2dDescriptor(pool_desc, + pool_mode, + CUDNN_NOT_PROPAGATE_NAN, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w)); cudnnTensorDescriptor_t x_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); cudnnTensorDescriptor_t y_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, tensor_format, data_type, output_n, output_c, output_h, output_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, + tensor_format, + data_type, + output_n, + output_c, + output_h, + output_w)); if (data_type == CUDNN_DATA_DOUBLE) { const double alpha_fp64 = static_cast(alpha); - const double beta_fp64 = static_cast(beta); - CUDNN_CALL(cudnnPoolingBackward( - handle, pool_desc, &alpha_fp64, y_desc, _y, y_desc, _dy, x_desc, _x, &beta_fp64, x_desc, _dx)); + const double beta_fp64 = static_cast(beta); + CUDNN_CALL(cudnnPoolingBackward(handle, + pool_desc, + &alpha_fp64, + y_desc, + _y, + y_desc, + _dy, + x_desc, + _x, + &beta_fp64, + x_desc, + _dx)); } else { - CUDNN_CALL( - cudnnPoolingBackward(handle, pool_desc, &alpha, y_desc, _y, y_desc, _dy, x_desc, _x, &beta, x_desc, _dx)); + CUDNN_CALL(cudnnPoolingBackward(handle, + pool_desc, + &alpha, + y_desc, + _y, + y_desc, + _dy, + x_desc, + _x, + &beta, + x_desc, + _dx)); } CUDNN_CALL(cudnnDestroyPoolingDescriptor(pool_desc)); @@ -1141,25 +1390,47 @@ void cinn_call_cudnn_softmax_forward(void *v_args, void *_x = args[0].operator cinn_buffer_t *()->memory; void *_y = args[1].operator cinn_buffer_t *()->memory; - cudnnSoftmaxMode_t softmax_mode = static_cast(mode); + cudnnSoftmaxMode_t softmax_mode = static_cast(mode); cudnnTensorFormat_t tensor_format = static_cast(format); - cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); + cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); cudnnTensorDescriptor_t x_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); cudnnTensorDescriptor_t y_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, tensor_format, data_type, output_n, output_c, output_h, output_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, + tensor_format, + data_type, + output_n, + output_c, + output_h, + output_w)); if (data_type == CUDNN_DATA_DOUBLE) { const double alpha_fp64 = static_cast(alpha); - const double beta_fp64 = static_cast(beta); - CUDNN_CALL( - cudnnSoftmaxForward(handle, CUDNN_SOFTMAX_LOG, softmax_mode, &alpha_fp64, x_desc, _x, &beta_fp64, y_desc, _y)); + const double beta_fp64 = static_cast(beta); + CUDNN_CALL(cudnnSoftmaxForward(handle, + CUDNN_SOFTMAX_LOG, + softmax_mode, + &alpha_fp64, + x_desc, + _x, + &beta_fp64, + y_desc, + _y)); } else { - CUDNN_CALL(cudnnSoftmaxForward(handle, CUDNN_SOFTMAX_LOG, softmax_mode, &alpha, x_desc, _x, &beta, y_desc, _y)); + CUDNN_CALL(cudnnSoftmaxForward(handle, + CUDNN_SOFTMAX_LOG, + softmax_mode, + &alpha, + x_desc, + _x, + &beta, + y_desc, + _y)); } CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc)); @@ -1186,30 +1457,55 @@ void cinn_call_cudnn_softmax_backward(void *v_args, CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); cinn_pod_value_t *args = static_cast(v_args); - void *_y = args[0].operator cinn_buffer_t *()->memory; + void *_y = args[0].operator cinn_buffer_t *()->memory; void *_dy = args[1].operator cinn_buffer_t *()->memory; void *_dx = args[2].operator cinn_buffer_t *()->memory; - cudnnSoftmaxMode_t softmax_mode = static_cast(mode); + cudnnSoftmaxMode_t softmax_mode = static_cast(mode); cudnnTensorFormat_t tensor_format = static_cast(format); - cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); + cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args); cudnnTensorDescriptor_t x_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w)); cudnnTensorDescriptor_t y_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, tensor_format, data_type, output_n, output_c, output_h, output_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, + tensor_format, + data_type, + output_n, + output_c, + output_h, + output_w)); if (data_type == CUDNN_DATA_DOUBLE) { const double alpha_fp64 = static_cast(alpha); - const double beta_fp64 = static_cast(beta); - CUDNN_CALL(cudnnSoftmaxBackward( - handle, CUDNN_SOFTMAX_LOG, softmax_mode, &alpha_fp64, y_desc, _y, y_desc, _dy, &beta_fp64, x_desc, _dx)); + const double beta_fp64 = static_cast(beta); + CUDNN_CALL(cudnnSoftmaxBackward(handle, + CUDNN_SOFTMAX_LOG, + softmax_mode, + &alpha_fp64, + y_desc, + _y, + y_desc, + _dy, + &beta_fp64, + x_desc, + _dx)); } else { - CUDNN_CALL(cudnnSoftmaxBackward( - handle, CUDNN_SOFTMAX_LOG, softmax_mode, &alpha, y_desc, _y, y_desc, _dy, &beta, x_desc, _dx)); + CUDNN_CALL(cudnnSoftmaxBackward(handle, + CUDNN_SOFTMAX_LOG, + softmax_mode, + &alpha, + y_desc, + _y, + y_desc, + _dy, + &beta, + x_desc, + _dx)); } CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc)); @@ -1235,21 +1531,26 @@ void Gemm(const cublasHandle_t &cublas, float *output_data, const std::vector &output_shape, cudaStream_t stream) { - int lhs_row = lhs_shape[0]; - int lhs_col = lhs_shape[1]; - int rhs_row = rhs_shape[0]; - int rhs_col = rhs_shape[1]; + int lhs_row = lhs_shape[0]; + int lhs_col = lhs_shape[1]; + int rhs_row = rhs_shape[0]; + int rhs_col = rhs_shape[1]; int output_row = output_shape[0]; int output_col = output_shape[1]; // copy values of bias_data to the output_data if (bias_data != nullptr) { - cudaMemcpyAsync(output_data, bias_data, output_row * output_col * sizeof(float), cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(output_data, + bias_data, + output_row * output_col * sizeof(float), + cudaMemcpyDeviceToDevice, + stream); } int contracting_size = lhs_trans ? lhs_row : lhs_col; CHECK_EQ(contracting_size, (rhs_trans ? rhs_col : rhs_row)) - << "The contracting dimension value of lhs matrix should be equal to the one of rhs matrix."; + << "The contracting dimension value of lhs matrix should be equal to the " + "one of rhs matrix."; auto trans_a = rhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N; auto trans_b = lhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N; cublasSgemm(cublas, @@ -1281,13 +1582,13 @@ void GemmStridedBatched(const cublasHandle_t &cublas, float *output_data, const std::vector &output_shape, cudaStream_t stream) { - int lhs_bs = lhs_shape[0]; - int lhs_row = lhs_shape[1]; - int lhs_col = lhs_shape[2]; - int rhs_bs = rhs_shape[0]; - int rhs_row = rhs_shape[1]; - int rhs_col = rhs_shape[2]; - int output_bs = output_shape[0]; + int lhs_bs = lhs_shape[0]; + int lhs_row = lhs_shape[1]; + int lhs_col = lhs_shape[2]; + int rhs_bs = rhs_shape[0]; + int rhs_row = rhs_shape[1]; + int rhs_col = rhs_shape[2]; + int output_bs = output_shape[0]; int output_row = output_shape[1]; int output_col = output_shape[2]; CHECK_EQ(lhs_bs, rhs_bs); @@ -1295,17 +1596,21 @@ void GemmStridedBatched(const cublasHandle_t &cublas, // copy values of bias_data to the output_data if (bias_data != nullptr) { - cudaMemcpyAsync( - output_data, bias_data, output_bs * output_row * output_col * sizeof(float), cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(output_data, + bias_data, + output_bs * output_row * output_col * sizeof(float), + cudaMemcpyDeviceToDevice, + stream); } int contracting_size = lhs_trans ? lhs_row : lhs_col; CHECK_EQ(contracting_size, (rhs_trans ? rhs_col : rhs_row)) - << "The contracting dimension value of lhs matrix should be equal to the one of rhs matrix."; - auto trans_a = rhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N; - auto trans_b = lhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N; - int64_t lhs_stride = lhs_row * lhs_col; - int64_t rhs_stride = rhs_row * rhs_col; + << "The contracting dimension value of lhs matrix should be equal to the " + "one of rhs matrix."; + auto trans_a = rhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N; + auto trans_b = lhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N; + int64_t lhs_stride = lhs_row * lhs_col; + int64_t rhs_stride = rhs_row * rhs_col; int64_t output_stride = output_row * output_col; cublasSgemmStridedBatched(cublas, trans_a, @@ -1345,31 +1650,41 @@ class CusolverHandle { cusolverDnHandle_t handle_; }; -void cinn_call_cholesky_nvgpu(void *v_args, int num_args, int batch_size, int m, bool upper, void *stream) { +void cinn_call_cholesky_nvgpu(void *v_args, + int num_args, + int batch_size, + int m, + bool upper, + void *stream) { cinn_pod_value_t *args = static_cast(v_args); - cinn_buffer_t *x = args[0].operator cinn_buffer_t *(); - cinn_buffer_t *out = args[1].operator cinn_buffer_t *(); - // In cuSOLVER, dense matrix stores in COL_MAJOR, thus FILL_MODE needs to be filpped. - // See also: https://docs.nvidia.com/cuda/cusolver/index.html#matrix-dense-format - cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; - size_t numel = x->num_elements(); - uint8_t bits = x->type.bits; - uint8_t bytes = bits / 8; + cinn_buffer_t *x = args[0].operator cinn_buffer_t *(); + cinn_buffer_t *out = args[1].operator cinn_buffer_t *(); + // In cuSOLVER, dense matrix stores in COL_MAJOR, thus FILL_MODE needs to be + // filpped. See also: + // https://docs.nvidia.com/cuda/cusolver/index.html#matrix-dense-format + cublasFillMode_t uplo = + upper ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + size_t numel = x->num_elements(); + uint8_t bits = x->type.bits; + uint8_t bytes = bits / 8; CHECK_EQ(x->type.code, cinn_type_code_t::cinn_type_float); - CHECK(bits == 32 || bits == 64) << "Unsupported bits = " << bits << " float data type for cholesky"; + CHECK(bits == 32 || bits == 64) + << "Unsupported bits = " << bits << " float data type for cholesky"; auto cuda_stream = static_cast(stream); // Copy data from x to out - void *x_ptr = reinterpret_cast(x->memory); + void *x_ptr = reinterpret_cast(x->memory); void *out_ptr = reinterpret_cast(out->memory); - CUDA_CALL(cudaMemcpyAsync(out_ptr, x_ptr, numel * bytes, cudaMemcpyDeviceToDevice, cuda_stream)); + CUDA_CALL(cudaMemcpyAsync( + out_ptr, x_ptr, numel * bytes, cudaMemcpyDeviceToDevice, cuda_stream)); // Generate pointer array thrust::host_vector host_out_ptr(batch_size, nullptr); for (int i = 0; i < batch_size; ++i) { host_out_ptr[i] = reinterpret_cast(out_ptr) + i * m * m * bytes; } - thrust::device_vector dev_out_ptr(host_out_ptr.begin(), host_out_ptr.end()); + thrust::device_vector dev_out_ptr(host_out_ptr.begin(), + host_out_ptr.end()); // Store the return value of each matrix thrust::host_vector host_info(batch_size, 0); thrust::device_vector dev_info(host_info.begin(), host_info.end()); @@ -1377,27 +1692,31 @@ void cinn_call_cholesky_nvgpu(void *v_args, int num_args, int batch_size, int m, cusolverDnHandle_t handler = CusolverHandle::GetInstance().GetHandle(); CUSOLVER_CALL(cusolverDnSetStream(handler, cuda_stream)); if (bits == 32) { - CUSOLVER_CALL(cusolverDnSpotrfBatched(handler, - uplo, - m, - reinterpret_cast(dev_out_ptr.data().get()), - m, - thrust::raw_pointer_cast(dev_info.data()), - batch_size)); + CUSOLVER_CALL(cusolverDnSpotrfBatched( + handler, + uplo, + m, + reinterpret_cast(dev_out_ptr.data().get()), + m, + thrust::raw_pointer_cast(dev_info.data()), + batch_size)); } else if (bits == 64) { - CUSOLVER_CALL(cusolverDnDpotrfBatched(handler, - uplo, - m, - reinterpret_cast(dev_out_ptr.data().get()), - m, - thrust::raw_pointer_cast(dev_info.data()), - batch_size)); + CUSOLVER_CALL(cusolverDnDpotrfBatched( + handler, + uplo, + m, + reinterpret_cast(dev_out_ptr.data().get()), + m, + thrust::raw_pointer_cast(dev_info.data()), + batch_size)); } // Check result thrust::copy(dev_info.begin(), dev_info.end(), host_info.begin()); for (int i = 0; i < host_info.size(); i++) { - CHECK_EQ(host_info[i], 0) << "Cholesky decomposition fail, please check the " << i + 1 << "th input matrix."; + CHECK_EQ(host_info[i], 0) + << "Cholesky decomposition fail, please check the " << i + 1 + << "th input matrix."; } } @@ -1412,36 +1731,43 @@ void cinn_call_triangular_solve_nvgpu(void *v_args, bool unit_diagonal, void *stream) { cublasHandle_t &handle = CublasHandle::GetInstance().GetCublasHandle(); - cudaStream_t custream = static_cast(stream); + cudaStream_t custream = static_cast(stream); CUBLAS_CALL(cublasSetStream(handle, custream)); - int b_rows = left_side ? k : m; - int b_cols = left_side ? m : k; - int lda = m; - int ldb = b_rows; - cublasSideMode_t side = left_side ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; - cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + int b_rows = left_side ? k : m; + int b_cols = left_side ? m : k; + int lda = m; + int ldb = b_rows; + cublasSideMode_t side = left_side ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; + cublasFillMode_t uplo = + upper ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; cublasOperation_t transa = transpose_a ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasDiagType_t diag = unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; + cublasDiagType_t diag = + unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; cinn_pod_value_t *args = static_cast(v_args); - cinn_buffer_t *input1 = args[0].operator cinn_buffer_t *(); - cinn_buffer_t *input2 = args[1].operator cinn_buffer_t *(); - cinn_buffer_t *output = args[2].operator cinn_buffer_t *(); + cinn_buffer_t *input1 = args[0].operator cinn_buffer_t *(); + cinn_buffer_t *input2 = args[1].operator cinn_buffer_t *(); + cinn_buffer_t *output = args[2].operator cinn_buffer_t *(); CHECK_EQ(input1->type.code, cinn_type_code_t::cinn_type_float); CHECK_EQ(input2->type.code, cinn_type_code_t::cinn_type_float); CHECK_EQ(input1->type.bits, input2->type.bits); - uint8_t bits = input1->type.bits; + uint8_t bits = input1->type.bits; uint8_t bytes = bits / 8; - CHECK(bits == 32 || bits == 64) << "unsupported bits = " << bits << " float data type for triangular solve"; + CHECK(bits == 32 || bits == 64) << "unsupported bits = " << bits + << " float data type for triangular solve"; std::string debug_info = - "triangular solve op: left_side=" + std::to_string(left_side) + ", upper=" + std::to_string(uplo) + - ", transpose_a=" + std::to_string(transa) + ", unit_diagonal=" + std::to_string(unit_diagonal) + - ", batch_size=" + std::to_string(batch_size) + ", m=" + std::to_string(m) + ", k=" + std::to_string(k) + - ", input1_dtype={code: " + std::to_string(input1->type.code) + ", bits: " + std::to_string(input1->type.bits) + - "}" + ", input2_dtype={code: " + std::to_string(input2->type.code) + + "triangular solve op: left_side=" + std::to_string(left_side) + + ", upper=" + std::to_string(uplo) + + ", transpose_a=" + std::to_string(transa) + + ", unit_diagonal=" + std::to_string(unit_diagonal) + + ", batch_size=" + std::to_string(batch_size) + + ", m=" + std::to_string(m) + ", k=" + std::to_string(k) + + ", input1_dtype={code: " + std::to_string(input1->type.code) + + ", bits: " + std::to_string(input1->type.bits) + "}" + + ", input2_dtype={code: " + std::to_string(input2->type.code) + ", bits: " + std::to_string(input2->type.bits) + "}"; VLOG(4) << debug_info; @@ -1449,10 +1775,12 @@ void cinn_call_triangular_solve_nvgpu(void *v_args, void *b_ptr = reinterpret_cast(input2->memory); void *x_ptr = reinterpret_cast(output->memory); - // The API cublasStrsmBatched overwrites the right-hand sides, so the right-hand sides should be copied to the output. - // The output can then be used directly for the calculation. + // The API cublasStrsmBatched overwrites the right-hand sides, so the + // right-hand sides should be copied to the output. The output can then be + // used directly for the calculation. size_t numel = input2->num_elements(); - CUDA_CALL(cudaMemcpyAsync(x_ptr, b_ptr, numel * bytes, cudaMemcpyDeviceToDevice, custream)); + CUDA_CALL(cudaMemcpyAsync( + x_ptr, b_ptr, numel * bytes, cudaMemcpyDeviceToDevice, custream)); std::vector a_array(batch_size, nullptr); std::vector x_array(batch_size, nullptr); @@ -1465,39 +1793,47 @@ void cinn_call_triangular_solve_nvgpu(void *v_args, if (bits == 32) { std::vector alpha(batch_size, 1.0f); - CUBLAS_CALL(cublasStrsmBatched(handle, - side, - uplo, - transa, - diag, - b_rows, - b_cols, - alpha.data(), - reinterpret_cast(dev_a_array.data().get()), - lda, - reinterpret_cast(dev_x_array.data().get()), - ldb, - batch_size)); + CUBLAS_CALL( + cublasStrsmBatched(handle, + side, + uplo, + transa, + diag, + b_rows, + b_cols, + alpha.data(), + reinterpret_cast(dev_a_array.data().get()), + lda, + reinterpret_cast(dev_x_array.data().get()), + ldb, + batch_size)); } else if (bits == 64) { std::vector alpha(batch_size, 1.0); - CUBLAS_CALL(cublasDtrsmBatched(handle, - side, - uplo, - transa, - diag, - b_rows, - b_cols, - alpha.data(), - reinterpret_cast(dev_a_array.data().get()), - lda, - reinterpret_cast(dev_x_array.data().get()), - ldb, - batch_size)); + CUBLAS_CALL(cublasDtrsmBatched( + handle, + side, + uplo, + transa, + diag, + b_rows, + b_cols, + alpha.data(), + reinterpret_cast(dev_a_array.data().get()), + lda, + reinterpret_cast(dev_x_array.data().get()), + ldb, + batch_size)); } } -void cinn_assert_true_nvgpu(void *v_args, int num_args, int msg, bool only_warning, void *stream) { - cinn_assert_true(v_args, num_args, msg, only_warning, stream, common::DefaultNVGPUTarget()); +void cinn_assert_true_nvgpu( + void *v_args, int num_args, int msg, bool only_warning, void *stream) { + cinn_assert_true(v_args, + num_args, + msg, + only_warning, + stream, + common::DefaultNVGPUTarget()); } void cinn_gpu_cublas_mul(const std::vector &attrs, @@ -1509,20 +1845,33 @@ void cinn_gpu_cublas_mul(const std::vector &attrs, CHECK_EQ(input1->type.code, cinn_type_code_t::cinn_type_float); cudaStream_t custream = static_cast(stream); CUBLAS_CALL(cublasSetStream(handle, custream)); - float *x_data = reinterpret_cast(input1->memory); - float *y_data = reinterpret_cast(input2->memory); + float *x_data = reinterpret_cast(input1->memory); + float *y_data = reinterpret_cast(input2->memory); float *out_data = reinterpret_cast(output->memory); - int M = 1; + int M = 1; CHECK_GE(attrs.size(), 6); for (int i = 0; i < attrs[attrs.size() - 2]; i++) { M *= attrs[i]; } - int N = attrs[attrs.size() - 3]; - int K = attrs[attrs.size() - 4]; + int N = attrs[attrs.size() - 3]; + int K = attrs[attrs.size() - 4]; float alpha = 1.f; - float beta = 0.f; + float beta = 0.f; // M,N * N,K - cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, K, M, N, &alpha, y_data, K, x_data, N, &beta, out_data, K); + cublasSgemm(handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + K, + M, + N, + &alpha, + y_data, + K, + x_data, + N, + &beta, + out_data, + K); } void cinn_gpu_cublas_gemm(const std::vector &attrs, @@ -1532,22 +1881,23 @@ void cinn_gpu_cublas_gemm(const std::vector &attrs, cinn_buffer_t *output, cudaStream_t stream) { cublasHandle_t &handle = CublasHandle::GetInstance().GetCublasHandle(); - cudaStream_t custream = static_cast(stream); + cudaStream_t custream = static_cast(stream); CUBLAS_CALL(cublasSetStream(handle, custream)); CHECK_EQ(lhs->type.code, cinn_type_code_t::cinn_type_float); - const float *lhs_data = reinterpret_cast(lhs->memory); - const float *rhs_data = reinterpret_cast(rhs->memory); - const float *bias_data = bias ? reinterpret_cast(bias->memory) : nullptr; - float *output_data = reinterpret_cast(output->memory); + const float *lhs_data = reinterpret_cast(lhs->memory); + const float *rhs_data = reinterpret_cast(rhs->memory); + const float *bias_data = + bias ? reinterpret_cast(bias->memory) : nullptr; + float *output_data = reinterpret_cast(output->memory); CHECK_GE(attrs.size(), 13); int lhs_dim_size = attrs[attrs.size() - 7]; int rhs_dim_size = attrs[attrs.size() - 6]; int out_dim_size = attrs[attrs.size() - 5]; - bool lhs_trans = static_cast(attrs[attrs.size() - 4]); - bool rhs_trans = static_cast(attrs[attrs.size() - 3]); - bool out_trans = static_cast(attrs[attrs.size() - 2]); + bool lhs_trans = static_cast(attrs[attrs.size() - 4]); + bool rhs_trans = static_cast(attrs[attrs.size() - 3]); + bool out_trans = static_cast(attrs[attrs.size() - 2]); // 1)C = A^T * B --> C^T = B^T * A // 2)C = A * B^T --> C^T = B * A^T // 3)C = A^T * B^T --> C^T = B * A @@ -1556,8 +1906,9 @@ void cinn_gpu_cublas_gemm(const std::vector &attrs, lhs_trans = static_cast(attrs[attrs.size() - 3]) ^ out_trans; rhs_trans = static_cast(attrs[attrs.size() - 4]) ^ out_trans; } - const float alpha = *reinterpret_cast(&attrs[attrs.size() - 1]); - const float beta = bias ? 1.f : 0.f; + const float alpha = + *reinterpret_cast(&attrs[attrs.size() - 1]); + const float beta = bias ? 1.f : 0.f; VLOG(4) << "The lhs_trans value used by cinn_gpu_cublas_gemm: " << lhs_trans; VLOG(4) << "The rhs_trans value used by cinn_gpu_cublas_gemm: " << rhs_trans; VLOG(4) << "The out_trans value used by cinn_gpu_cublas_gemm: " << out_trans; @@ -1616,9 +1967,13 @@ void cinn_gpu_cublas_gemm(const std::vector &attrs, class CurandGenerator { public: - CurandGenerator() { CURAND_CALL(curandCreateGenerator(&generator_, CURAND_RNG_PSEUDO_DEFAULT)); } + CurandGenerator() { + CURAND_CALL(curandCreateGenerator(&generator_, CURAND_RNG_PSEUDO_DEFAULT)); + } - CurandGenerator(curandRngType rng_type) { CURAND_CALL(curandCreateGenerator(&generator_, rng_type)); } + CurandGenerator(curandRngType rng_type) { + CURAND_CALL(curandCreateGenerator(&generator_, rng_type)); + } ~CurandGenerator() { CURAND_CALL(curandDestroyGenerator(generator_)); } @@ -1635,7 +1990,8 @@ class CurandGenerator { auto rand_seed = (seed == 0ULL) ? RandomSeed::GetOrSet() : seed; if (rand_seed != 0ULL && rand_seed != seed_) { CURAND_CALL(curandSetPseudoRandomGeneratorSeed(generator_, rand_seed)); - VLOG(4) << "Change curand random seed from: " << seed_ << " to: " << rand_seed; + VLOG(4) << "Change curand random seed from: " << seed_ + << " to: " << rand_seed; seed_ = rand_seed; } return *this; @@ -1644,7 +2000,8 @@ class CurandGenerator { CurandGenerator &SetStream(cudaStream_t stream) { if (stream != nullptr && stream != stream_) { CURAND_CALL(curandSetStream(generator_, stream)); - VLOG(4) << "Change curand generator stream from: " << stream_ << " to: " << stream; + VLOG(4) << "Change curand generator stream from: " << stream_ + << " to: " << stream; stream_ = stream; } return *this; @@ -1653,7 +2010,7 @@ class CurandGenerator { private: curandGenerator_t generator_; unsigned long long seed_ = 0ULL; - cudaStream_t stream_ = nullptr; + cudaStream_t stream_ = nullptr; }; class CurandGeneratorFactory { @@ -1668,10 +2025,12 @@ class CurandGeneratorFactory { static CurandGenerator &Get(CurandGeneratorType type) { switch (type) { case CurandGeneratorType::GENERATOR_GAUSSIAN: - static CurandGenerator gaussian_generator(CURAND_RNG_PSEUDO_PHILOX4_32_10); + static CurandGenerator gaussian_generator( + CURAND_RNG_PSEUDO_PHILOX4_32_10); return gaussian_generator; case CurandGeneratorType::GENERATOR_UNIFORM: - static CurandGenerator uniform_generator(CURAND_RNG_PSEUDO_PHILOX4_32_10); + static CurandGenerator uniform_generator( + CURAND_RNG_PSEUDO_PHILOX4_32_10); return uniform_generator; case CurandGeneratorType::GENERATOR_RANDINT: static CurandGenerator randint_generator(CURAND_RNG_PSEUDO_MT19937); @@ -1683,20 +2042,22 @@ class CurandGeneratorFactory { } }; -void cinn_call_gaussian_random(void *v_args, int num_args, float mean, float std, int seed, void *stream) { +void cinn_call_gaussian_random( + void *v_args, int num_args, float mean, float std, int seed, void *stream) { cinn_pod_value_t *args = static_cast(v_args); - cinn_buffer_t *output = args[0].operator cinn_buffer_t *(); - cinn_type_t dtype = output->type; - size_t numel = output->num_elements(); + cinn_buffer_t *output = args[0].operator cinn_buffer_t *(); + cinn_type_t dtype = output->type; + size_t numel = output->num_elements(); curandGenerator_t generator = - CurandGeneratorFactory::Get(CurandGeneratorFactory::CurandGeneratorType::GENERATOR_GAUSSIAN) + CurandGeneratorFactory::Get( + CurandGeneratorFactory::CurandGeneratorType::GENERATOR_GAUSSIAN) .SetStream(static_cast(stream)) .SetSeed(seed) .GetGenerator(); - VLOG(4) << "cinn_call_gaussian_random: output_size=" << numel << ", mean=" << mean << ", std=" << std - << ", seed=" << seed; + VLOG(4) << "cinn_call_gaussian_random: output_size=" << numel + << ", mean=" << mean << ", std=" << std << ", seed=" << seed; if (dtype == cinn_float32_t()) { float *ptr = reinterpret_cast(output->memory); @@ -1705,24 +2066,27 @@ void cinn_call_gaussian_random(void *v_args, int num_args, float mean, float std double *ptr = reinterpret_cast(output->memory); CURAND_CALL(curandGenerateNormalDouble(generator, ptr, numel, mean, std)); } else { - LOG(FATAL) << "gaussian_random only support float32 and float64! Please check."; + LOG(FATAL) + << "gaussian_random only support float32 and float64! Please check."; } } -void cinn_call_uniform_random(void *v_args, int num_args, float min, float max, int seed, void *stream) { +void cinn_call_uniform_random( + void *v_args, int num_args, float min, float max, int seed, void *stream) { cinn_pod_value_t *args = static_cast(v_args); - cinn_buffer_t *output = args[0].operator cinn_buffer_t *(); - cinn_type_t dtype = output->type; - size_t numel = output->num_elements(); + cinn_buffer_t *output = args[0].operator cinn_buffer_t *(); + cinn_type_t dtype = output->type; + size_t numel = output->num_elements(); curandGenerator_t generator = - CurandGeneratorFactory::Get(CurandGeneratorFactory::CurandGeneratorType::GENERATOR_UNIFORM) + CurandGeneratorFactory::Get( + CurandGeneratorFactory::CurandGeneratorType::GENERATOR_UNIFORM) .SetStream(static_cast(stream)) .SetSeed(seed) .GetGenerator(); - VLOG(4) << "cinn_call_uniform_random: output_size=" << numel << ", min=" << min << ", max=" << max - << ", seed=" << seed; + VLOG(4) << "cinn_call_uniform_random: output_size=" << numel + << ", min=" << min << ", max=" << max << ", seed=" << seed; if (dtype == cinn_float32_t()) { float *ptr = reinterpret_cast(output->memory); @@ -1731,20 +2095,22 @@ void cinn_call_uniform_random(void *v_args, int num_args, float min, float max, double *ptr = reinterpret_cast(output->memory); CURAND_CALL(curandGenerateUniformDouble(generator, ptr, numel)); } else { - LOG(FATAL) << "uniform_random only support float32 and float64! Please check."; + LOG(FATAL) + << "uniform_random only support float32 and float64! Please check."; } } void cinn_call_randint(void *v_args, int num_args, int seed, void *stream) { cinn_pod_value_t *args = static_cast(v_args); - cinn_buffer_t *output = args[0].operator cinn_buffer_t *(); - cinn_type_t dtype = output->type; - size_t numel = output->num_elements(); + cinn_buffer_t *output = args[0].operator cinn_buffer_t *(); + cinn_type_t dtype = output->type; + size_t numel = output->num_elements(); VLOG(4) << "cinn_call_randint: output_size=" << numel << ", seed=" << seed; curandGenerator_t generator = - CurandGeneratorFactory::Get(CurandGeneratorFactory::CurandGeneratorType::GENERATOR_RANDINT) + CurandGeneratorFactory::Get( + CurandGeneratorFactory::CurandGeneratorType::GENERATOR_RANDINT) .SetStream(static_cast(stream)) .SetSeed(seed) .GetGenerator(); @@ -1763,9 +2129,9 @@ namespace { cudnnDataType_t convert_to_cudnn_dtype(cinn_buffer_t *input) { CHECK(input) << "the pointer of input is null"; auto type_code = input->type.code; - int bits = input->type.bits; + int bits = input->type.bits; cudnnDataType_t data_type; - bool is_float = type_code == cinn_type_float; + bool is_float = type_code == cinn_type_float; bool is_bfloat16 = type_code == cinn_type_bfloat; if (is_float && bits == 16) { data_type = CUDNN_DATA_HALF; @@ -1776,7 +2142,8 @@ cudnnDataType_t convert_to_cudnn_dtype(cinn_buffer_t *input) { } else if (is_float && bits == 64) { data_type = CUDNN_DATA_DOUBLE; } else { - LOG(FATAL) << "unsupported cudnn data type: " << static_cast(type_code) << ", bits = " << bits; + LOG(FATAL) << "unsupported cudnn data type: " << static_cast(type_code) + << ", bits = " << bits; } return data_type; } @@ -1837,40 +2204,60 @@ void cinn_gpu_cudnn_conv2d(const absl::flat_hash_map &attr, cudnnTensorDescriptor_t x_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(x_desc, cudnn_tensor_format, data_type, input_n, input_c, input_h, input_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(x_desc, + cudnn_tensor_format, + data_type, + input_n, + input_c, + input_h, + input_w)); cudnnFilterDescriptor_t w_desc; CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc)); - CUDNN_CALL( - cudnnSetFilter4dDescriptor(w_desc, data_type, cudnn_tensor_format, weights_n, weights_c, weights_h, weights_w)); + CUDNN_CALL(cudnnSetFilter4dDescriptor(w_desc, + data_type, + cudnn_tensor_format, + weights_n, + weights_c, + weights_h, + weights_w)); cudnnConvolutionDescriptor_t conv_desc; CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc)); - CUDNN_CALL(cudnnSetConvolution2dDescriptor(conv_desc, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - CUDNN_CROSS_CORRELATION, - get_cudnn_compute_dtype(data_type))); + CUDNN_CALL( + cudnnSetConvolution2dDescriptor(conv_desc, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + CUDNN_CROSS_CORRELATION, + get_cudnn_compute_dtype(data_type))); CUDNN_CALL(cudnnSetConvolutionGroupCount(conv_desc, groups)); CUDNN_CALL(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH)); cudnnTensorDescriptor_t y_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc)); - CUDNN_CALL( - cudnnSetTensor4dDescriptor(y_desc, cudnn_tensor_format, data_type, output_n, output_c, output_h, output_w)); - - auto &conv_algo_map = ConvAlgoMap::GetInstance(); - std::string hash_key = "conv2d forward, layout=" + debug_cudnn_tensor_format(CUDNN_TENSOR_NCHW) + - ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" + std::to_string(input_n) + - "," + std::to_string(input_c) + "," + std::to_string(input_h) + "," + std::to_string(input_w) + - "}, filter_nchw={" + std::to_string(weights_n) + "," + std::to_string(weights_c) + "," + - std::to_string(weights_h) + "," + std::to_string(weights_w) + "}, output_nchw={" + - std::to_string(output_n) + "," + std::to_string(output_c) + "," + std::to_string(output_h) + - "," + std::to_string(output_w) + "}"; + CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc, + cudnn_tensor_format, + data_type, + output_n, + output_c, + output_h, + output_w)); + + auto &conv_algo_map = ConvAlgoMap::GetInstance(); + std::string hash_key = + "conv2d forward, layout=" + debug_cudnn_tensor_format(CUDNN_TENSOR_NCHW) + + ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" + + std::to_string(input_n) + "," + std::to_string(input_c) + "," + + std::to_string(input_h) + "," + std::to_string(input_w) + + "}, filter_nchw={" + std::to_string(weights_n) + "," + + std::to_string(weights_c) + "," + std::to_string(weights_h) + "," + + std::to_string(weights_w) + "}, output_nchw={" + + std::to_string(output_n) + "," + std::to_string(output_c) + "," + + std::to_string(output_h) + "," + std::to_string(output_w) + "}"; cudnnConvolutionFwdAlgo_t algo; int algo_int = conv_algo_map.GetAlgo(hash_key); @@ -1879,7 +2266,8 @@ void cinn_gpu_cudnn_conv2d(const absl::flat_hash_map &attr, } else { int count = 0; cudnnConvolutionFwdAlgoPerf_t algo_perf; - CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(handle, x_desc, w_desc, conv_desc, y_desc, 1, &count, &algo_perf)); + CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm( + handle, x_desc, w_desc, conv_desc, y_desc, 1, &count, &algo_perf)); algo = algo_perf.algo; conv_algo_map.InsertAlgo(hash_key, static_cast(algo_perf.algo)); @@ -1890,17 +2278,40 @@ void cinn_gpu_cudnn_conv2d(const absl::flat_hash_map &attr, } size_t ws_size = 0; - CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(handle, x_desc, w_desc, conv_desc, y_desc, algo, &ws_size)); + CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize( + handle, x_desc, w_desc, conv_desc, y_desc, algo, &ws_size)); void *ws_data = CudnnHandle::GetInstance().GetWorkSpace(ws_size); if (data_type == CUDNN_DATA_DOUBLE) { double alpha[] = {1.f}, beta[] = {0.f}; - CUDNN_CALL(cudnnConvolutionForward( - handle, alpha, x_desc, _x, w_desc, _w, conv_desc, algo, ws_data, ws_size, beta, y_desc, _y)); + CUDNN_CALL(cudnnConvolutionForward(handle, + alpha, + x_desc, + _x, + w_desc, + _w, + conv_desc, + algo, + ws_data, + ws_size, + beta, + y_desc, + _y)); } else { float alpha[] = {1.f}, beta[] = {0.f}; - CUDNN_CALL(cudnnConvolutionForward( - handle, alpha, x_desc, _x, w_desc, _w, conv_desc, algo, ws_data, ws_size, beta, y_desc, _y)); + CUDNN_CALL(cudnnConvolutionForward(handle, + alpha, + x_desc, + _x, + w_desc, + _w, + conv_desc, + algo, + ws_data, + ws_size, + beta, + y_desc, + _y)); } CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc)); @@ -1909,11 +2320,12 @@ void cinn_gpu_cudnn_conv2d(const absl::flat_hash_map &attr, CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc)); } -void cinn_gpu_cudnn_conv2d_backward_data(const absl::flat_hash_map &attr, - cinn_buffer_t *w, - cinn_buffer_t *dy, - cinn_buffer_t *dx, - cudaStream_t stream) { +void cinn_gpu_cudnn_conv2d_backward_data( + const absl::flat_hash_map &attr, + cinn_buffer_t *w, + cinn_buffer_t *dy, + cinn_buffer_t *dx, + cudaStream_t stream) { GetAttrValue(attr, input_n, -1); GetAttrValue(attr, input_c, -1); GetAttrValue(attr, input_h, -1); @@ -1936,7 +2348,7 @@ void cinn_gpu_cudnn_conv2d_backward_data(const absl::flat_hash_map(stream))); - void *_w = w->memory; + void *_w = w->memory; void *_dy = dy->memory; void *_dx = dx->memory; @@ -1944,39 +2356,61 @@ void cinn_gpu_cudnn_conv2d_backward_data(const absl::flat_hash_map(algo_perf.algo)); @@ -1998,17 +2432,40 @@ void cinn_gpu_cudnn_conv2d_backward_data(const absl::flat_hash_map &attr, - cinn_buffer_t *x, - cinn_buffer_t *dy, - cinn_buffer_t *dw, - cudaStream_t stream) { +void cinn_gpu_cudnn_conv2d_backward_filter( + const absl::flat_hash_map &attr, + cinn_buffer_t *x, + cinn_buffer_t *dy, + cinn_buffer_t *dw, + cudaStream_t stream) { GetAttrValue(attr, input_n, -1); GetAttrValue(attr, input_c, -1); GetAttrValue(attr, input_h, -1); @@ -2045,7 +2503,7 @@ void cinn_gpu_cudnn_conv2d_backward_filter(const absl::flat_hash_map(stream))); - void *_x = x->memory; + void *_x = x->memory; void *_dy = dy->memory; void *_dw = dw->memory; @@ -2053,39 +2511,61 @@ void cinn_gpu_cudnn_conv2d_backward_filter(const absl::flat_hash_map(algo_perf.algo)); @@ -2106,17 +2586,40 @@ void cinn_gpu_cudnn_conv2d_backward_filter(const absl::flat_hash_map &attrs, CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); CHECK_EQ(attrs.size(), 17); // Here the input paddings are pad_top, pad_bottom, pad_left, pad_right. - // Since pad_top==pad_bottom and pad_left==pad_rifht, we only take pad_top and pad_left. - int input_n = attrs[0]; - int input_c = attrs[1]; - int input_h = attrs[2]; - int input_w = attrs[3]; - int kernel_h = attrs[4]; - int kernel_w = attrs[5]; - int pad_h = attrs[6]; - int pad_w = attrs[8]; - int stride_h = attrs[10]; - int stride_w = attrs[11]; - int output_n = attrs[12]; - int output_c = attrs[13]; - int output_h = attrs[14]; - int output_w = attrs[15]; - int adaptive = attrs[16]; + // Since pad_top==pad_bottom and pad_left==pad_rifht, we only take pad_top and + // pad_left. + int input_n = attrs[0]; + int input_c = attrs[1]; + int input_h = attrs[2]; + int input_w = attrs[3]; + int kernel_h = attrs[4]; + int kernel_w = attrs[5]; + int pad_h = attrs[6]; + int pad_w = attrs[8]; + int stride_h = attrs[10]; + int stride_w = attrs[11]; + int output_n = attrs[12]; + int output_c = attrs[13]; + int output_h = attrs[14]; + int output_w = attrs[15]; + int adaptive = attrs[16]; std::string pool_type = str_attrs[0]; cudnnPoolingDescriptor_t pooling_desc; CUDNN_CALL(cudnnCreatePoolingDescriptor(&pooling_desc)); @@ -2170,33 +2674,65 @@ void cinn_gpu_cudnn_pool2d(const std::vector &attrs, auto data_type = convert_to_cudnn_dtype(input); - CUDNN_CALL(cudnnSetPooling2dDescriptor( - pooling_desc, pool_mode, CUDNN_NOT_PROPAGATE_NAN, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w)); + CUDNN_CALL(cudnnSetPooling2dDescriptor(pooling_desc, + pool_mode, + CUDNN_NOT_PROPAGATE_NAN, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w)); cudnnTensorDescriptor_t in_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc, CUDNN_TENSOR_NCHW, data_type, input_n, input_c, input_h, input_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc, + CUDNN_TENSOR_NCHW, + data_type, + input_n, + input_c, + input_h, + input_w)); cudnnTensorDescriptor_t out_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc)); - CUDNN_CALL( - cudnnSetTensor4dDescriptor(out_desc, CUDNN_TENSOR_NCHW, data_type, output_n, output_c, output_h, output_w)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc, + CUDNN_TENSOR_NCHW, + data_type, + output_n, + output_c, + output_h, + output_w)); - void *in_data = input->memory; + void *in_data = input->memory; void *out_data = output->memory; if (data_type == CUDNN_DATA_DOUBLE) { double alpha = 1.0f; - double beta = 0.0f; - CUDNN_CALL(cudnnPoolingForward(handle, pooling_desc, &alpha, in_desc, in_data, &beta, out_desc, out_data)); + double beta = 0.0f; + CUDNN_CALL(cudnnPoolingForward(handle, + pooling_desc, + &alpha, + in_desc, + in_data, + &beta, + out_desc, + out_data)); } else { float alpha = 1.0f; - float beta = 0.0f; - CUDNN_CALL(cudnnPoolingForward(handle, pooling_desc, &alpha, in_desc, in_data, &beta, out_desc, out_data)); + float beta = 0.0f; + CUDNN_CALL(cudnnPoolingForward(handle, + pooling_desc, + &alpha, + in_desc, + in_data, + &beta, + out_desc, + out_data)); } cudnnDestroyTensorDescriptor(in_desc); @@ -2213,8 +2749,8 @@ void cinn_gpu_cudnn_softmax(const std::vector &attrs, for (int i = 0; i < rank; i++) { shape.push_back(attrs[i]); } - int axis = attrs.back(); - axis = axis < 0 ? rank + axis : axis; + int axis = attrs.back(); + axis = axis < 0 ? rank + axis : axis; int inner_num = 1; int outer_num = 1; for (int i = 0; i < shape.size(); i++) { @@ -2229,20 +2765,32 @@ void cinn_gpu_cudnn_softmax(const std::vector &attrs, cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle(); CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); - void *in_data = input->memory; + void *in_data = input->memory; void *out_data = output->memory; cudnnTensorDescriptor_t in_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc, CUDNN_TENSOR_NCHW, data_type, outer_num, shape[axis], inner_num, 1)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc, + CUDNN_TENSOR_NCHW, + data_type, + outer_num, + shape[axis], + inner_num, + 1)); cudnnTensorDescriptor_t out_desc; CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc)); - CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc, CUDNN_TENSOR_NCHW, data_type, outer_num, shape[axis], inner_num, 1)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc, + CUDNN_TENSOR_NCHW, + data_type, + outer_num, + shape[axis], + inner_num, + 1)); if (data_type == CUDNN_DATA_DOUBLE) { double alpha = 1.f; - double beta = 0.f; + double beta = 0.f; CUDNN_CALL(cudnnSoftmaxForward(handle, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, @@ -2254,7 +2802,7 @@ void cinn_gpu_cudnn_softmax(const std::vector &attrs, out_data)); } else { float alpha = 1.f; - float beta = 0.f; + float beta = 0.f; CUDNN_CALL(cudnnSoftmaxForward(handle, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, diff --git a/paddle/cinn/runtime/cuda/cuda_util.h b/paddle/cinn/runtime/cuda/cuda_util.h index 1e8d691e48fa3..de30699fa9de4 100644 --- a/paddle/cinn/runtime/cuda/cuda_util.h +++ b/paddle/cinn/runtime/cuda/cuda_util.h @@ -42,15 +42,37 @@ void cinn_gpu_cublas_gemm(const std::vector& attrs, cinn_buffer_t* output, cudaStream_t stream = nullptr); -void cinn_call_gaussian_random(void* v_args, int num_args, float mean, float std, int seed, void* stream = nullptr); +void cinn_call_gaussian_random(void* v_args, + int num_args, + float mean, + float std, + int seed, + void* stream = nullptr); -void cinn_call_uniform_random(void* v_args, int num_args, float min, float max, int seed, void* stream = nullptr); +void cinn_call_uniform_random(void* v_args, + int num_args, + float min, + float max, + int seed, + void* stream = nullptr); -void cinn_call_randint(void* v_args, int num_args, int seed, void* stream = nullptr); +void cinn_call_randint(void* v_args, + int num_args, + int seed, + void* stream = nullptr); -void cinn_call_cholesky_nvgpu(void* v_args, int num_args, int batch_size, int m, bool upper, void* stream = nullptr); +void cinn_call_cholesky_nvgpu(void* v_args, + int num_args, + int batch_size, + int m, + bool upper, + void* stream = nullptr); -void cinn_assert_true_nvgpu(void* v_args, int num_args, int msg, bool only_warning, void* stream = nullptr); +void cinn_assert_true_nvgpu(void* v_args, + int num_args, + int msg, + bool only_warning, + void* stream = nullptr); void cinn_call_triangular_solve_nvgpu(void* v_args, int num_args, @@ -63,8 +85,15 @@ void cinn_call_triangular_solve_nvgpu(void* v_args, bool unit_diagonal, void* stream = nullptr); -void cinn_call_cuda_memset(void* v_args, int num_args, int value, size_t count, void* stream = nullptr); -void cinn_call_cuda_memcpy(void* v_args, int num_args, size_t count, void* stream = nullptr); +void cinn_call_cuda_memset(void* v_args, + int num_args, + int value, + size_t count, + void* stream = nullptr); +void cinn_call_cuda_memcpy(void* v_args, + int num_args, + size_t count, + void* stream = nullptr); /** * Call a CUDA compiled kernel. @@ -123,20 +152,22 @@ void cinn_gpu_cudnn_conv2d(const absl::flat_hash_map& attr, cinn_buffer_t* x, cinn_buffer_t* w, cinn_buffer_t* y, - cudaStream_t stream = nullptr, + cudaStream_t stream = nullptr, common::Layout target = common::Layout::kNCHW); -void cinn_gpu_cudnn_conv2d_backward_data(const absl::flat_hash_map& attr, - cinn_buffer_t* w, - cinn_buffer_t* dy, - cinn_buffer_t* dx, - cudaStream_t stream = nullptr); +void cinn_gpu_cudnn_conv2d_backward_data( + const absl::flat_hash_map& attr, + cinn_buffer_t* w, + cinn_buffer_t* dy, + cinn_buffer_t* dx, + cudaStream_t stream = nullptr); -void cinn_gpu_cudnn_conv2d_backward_filter(const absl::flat_hash_map& attr, - cinn_buffer_t* x, - cinn_buffer_t* dy, - cinn_buffer_t* dw, - cudaStream_t stream = nullptr); +void cinn_gpu_cudnn_conv2d_backward_filter( + const absl::flat_hash_map& attr, + cinn_buffer_t* x, + cinn_buffer_t* dy, + cinn_buffer_t* dw, + cudaStream_t stream = nullptr); void cinn_gpu_cudnn_pool2d(const std::vector& attrs, const std::vector& str_attrs, diff --git a/paddle/cinn/runtime/cuda/float16.h b/paddle/cinn/runtime/cuda/float16.h index 4bf8c64614b17..15bd2cee3fc69 100644 --- a/paddle/cinn/runtime/cuda/float16.h +++ b/paddle/cinn/runtime/cuda/float16.h @@ -19,7 +19,8 @@ #pragma once #endif // __cplusplus -#if defined(_M_X64) || defined(__x86_64__) || defined(_M_IX86) || defined(__i386__) +#if defined(_M_X64) || defined(__x86_64__) || defined(_M_IX86) || \ + defined(__i386__) #define __CINN_x86__ #include #endif @@ -74,12 +75,12 @@ struct CINN_ALIGN(2) float16 { #ifdef __cplusplus // The following defaulted special class member functions // are added to make float16 pass the std::is_trivial test - float16() = default; + float16() = default; float16(const float16& o) = default; float16& operator=(const float16& o) = default; - float16(float16&& o) = default; + float16(float16&& o) = default; float16& operator=(float16&& o) = default; - ~float16() = default; + ~float16() = default; // Constructors #ifdef CINN_CUDA_FP16 @@ -95,7 +96,7 @@ struct CINN_ALIGN(2) float16 { __host__ __device__ inline explicit float16(float val) { #if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300) half tmp = __float2half(val); - x = *reinterpret_cast(&tmp); + x = *reinterpret_cast(&tmp); #elif defined(__F16C__) && defined(__CINN_x86__) x = _cvtss_sh(val, 0); @@ -104,7 +105,7 @@ struct CINN_ALIGN(2) float16 { // Conversion routine adapted from // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion Bits v, s; - v.f = val; + v.f = val; uint32_t sign = v.si & sigN; v.si ^= sign; sign >>= shiftSign; // logical shift @@ -124,7 +125,8 @@ struct CINN_ALIGN(2) float16 { __host__ __device__ inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {} template - __host__ __device__ inline explicit float16(const T& val) : x(float16(static_cast(val)).x) {} + __host__ __device__ inline explicit float16(const T& val) + : x(float16(static_cast(val)).x) {} // Assignment operators #ifdef CINN_CUDA_FP16 @@ -220,7 +222,7 @@ struct CINN_ALIGN(2) float16 { // Conversion routine adapted from // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion Bits v; - v.ui = this->x; + v.ui = this->x; int32_t sign = v.si & sigC; v.si ^= sign; sign <<= shiftSign; @@ -238,9 +240,13 @@ struct CINN_ALIGN(2) float16 { #endif } - __host__ __device__ inline explicit operator bool() const { return (x & 0x7fff) != 0; } + __host__ __device__ inline explicit operator bool() const { + return (x & 0x7fff) != 0; + } - __host__ __device__ inline explicit operator int8_t() const { return static_cast(static_cast(*this)); } + __host__ __device__ inline explicit operator int8_t() const { + return static_cast(static_cast(*this)); + } __host__ __device__ inline explicit operator uint8_t() const { return static_cast(static_cast(*this)); @@ -270,7 +276,9 @@ struct CINN_ALIGN(2) float16 { return static_cast(static_cast(*this)); } - __host__ __device__ inline operator double() const { return static_cast(static_cast(*this)); } + __host__ __device__ inline operator double() const { + return static_cast(static_cast(*this)); + } private: union Bits { @@ -279,7 +287,7 @@ struct CINN_ALIGN(2) float16 { uint32_t ui; }; - static const int shift = 13; + static const int shift = 13; static const int shiftSign = 16; static const int32_t infN = 0x7F800000; @@ -288,7 +296,8 @@ struct CINN_ALIGN(2) float16 { static const int32_t sigN = 0x80000000; // sign bit static constexpr int32_t infC = infN >> shift; - static constexpr int32_t nanN = (infC + 1) << shift; // minimum flt16 nan as float32 + static constexpr int32_t nanN = (infC + 1) + << shift; // minimum flt16 nan as float32 static constexpr int32_t maxC = maxN >> shift; static constexpr int32_t minC = minN >> shift; static constexpr int32_t sigC = sigN >> shiftSign; @@ -353,7 +362,7 @@ __device__ inline half operator*(const half& a, const half& b) { __device__ inline half operator/(const half& a, const half& b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 - float num = __half2float(a); + float num = __half2float(a); float denom = __half2float(b); return __float2half(num / denom); #else @@ -442,7 +451,8 @@ __device__ inline bool operator>=(const half& a, const half& b) { #endif // CINN_CUDA_FP16 // Arithmetic operators for float16 on GPU -__host__ __device__ inline float16 operator+(const float16& a, const float16& b) { +__host__ __device__ inline float16 operator+(const float16& a, + const float16& b) { #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hadd(a.to_half(), b.to_half())); #else @@ -450,7 +460,8 @@ __host__ __device__ inline float16 operator+(const float16& a, const float16& b) #endif } -__host__ __device__ inline float16 operator-(const float16& a, const float16& b) { +__host__ __device__ inline float16 operator-(const float16& a, + const float16& b) { #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hsub(a.to_half(), b.to_half())); #else @@ -458,7 +469,8 @@ __host__ __device__ inline float16 operator-(const float16& a, const float16& b) #endif } -__host__ __device__ inline float16 operator*(const float16& a, const float16& b) { +__host__ __device__ inline float16 operator*(const float16& a, + const float16& b) { #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return float16(__hmul(a.to_half(), b.to_half())); #else @@ -466,10 +478,11 @@ __host__ __device__ inline float16 operator*(const float16& a, const float16& b) #endif } -__host__ __device__ inline float16 operator/(const float16& a, const float16& b) { +__host__ __device__ inline float16 operator/(const float16& a, + const float16& b) { #if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 // TODO(kexinzhao): check which cuda version starts to support __hdiv - float num = __half2float(a.to_half()); + float num = __half2float(a.to_half()); float denom = __half2float(b.to_half()); return float16(num / denom); #else @@ -487,22 +500,26 @@ __host__ __device__ inline float16 operator-(const float16& a) { #endif } -__host__ __device__ inline float16& operator+=(float16& a, const float16& b) { // NOLINT +__host__ __device__ inline float16& operator+=(float16& a, + const float16& b) { // NOLINT a = a + b; return a; } -__host__ __device__ inline float16& operator-=(float16& a, const float16& b) { // NOLINT +__host__ __device__ inline float16& operator-=(float16& a, + const float16& b) { // NOLINT a = a - b; return a; } -__host__ __device__ inline float16& operator*=(float16& a, const float16& b) { // NOLINT +__host__ __device__ inline float16& operator*=(float16& a, + const float16& b) { // NOLINT a = a * b; return a; } -__host__ __device__ inline float16& operator/=(float16& a, const float16& b) { // NOLINT +__host__ __device__ inline float16& operator/=(float16& a, + const float16& b) { // NOLINT a = a / b; return a; } @@ -570,9 +587,13 @@ __host__ __device__ inline bool(isnan)(const float16& a) { #endif } -__host__ __device__ inline bool(isinf)(const float16& a) { return (a.x & 0x7fff) == 0x7c00; } +__host__ __device__ inline bool(isinf)(const float16& a) { + return (a.x & 0x7fff) == 0x7c00; +} -__host__ __device__ inline bool(isfinite)(const float16& a) { return !((isnan)(a)) && !((isinf)(a)); } +__host__ __device__ inline bool(isfinite)(const float16& a) { + return !((isnan)(a)) && !((isinf)(a)); +} __host__ __device__ inline float16(abs)(const float16& a) { #if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) @@ -582,7 +603,9 @@ __host__ __device__ inline float16(abs)(const float16& a) { #endif } -__host__ __device__ inline float16(log)(const float16& a) { return float16(std::log(static_cast(a))); } +__host__ __device__ inline float16(log)(const float16& a) { + return float16(std::log(static_cast(a))); +} #ifdef __cplusplus } // namespace common @@ -594,34 +617,43 @@ __device__ inline cinn::common::float16 __shfl_sync(unsigned mask, cinn::common::float16 var, int srcLane, int width = warpSize) { - return cinn::common::float16(__shfl_sync(mask, var.to_half(), srcLane, width)); + return cinn::common::float16( + __shfl_sync(mask, var.to_half(), srcLane, width)); } -__device__ inline cinn::common::float16 __shfl_up_sync(unsigned mask, - cinn::common::float16 var, - unsigned int delta, - int width = warpSize) { - return cinn::common::float16(__shfl_up_sync(mask, var.to_half(), delta, width)); +__device__ inline cinn::common::float16 __shfl_up_sync( + unsigned mask, + cinn::common::float16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::float16( + __shfl_up_sync(mask, var.to_half(), delta, width)); } -__device__ inline cinn::common::float16 __shfl_down_sync(unsigned mask, - cinn::common::float16 var, - unsigned int delta, - int width = warpSize) { - return cinn::common::float16(__shfl_down_sync(mask, var.to_half(), delta, width)); +__device__ inline cinn::common::float16 __shfl_down_sync( + unsigned mask, + cinn::common::float16 var, + unsigned int delta, + int width = warpSize) { + return cinn::common::float16( + __shfl_down_sync(mask, var.to_half(), delta, width)); } -__device__ inline cinn::common::float16 __shfl_xor_sync(unsigned mask, - cinn::common::float16 var, - int laneMask, - int width = warpSize) { - return cinn::common::float16(__shfl_xor_sync(mask, var.to_half(), laneMask, width)); +__device__ inline cinn::common::float16 __shfl_xor_sync( + unsigned mask, + cinn::common::float16 var, + int laneMask, + int width = warpSize) { + return cinn::common::float16( + __shfl_xor_sync(mask, var.to_half(), laneMask, width)); } -__host__ __device__ inline cinn::common::float16 max(const cinn::common::float16& a, const cinn::common::float16& b) { +__host__ __device__ inline cinn::common::float16 max( + const cinn::common::float16& a, const cinn::common::float16& b) { return a > b ? a : b; } -__host__ __device__ inline cinn::common::float16 min(const cinn::common::float16& a, const cinn::common::float16& b) { +__host__ __device__ inline cinn::common::float16 min( + const cinn::common::float16& a, const cinn::common::float16& b) { return a < b ? a : b; } #endif // __cplusplus && CINN_CUDA_FP16 diff --git a/paddle/cinn/runtime/custom_function.cc b/paddle/cinn/runtime/custom_function.cc index 3b612b208f89e..d04d00c2d3d72 100644 --- a/paddle/cinn/runtime/custom_function.cc +++ b/paddle/cinn/runtime/custom_function.cc @@ -32,10 +32,13 @@ using hlir::framework::Shape; using hlir::framework::Tensor; namespace utils { -void AssertTrueMsgTool::SetMsg(int key, const std::string& msg) { global_msg_[key] = msg; } +void AssertTrueMsgTool::SetMsg(int key, const std::string& msg) { + global_msg_[key] = msg; +} const std::string& AssertTrueMsgTool::GetMsg(int key) { - CHECK(global_msg_.find(key) != global_msg_.end()) << "Cannot find assert_true message key " << key; + CHECK(global_msg_.find(key) != global_msg_.end()) + << "Cannot find assert_true message key " << key; return global_msg_[key]; } @@ -45,24 +48,30 @@ void AssertTrueMsgTool::InitFlagInfo() { return; } // default value - flag_values_ = {{"only_warning", false}, {"rtol", 1e-5f}, {"atol", 1e-8f}, {"equal_nan", false}}; + flag_values_ = {{"only_warning", false}, + {"rtol", 1e-5f}, + {"atol", 1e-8f}, + {"equal_nan", false}}; if (CheckStringFlagFalse(FLAGS_cinn_check_fusion_accuracy_pass) || CheckStringFlagTrue(FLAGS_cinn_check_fusion_accuracy_pass)) { // using default value - LOG(INFO) << "The FLAGS_cinn_check_fusion_accuracy_pass will check fusion group accuracy with: " + LOG(INFO) << "The FLAGS_cinn_check_fusion_accuracy_pass will check fusion " + "group accuracy with: " "\"only_warning=false;rtol=1e-5;atol=1e-8;equal_nan=false\""; return; } // parse flags - const auto& args = cinn::utils::Split(FLAGS_cinn_check_fusion_accuracy_pass, ";"); + const auto& args = + cinn::utils::Split(FLAGS_cinn_check_fusion_accuracy_pass, ";"); for (const auto& str : args) { if (str.empty()) { continue; } const auto& flag_arg = cinn::utils::Split(str, "="); - CHECK_EQ(flag_arg.size(), 2UL) << "The FLAGS_cinn_check_fusion_accuracy_pass must be the format of " - "\"only_warning=false;rtol=1e-5;atol=1e-8;equal_nan=false\""; + CHECK_EQ(flag_arg.size(), 2UL) + << "The FLAGS_cinn_check_fusion_accuracy_pass must be the format of " + "\"only_warning=false;rtol=1e-5;atol=1e-8;equal_nan=false\""; if (flag_arg[0] == "only_warning" || flag_arg[0] == "equal_nan") { // bool type parameter @@ -71,19 +80,30 @@ void AssertTrueMsgTool::InitFlagInfo() { // string type parameter flag_values_[flag_arg[0]] = std::stof(flag_arg[1]); } else { - LOG(FATAL) << "The FLAGS_cinn_check_fusion_accuracy_pass only support parameter " - "\"only_warning/rtol/atol/equal_nan\" now"; + LOG(FATAL) + << "The FLAGS_cinn_check_fusion_accuracy_pass only support parameter " + "\"only_warning/rtol/atol/equal_nan\" now"; } } - LOG(INFO) << "The FLAGS_cinn_check_fusion_accuracy_pass will check fusion group accuracy with: \"" - << "only_warning=" << cinn::utils::Attribute2String(flag_values_.at("only_warning")) - << ";rtol=" << cinn::utils::Attribute2String(flag_values_.at("rtol")) - << ";atol=" << cinn::utils::Attribute2String(flag_values_.at("atol")) - << ";equal_nan=" << cinn::utils::Attribute2String(flag_values_.at("equal_nan")) << "\""; + LOG(INFO) << "The FLAGS_cinn_check_fusion_accuracy_pass will check fusion " + "group accuracy with: \"" + << "only_warning=" + << cinn::utils::Attribute2String(flag_values_.at("only_warning")) + << ";rtol=" + << cinn::utils::Attribute2String(flag_values_.at("rtol")) + << ";atol=" + << cinn::utils::Attribute2String(flag_values_.at("atol")) + << ";equal_nan=" + << cinn::utils::Attribute2String(flag_values_.at("equal_nan")) + << "\""; } -bool MemcpyToHost(void* dst, const void* src, size_t bytes, const Target& input_target, void* stream = nullptr) { +bool MemcpyToHost(void* dst, + const void* src, + size_t bytes, + const Target& input_target, + void* stream = nullptr) { if (input_target == common::DefaultNVGPUTarget()) { #ifdef CINN_WITH_CUDA const auto& cuda_stream = static_cast(stream); @@ -91,7 +111,8 @@ bool MemcpyToHost(void* dst, const void* src, size_t bytes, const Target& input_ cudaStreamSynchronize(cuda_stream); return true; #else - LOG(FATAL) << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; + LOG(FATAL) + << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; return false; #endif } @@ -99,34 +120,51 @@ bool MemcpyToHost(void* dst, const void* src, size_t bytes, const Target& input_ memcpy(dst, src, bytes); return true; } - LOG(FATAL) << "MemcpyToHost Only support cpu or nvgpu -> cpu, but here the input target is " << input_target - << "! Please check."; + LOG(FATAL) << "MemcpyToHost Only support cpu or nvgpu -> cpu, but here the " + "input target is " + << input_target << "! Please check."; return false; } -bool MemcpyToDevice(void* dst, const void* src, size_t bytes, const Target& input_target, void* stream = nullptr) { +bool MemcpyToDevice(void* dst, + const void* src, + size_t bytes, + const Target& input_target, + void* stream = nullptr) { #ifdef CINN_WITH_CUDA if (input_target == common::DefaultNVGPUTarget()) { - cudaMemcpyAsync(dst, src, bytes, cudaMemcpyDeviceToDevice, static_cast(stream)); + cudaMemcpyAsync(dst, + src, + bytes, + cudaMemcpyDeviceToDevice, + static_cast(stream)); return true; } else if (input_target == common::DefaultHostTarget()) { - cudaMemcpyAsync(dst, src, bytes, cudaMemcpyHostToDevice, static_cast(stream)); + cudaMemcpyAsync(dst, + src, + bytes, + cudaMemcpyHostToDevice, + static_cast(stream)); return true; } else { - LOG(FATAL) << "MemcpyToDevice only support cpu or nvgpu -> nvgpu, but here the input target is " << input_target - << "! Please check."; + LOG(FATAL) << "MemcpyToDevice only support cpu or nvgpu -> nvgpu, but here " + "the input target is " + << input_target << "! Please check."; return false; } #else - LOG(FATAL) - << "MemcpyToDevice only support nvgpu, and NVGPU Target only support when flag CINN_WITH_CUDA ON! Please check."; + LOG(FATAL) << "MemcpyToDevice only support nvgpu, and NVGPU Target only " + "support when flag CINN_WITH_CUDA ON! Please check."; return false; #endif } } // namespace utils -void CheckAssertTrue( - const bool* x, const size_t numel, bool only_warning, const std::string& msg, const Target& target) { +void CheckAssertTrue(const bool* x, + const size_t numel, + bool only_warning, + const std::string& msg, + const Target& target) { // check false number and first false offset int error_num = 0, first_diff = -1; for (int i = 0; i < numel; ++i) { @@ -157,17 +195,23 @@ void CheckAssertTrue( } } -void cinn_assert_true(void* v_args, int num_args, int msg, bool only_warning, void* stream, const Target& target) { +void cinn_assert_true(void* v_args, + int num_args, + int msg, + bool only_warning, + void* stream, + const Target& target) { // why x->type and output->type are empty? - // CHECK(x->type == cinn_bool_t()) << "The input type of AssertTrue should be bool, but here " << x->type.bits + // CHECK(x->type == cinn_bool_t()) << "The input type of AssertTrue should be + // bool, but here " << x->type.bits // << "! Please check."; - // CHECK(output->type == cinn_bool_t()) << "The output type of AssertTrue should be bool, but here " << - // output->type.bits + // CHECK(output->type == cinn_bool_t()) << "The output type of AssertTrue + // should be bool, but here " << output->type.bits // << "! Please check."; cinn_pod_value_t* args = static_cast(v_args); - cinn_buffer_t* x = args[0].operator cinn_buffer_t*(); + cinn_buffer_t* x = args[0].operator cinn_buffer_t*(); cinn_buffer_t* output = args[1].operator cinn_buffer_t*(); // create cpu tensor @@ -183,15 +227,21 @@ void cinn_assert_true(void* v_args, int num_args, int msg, bool only_warning, vo // copy data from gpu to cpu const bool* src = reinterpret_cast(x->memory); - size_t numel = cpu_tensor->shape().numel(); + size_t numel = cpu_tensor->shape().numel(); utils::MemcpyToHost(dst, src, numel * sizeof(bool), target, stream); - CheckAssertTrue(dst, numel, only_warning, utils::AssertTrueMsgTool::GetInstance()->GetMsg(msg), target); + CheckAssertTrue(dst, + numel, + only_warning, + utils::AssertTrueMsgTool::GetInstance()->GetMsg(msg), + target); if (target == common::DefaultNVGPUTarget()) { - utils::MemcpyToDevice(output->memory, x->memory, numel * sizeof(bool), target, stream); + utils::MemcpyToDevice( + output->memory, x->memory, numel * sizeof(bool), target, stream); } else { - utils::MemcpyToHost(output->memory, x->memory, numel * sizeof(bool), target, stream); + utils::MemcpyToHost( + output->memory, x->memory, numel * sizeof(bool), target, stream); } } diff --git a/paddle/cinn/runtime/custom_function.h b/paddle/cinn/runtime/custom_function.h index 5c8a6c3a34a70..103da8b5eba89 100644 --- a/paddle/cinn/runtime/custom_function.h +++ b/paddle/cinn/runtime/custom_function.h @@ -43,8 +43,10 @@ class AssertTrueMsgTool { const T& GetFlagValue(const std::string& param) { InitFlagInfo(); CHECK(flag_values_.count(param)) - << "The FLAGS_cinn_check_fusion_accuracy_pass only support parameter \"only_warning/rtol/atol/equal_nan\" now"; - CHECK(absl::holds_alternative(flag_values_.at(param))) << "Try get value from a error type!"; + << "The FLAGS_cinn_check_fusion_accuracy_pass only support parameter " + "\"only_warning/rtol/atol/equal_nan\" now"; + CHECK(absl::holds_alternative(flag_values_.at(param))) + << "Try get value from a error type!"; return absl::get(flag_values_.at(param)); } @@ -60,7 +62,12 @@ class AssertTrueMsgTool { }; } // namespace utils -void cinn_assert_true(void* v_args, int num_args, int msg, bool only_warning, void* stream, const Target& target); +void cinn_assert_true(void* v_args, + int num_args, + int msg, + bool only_warning, + void* stream, + const Target& target); } // namespace runtime } // namespace cinn diff --git a/paddle/cinn/runtime/custom_function_test.cc b/paddle/cinn/runtime/custom_function_test.cc index 0f5935fa774f6..df88a0e4b817b 100644 --- a/paddle/cinn/runtime/custom_function_test.cc +++ b/paddle/cinn/runtime/custom_function_test.cc @@ -36,14 +36,19 @@ namespace runtime { class CinnBufferAllocHelper { public: - CinnBufferAllocHelper(cinn_device_kind_t device, cinn_type_t type, const std::vector& shape, int align = 0) { + CinnBufferAllocHelper(cinn_device_kind_t device, + cinn_type_t type, + const std::vector& shape, + int align = 0) { buffer_ = cinn_buffer_t::new_(device, type, shape, align); } template T* mutable_data(const Target& target) { if (target_ != common::UnkTarget()) { - CHECK_EQ(target, target_) << "Cannot alloc twice, the memory had alloced at " << target_ << "! Please check."; + CHECK_EQ(target, target_) + << "Cannot alloc twice, the memory had alloced at " << target_ + << "! Please check."; return reinterpret_cast(buffer_->memory); } @@ -54,10 +59,12 @@ class CinnBufferAllocHelper { #ifdef CINN_WITH_CUDA cudaMalloc(&buffer_->memory, buffer_->num_elements() * sizeof(T)); #else - LOG(FATAL) << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; + LOG(FATAL) << "NVGPU Target only support on flag CINN_WITH_CUDA ON! " + "Please check."; #endif } else { - LOG(FATAL) << "Only support nvgpu and cpu, but here " << target << "! Please check."; + LOG(FATAL) << "Only support nvgpu and cpu, but here " << target + << "! Please check."; } return reinterpret_cast(buffer_->memory); @@ -81,10 +88,12 @@ class CinnBufferAllocHelper { #ifdef CINN_WITH_CUDA cudaFree(buffer_->memory); #else - LOG(FATAL) << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; + LOG(FATAL) << "NVGPU Target only support on flag CINN_WITH_CUDA ON! " + "Please check."; #endif } else { - LOG(FATAL) << "Only support nvgpu and cpu, but here " << target_ << "! Please check."; + LOG(FATAL) << "Only support nvgpu and cpu, but here " << target_ + << "! Please check."; } delete buffer_; } @@ -100,7 +109,10 @@ class CinnBufferAllocHelper { }; template -void SetInputValue(T* input, const T* input_h, size_t num, const Target& target) { +void SetInputValue(T* input, + const T* input_h, + size_t num, + const Target& target) { if (target == common::DefaultHostTarget()) { for (int i = 0; i < num; ++i) { input[i] = input_h[i]; @@ -109,7 +121,8 @@ void SetInputValue(T* input, const T* input_h, size_t num, const Target& target) #ifdef CINN_WITH_CUDA cudaMemcpy(input, input_h, num * sizeof(T), cudaMemcpyHostToDevice); #else - LOG(FATAL) << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; + LOG(FATAL) + << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; #endif } } @@ -121,30 +134,33 @@ TEST(CinnAssertTrue, test_true) { // set inpute value true bool input_h = true; - auto* input = x.mutable_data(target); + auto* input = x.mutable_data(target); SetInputValue(input, &input_h, 1, target); CinnBufferAllocHelper y(cinn_x86_device, cinn_bool_t(), {1}); auto* output = y.mutable_data(target); - cinn_pod_value_t v_args[2] = {cinn_pod_value_t(x.get()), cinn_pod_value_t(y.get())}; + cinn_pod_value_t v_args[2] = {cinn_pod_value_t(x.get()), + cinn_pod_value_t(y.get())}; std::stringstream ss; ss << "Test AssertTrue(true) on " << target; const auto& msg = ss.str(); - int msg_key = static_cast(std::hash()(msg)); + int msg_key = static_cast(std::hash()(msg)); cinn::runtime::utils::AssertTrueMsgTool::GetInstance()->SetMsg(msg_key, msg); cinn_assert_true(v_args, 2, msg_key, true, nullptr, target); if (target == common::DefaultHostTarget()) { - ASSERT_EQ(input[0], output[0]) << "The output of AssertTrue should be the same as input"; + ASSERT_EQ(input[0], output[0]) + << "The output of AssertTrue should be the same as input"; } else if (target == common::DefaultNVGPUTarget()) { #ifdef CINN_WITH_CUDA bool output_h = false; cudaMemcpy(&output_h, output, sizeof(bool), cudaMemcpyDeviceToHost); - ASSERT_EQ(input_h, output_h) << "The output of AssertTrue should be the same as input"; + ASSERT_EQ(input_h, output_h) + << "The output of AssertTrue should be the same as input"; #endif } } @@ -156,30 +172,33 @@ TEST(CinnAssertTrue, test_false_only_warning) { // set inpute value false bool input_h = false; - auto* input = x.mutable_data(target); + auto* input = x.mutable_data(target); SetInputValue(input, &input_h, 1, target); CinnBufferAllocHelper y(cinn_x86_device, cinn_bool_t(), {1}); auto* output = y.mutable_data(target); - cinn_pod_value_t v_args[2] = {cinn_pod_value_t(x.get()), cinn_pod_value_t(y.get())}; + cinn_pod_value_t v_args[2] = {cinn_pod_value_t(x.get()), + cinn_pod_value_t(y.get())}; std::stringstream ss; ss << "Test AssertTrue(false, only_warning=true) on " << target; const auto& msg = ss.str(); - int msg_key = static_cast(std::hash()(msg)); + int msg_key = static_cast(std::hash()(msg)); cinn::runtime::utils::AssertTrueMsgTool::GetInstance()->SetMsg(msg_key, msg); cinn_assert_true(v_args, 2, msg_key, true, nullptr, target); if (target == common::DefaultHostTarget()) { - ASSERT_EQ(input[0], output[0]) << "The output of AssertTrue should be the same as input"; + ASSERT_EQ(input[0], output[0]) + << "The output of AssertTrue should be the same as input"; } else if (target == common::DefaultNVGPUTarget()) { #ifdef CINN_WITH_CUDA bool output_h = false; cudaMemcpy(&output_h, output, sizeof(bool), cudaMemcpyDeviceToHost); - ASSERT_EQ(input_h, output_h) << "The output of AssertTrue should be the same as input"; + ASSERT_EQ(input_h, output_h) + << "The output of AssertTrue should be the same as input"; #endif } } @@ -198,14 +217,15 @@ TEST(CustomCallGaussianRandom, test_target_nvgpu) { CinnBufferAllocHelper out(cinn_x86_device, cinn_float32_t(), {2, 3}); auto* output = out.mutable_data(target); - int num_args = 1; + int num_args = 1; cinn_pod_value_t v_args[1] = {cinn_pod_value_t(out.get())}; if (target == common::DefaultHostTarget()) { LOG(INFO) << "Op gaussian random only support on NVGPU"; } else if (target == common::DefaultNVGPUTarget()) { #ifdef CINN_WITH_CUDA - cinn::runtime::cuda::cinn_call_gaussian_random(v_args, num_args, mean, std, seed, nullptr); + cinn::runtime::cuda::cinn_call_gaussian_random( + v_args, num_args, mean, std, seed, nullptr); float output_data[6] = {0.0}; cudaMemcpy(output_data, output, 6 * sizeof(float), cudaMemcpyDeviceToHost); @@ -213,7 +233,8 @@ TEST(CustomCallGaussianRandom, test_target_nvgpu) { VLOG(6) << output_data[i]; } #else - LOG(FATAL) << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; + LOG(FATAL) + << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; #endif } } @@ -232,14 +253,15 @@ TEST(CustomCallUniformRandom, test_target_nvgpu) { CinnBufferAllocHelper out(cinn_x86_device, cinn_float32_t(), {2, 3}); auto* output = out.mutable_data(target); - int num_args = 1; + int num_args = 1; cinn_pod_value_t v_args[1] = {cinn_pod_value_t(out.get())}; if (target == common::DefaultHostTarget()) { LOG(INFO) << "Op uniform random only support on NVGPU"; } else if (target == common::DefaultNVGPUTarget()) { #ifdef CINN_WITH_CUDA - cinn::runtime::cuda::cinn_call_uniform_random(v_args, num_args, min, max, seed, nullptr); + cinn::runtime::cuda::cinn_call_uniform_random( + v_args, num_args, min, max, seed, nullptr); float output_data[6] = {0.0f}; cudaMemcpy(output_data, output, 6 * sizeof(float), cudaMemcpyDeviceToHost); @@ -247,7 +269,8 @@ TEST(CustomCallUniformRandom, test_target_nvgpu) { VLOG(6) << output_data[i]; } #else - LOG(FATAL) << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; + LOG(FATAL) + << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; #endif } } @@ -264,8 +287,15 @@ TEST(CustomCallCholesky, test) { // Input matrix x CinnBufferAllocHelper x(cinn_x86_device, cinn_float32_t(), {m, m}); - float input_h[9] = { - 0.96329159, 0.88160539, 0.40593964, 0.88160539, 1.39001071, 0.48823422, 0.40593964, 0.48823422, 0.19755946}; + float input_h[9] = {0.96329159, + 0.88160539, + 0.40593964, + 0.88160539, + 1.39001071, + 0.48823422, + 0.40593964, + 0.48823422, + 0.19755946}; auto* input = x.mutable_data(target); SetInputValue(input, input_h, m * m, target); @@ -274,35 +304,60 @@ TEST(CustomCallCholesky, test) { auto* output = out.mutable_data(target); // Result matrix - // In the calculation result of MKL, the matrix !upper part is the same as the original input - float host_result[9] = { - 0.98147416, 0.88160539, 0.40593964, 0.89824611, 0.76365221, 0.48823422, 0.41360193, 0.15284170, 0.055967092}; - // In the calculation results of cuSOLVER, the upper and lower triangles of the matrix are the same - float gpu_result[9] = { - 0.98147416, 0.89824611, 0.41360193, 0.89824611, 0.76365221, 0.15284170, 0.41360193, 0.15284170, 0.055967092}; - - int num_args = 2; - cinn_pod_value_t v_args[2] = {cinn_pod_value_t(x.get()), cinn_pod_value_t(out.get())}; + // In the calculation result of MKL, the matrix !upper part is the same as the + // original input + float host_result[9] = {0.98147416, + 0.88160539, + 0.40593964, + 0.89824611, + 0.76365221, + 0.48823422, + 0.41360193, + 0.15284170, + 0.055967092}; + // In the calculation results of cuSOLVER, the upper and lower triangles of + // the matrix are the same + float gpu_result[9] = {0.98147416, + 0.89824611, + 0.41360193, + 0.89824611, + 0.76365221, + 0.15284170, + 0.41360193, + 0.15284170, + 0.055967092}; + + int num_args = 2; + cinn_pod_value_t v_args[2] = {cinn_pod_value_t(x.get()), + cinn_pod_value_t(out.get())}; if (target == common::DefaultHostTarget()) { #ifdef CINN_WITH_MKL_CBLAS cinn_call_cholesky_host(v_args, num_args, batch_size, m, upper); for (int i = 0; i < batch_size * m * m; i++) { - ASSERT_NEAR(output[i], host_result[i], 1e-5) << "The output of Cholesky should be the same as result"; + ASSERT_NEAR(output[i], host_result[i], 1e-5) + << "The output of Cholesky should be the same as result"; } #else - LOG(INFO) << "Host Target only support on flag CINN_WITH_MKL_CBLAS ON! Please check."; + LOG(INFO) << "Host Target only support on flag CINN_WITH_MKL_CBLAS ON! " + "Please check."; #endif } else if (target == common::DefaultNVGPUTarget()) { #ifdef CINN_WITH_CUDA - cinn::runtime::cuda::cinn_call_cholesky_nvgpu(v_args, num_args, batch_size, m, upper); + cinn::runtime::cuda::cinn_call_cholesky_nvgpu( + v_args, num_args, batch_size, m, upper); std::vector host_output(batch_size * m * m, 0.0f); - cudaMemcpy(host_output.data(), output, batch_size * m * m * sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(host_output.data(), + output, + batch_size * m * m * sizeof(float), + cudaMemcpyDeviceToHost); for (int i = 0; i < batch_size * m * m; i++) { - ASSERT_NEAR(host_output[i], gpu_result[i], 1e-5) << "The output of Cholesky should be the same as result"; + ASSERT_NEAR(host_output[i], gpu_result[i], 1e-5) + << "The output of Cholesky should be the same as result"; } #else - LOG(INFO) << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; + LOG(INFO) + << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; #endif } } @@ -311,12 +366,12 @@ TEST(CustomCallCholesky, test) { TEST(CustomCallTriangularSolve, test) { Target target = common::DefaultNVGPUTarget(); - int batch_size = 1; - int m = 3; - int k = 1; - bool left_side = true; - bool upper = true; - bool transpose_a = false; + int batch_size = 1; + int m = 3; + int k = 1; + bool left_side = true; + bool upper = true; + bool transpose_a = false; bool unit_diagonal = false; double input_a_host[9] = {1.0, 1.0, 1.0, 0.0, 2.0, 1.0, 0.0, 0.0, -1.0}; @@ -335,15 +390,27 @@ TEST(CustomCallTriangularSolve, test) { // Result matrix res double result[3] = {7.0, -2.0, -5.0}; - constexpr int num_args = 3; - cinn_pod_value_t v_args[num_args] = { - cinn_pod_value_t(a.get()), cinn_pod_value_t(b.get()), cinn_pod_value_t(out.get())}; - cinn::runtime::cuda::cinn_call_triangular_solve_nvgpu( - v_args, num_args, batch_size, m, k, left_side, upper, transpose_a, unit_diagonal); + constexpr int num_args = 3; + cinn_pod_value_t v_args[num_args] = {cinn_pod_value_t(a.get()), + cinn_pod_value_t(b.get()), + cinn_pod_value_t(out.get())}; + cinn::runtime::cuda::cinn_call_triangular_solve_nvgpu(v_args, + num_args, + batch_size, + m, + k, + left_side, + upper, + transpose_a, + unit_diagonal); std::vector device_output(batch_size * m * k, 0.0f); - cudaMemcpy(device_output.data(), output, batch_size * m * k * sizeof(double), cudaMemcpyDeviceToHost); + cudaMemcpy(device_output.data(), + output, + batch_size * m * k * sizeof(double), + cudaMemcpyDeviceToHost); for (int i = 0; i < batch_size * m * k; i++) { - ASSERT_NEAR(device_output[i], result[i], 1e-5) << "The output of triangular solve should be the same as result"; + ASSERT_NEAR(device_output[i], result[i], 1e-5) + << "The output of triangular solve should be the same as result"; } } #endif diff --git a/paddle/cinn/runtime/flags.cc b/paddle/cinn/runtime/flags.cc index cf213a58fe21e..0cb6193b810d2 100644 --- a/paddle/cinn/runtime/flags.cc +++ b/paddle/cinn/runtime/flags.cc @@ -37,28 +37,35 @@ using ::GFLAGS_NAMESPACE::Int32FromEnv; using ::GFLAGS_NAMESPACE::Int64FromEnv; using ::GFLAGS_NAMESPACE::StringFromEnv; -DEFINE_string(cinn_x86_builtin_code_root, StringFromEnv("FLAGS_cinn_x86_builtin_code_root", ""), ""); +DEFINE_string(cinn_x86_builtin_code_root, + StringFromEnv("FLAGS_cinn_x86_builtin_code_root", ""), + ""); DEFINE_string(cinn_nvcc_cmd_path, StringFromEnv("FLAGS_cinn_nvcc_cmd_path", "/usr/local/cuda/bin"), "Setting nvcc default path!"); DEFINE_int32(cinn_parallel_compile_size, Int32FromEnv("FLAGS_cinn_parallel_compile_size", 16), - "When use parallel compile, set the number of group compiled by each thread."); + "When use parallel compile, set the number of group compiled by " + "each thread."); DEFINE_int32(cinn_parallel_compile_thread, Int32FromEnv("FLAGS_cinn_parallel_compile_thread", -1), "How much thread the parallel compile used."); -DEFINE_bool(cinn_use_op_fusion, BoolFromEnv("FLAGS_cinn_use_op_fusion", true), "Whether to use op fusion pass."); +DEFINE_bool(cinn_use_op_fusion, + BoolFromEnv("FLAGS_cinn_use_op_fusion", true), + "Whether to use op fusion pass."); DEFINE_bool(cinn_use_common_subexpression_elimination, - BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination", false), + BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination", + false), "Whether to use common subexpression elimination pass."); -DEFINE_string(cinn_custom_call_deny_ops, - StringFromEnv("FLAGS_cinn_custom_call_deny_ops", ""), - "a blacklist of op are denied by MarkCustomCallOps pass, separated by ;"); +DEFINE_string( + cinn_custom_call_deny_ops, + StringFromEnv("FLAGS_cinn_custom_call_deny_ops", ""), + "a blacklist of op are denied by MarkCustomCallOps pass, separated by ;"); DEFINE_bool(cinn_use_custom_call, BoolFromEnv("FLAGS_cinn_use_custom_call", true), @@ -70,7 +77,8 @@ DEFINE_bool(cinn_use_fill_constant_folding, DEFINE_string(cinn_check_fusion_accuracy_pass, StringFromEnv("FLAGS_cinn_check_fusion_accuracy_pass", ""), - "Check the correct of fusion kernels, if the results not satisfied 'allclose(rtol=1e-05f, atol=1e-08f)', " + "Check the correct of fusion kernels, if the results not " + "satisfied 'allclose(rtol=1e-05f, atol=1e-08f)', " "report error and exited."); DEFINE_bool(cinn_use_cuda_vectorize, @@ -81,7 +89,9 @@ DEFINE_bool(cinn_ir_schedule, BoolFromEnv("FLAGS_cinn_ir_schedule", true), "Whether use reconstructed schedule primitives."); -DEFINE_bool(use_reduce_split_pass, BoolFromEnv("FLAGS_use_reduce_split_pass", false), "Whether use reduce split pass."); +DEFINE_bool(use_reduce_split_pass, + BoolFromEnv("FLAGS_use_reduce_split_pass", false), + "Whether use reduce split pass."); DEFINE_bool(cinn_use_dense_merge_pass, BoolFromEnv("FLAGS_cinn_use_dense_merge_pass", false), @@ -89,7 +99,8 @@ DEFINE_bool(cinn_use_dense_merge_pass, DEFINE_bool(nvrtc_compile_to_cubin, BoolFromEnv("FLAGS_nvrtc_compile_to_cubin", false), - "Whether nvrtc compile cuda source into cubin instead of ptx (only works after cuda-11.1)."); + "Whether nvrtc compile cuda source into cubin instead of ptx (only " + "works after cuda-11.1)."); DEFINE_bool(cinn_compile_with_nvrtc, BoolFromEnv("FLAGS_cinn_compile_with_nvrtc", true), @@ -98,11 +109,13 @@ DEFINE_bool(cinn_compile_with_nvrtc, // FLAGS for performance analysis and accuracy debug DEFINE_bool(cinn_sync_run, BoolFromEnv("FLAGS_cinn_sync_run", false), - "Whether sync all devices after each instruction run, which is used for debug."); + "Whether sync all devices after each instruction run, which is " + "used for debug."); DEFINE_string(cinn_self_check_accuracy, StringFromEnv("FLAGS_cinn_self_check_accuracy", ""), - "Whether self-check accuracy after each instruction run, which is used for debug."); + "Whether self-check accuracy after each instruction run, which " + "is used for debug."); DEFINE_int64(cinn_self_check_accuracy_num, Int64FromEnv("FLAGS_cinn_self_check_accuracy_num", 0L), @@ -110,21 +123,27 @@ DEFINE_int64(cinn_self_check_accuracy_num, DEFINE_string(cinn_fusion_groups_graphviz_dir, StringFromEnv("FLAGS_cinn_fusion_groups_graphviz_dir", ""), - "Specify the directory path of dot file of graph, which is used for debug."); + "Specify the directory path of dot file of graph, which is used " + "for debug."); DEFINE_string(cinn_source_code_save_path, StringFromEnv("FLAGS_cinn_source_code_save_path", ""), - "Specify the directory path of generated source code, which is used for debug."); + "Specify the directory path of generated source code, which is " + "used for debug."); DEFINE_string(cinn_pass_visualize_dir, StringFromEnv("FLAGS_cinn_pass_visualize_dir", ""), - "Specify the directory path of pass visualize file of graph, which is used for debug."); + "Specify the directory path of pass visualize file of graph, " + "which is used for debug."); -DEFINE_bool(enable_auto_tuner, BoolFromEnv("FLAGS_enable_auto_tuner", false), "Whether enable auto tuner."); +DEFINE_bool(enable_auto_tuner, + BoolFromEnv("FLAGS_enable_auto_tuner", false), + "Whether enable auto tuner."); DEFINE_bool(auto_schedule_use_cost_model, BoolFromEnv("FLAGS_auto_schedule_use_cost_model", true), - "Whether to use cost model in auto schedule, this is an on-developing flag and it will be removed when " + "Whether to use cost model in auto schedule, this is an " + "on-developing flag and it will be removed when " "cost model is stable."); DEFINE_bool(enhance_vertical_fusion_with_recompute, @@ -133,12 +152,13 @@ DEFINE_bool(enhance_vertical_fusion_with_recompute, DEFINE_bool(verbose_function_register, BoolFromEnv("FLAGS_verbose_function_register", false), - "Whether to verbose function regist log. This will only work if CINN build with flag -DWITH_DEBUG=ON."); + "Whether to verbose function regist log. This will only work if " + "CINN build with flag -DWITH_DEBUG=ON."); -DEFINE_int32( - cinn_profiler_state, - Int32FromEnv("FLAGS_cinn_profiler_state", -1), - "Specify the ProfilerState by Int in CINN, 0 for kDisabled, 1 for kCPU, 2 for kCUDA, 3 for kAll, default 0."); +DEFINE_int32(cinn_profiler_state, + Int32FromEnv("FLAGS_cinn_profiler_state", -1), + "Specify the ProfilerState by Int in CINN, 0 for kDisabled, 1 for " + "kCPU, 2 for kCUDA, 3 for kAll, default 0."); namespace cinn { namespace runtime { @@ -146,7 +166,8 @@ namespace runtime { bool CheckStringFlagTrue(const std::string& flag) { // from gflag FlagValue::ParseFrom: // https://github.com/gflags/gflags/blob/master/src/gflags.cc#L292 - static const std::unordered_set kTrue = {"1", "t", "true", "y", "yes", "T", "True", "TRUE", "Y", "yes"}; + static const std::unordered_set kTrue = { + "1", "t", "true", "y", "yes", "T", "True", "TRUE", "Y", "yes"}; return kTrue.count(flag); } @@ -186,13 +207,14 @@ unsigned long long RandomSeed::GetOrSet(unsigned long long seed) { unsigned long long RandomSeed::Clear() { auto old_seed = seed_; - seed_ = 0ULL; + seed_ = 0ULL; return old_seed; } bool CanUseNvccCompiler() { std::string nvcc_dir = FLAGS_cinn_nvcc_cmd_path + "/nvcc"; - return (access(nvcc_dir.c_str(), 0) == -1 ? false : true) && (!FLAGS_cinn_compile_with_nvrtc); + return (access(nvcc_dir.c_str(), 0) == -1 ? false : true) && + (!FLAGS_cinn_compile_with_nvrtc); } bool IsCompiledWithCUDA() { @@ -215,7 +237,8 @@ common::Target CurrentTarget::target_ = common::DefaultTarget(); void CurrentTarget::SetCurrentTarget(const common::Target& target) { if (!IsCompiledWithCUDA() && target.arch == common::Target::Arch::NVGPU) { - LOG(FATAL) << "Current CINN version does not support NVGPU, please try to recompile with -DWITH_CUDA."; + LOG(FATAL) << "Current CINN version does not support NVGPU, please try to " + "recompile with -DWITH_CUDA."; } else { target_ = target; } diff --git a/paddle/cinn/runtime/flags.h b/paddle/cinn/runtime/flags.h index c13103a3335bc..892644cdb54b5 100644 --- a/paddle/cinn/runtime/flags.h +++ b/paddle/cinn/runtime/flags.h @@ -35,7 +35,7 @@ class RandomSeed { static unsigned long long Clear(); private: - RandomSeed() = default; + RandomSeed() = default; RandomSeed(const RandomSeed &) = delete; RandomSeed &operator=(const RandomSeed &) = delete; @@ -51,7 +51,7 @@ class CurrentTarget { static void SetCurrentTarget(const common::Target &target); private: - CurrentTarget() = default; + CurrentTarget() = default; CurrentTarget(const CurrentTarget &) = delete; CurrentTarget &operator=(const CurrentTarget &) = delete; diff --git a/paddle/cinn/runtime/intrinsic.cc b/paddle/cinn/runtime/intrinsic.cc index d9f80996d754b..41e12331650b6 100644 --- a/paddle/cinn/runtime/intrinsic.cc +++ b/paddle/cinn/runtime/intrinsic.cc @@ -60,7 +60,13 @@ Expr IntrinsicCall(Type type, const std::string& fn_name, const std::vector& args, const std::vector& write_args) { - return ir::Call::Make(type, fn_name, args, write_args, ir::CallType::Intrinsic, ir::FunctionRef(), 0); + return ir::Call::Make(type, + fn_name, + args, + write_args, + ir::CallType::Intrinsic, + ir::FunctionRef(), + 0); } } // namespace runtime diff --git a/paddle/cinn/runtime/intrinsic.h b/paddle/cinn/runtime/intrinsic.h index 49798f9bdc84c..00507dbc405fd 100644 --- a/paddle/cinn/runtime/intrinsic.h +++ b/paddle/cinn/runtime/intrinsic.h @@ -21,7 +21,8 @@ #include "paddle/cinn/runtime/intrinsic_types.h" /** - * \file This file implements some runtime concepts used in analysis and codegen. + * \file This file implements some runtime concepts used in analysis and + * codegen. */ namespace cinn { @@ -40,59 +41,62 @@ static const char* buffer_destroy = "cinn_buffer_t::delete_"; static const char* buffer_load = "cinn_buffer_load"; -static const char* buffer_malloc = "cinn_buffer_malloc"; -static const char* buffer_free = "cinn_buffer_free"; +static const char* buffer_malloc = "cinn_buffer_malloc"; +static const char* buffer_free = "cinn_buffer_free"; static const char* buffer_create_default = "cinn_buffer_new_default"; -static const char* buffer_get_data_handle = "cinn_buffer_get_data_handle"; -static const char* buffer_get_data_const_handle = "cinn_buffer_get_data_const_handle"; +static const char* buffer_get_data_handle = "cinn_buffer_get_data_handle"; +static const char* buffer_get_data_const_handle = + "cinn_buffer_get_data_const_handle"; //! Buffer load an element of some primitive type // @{ static const char* buffer_load_bfloat16 = "buffer_load_bfloat16"; -static const char* buffer_load_float16 = "buffer_load_float16"; -static const char* buffer_load_float32 = "buffer_load_float32"; -static const char* buffer_load_float64 = "buffer_load_float64"; +static const char* buffer_load_float16 = "buffer_load_float16"; +static const char* buffer_load_float32 = "buffer_load_float32"; +static const char* buffer_load_float64 = "buffer_load_float64"; // @} static const char* pod_value_ty = "cinn_pod_value_t"; -static const char* float_to_cinn_pod_value_repr = "float_to_cinn_pod_value"; -static const char* double_to_cinn_pod_value_repr = "double_to_cinn_pod_value"; -static const char* bfloat16_to_cinn_pod_value_repr = "bfloat16_to_cinn_pod_value"; -static const char* float16_to_cinn_pod_value_repr = "float16_to_cinn_pod_value"; +static const char* float_to_cinn_pod_value_repr = "float_to_cinn_pod_value"; +static const char* double_to_cinn_pod_value_repr = "double_to_cinn_pod_value"; +static const char* bfloat16_to_cinn_pod_value_repr = + "bfloat16_to_cinn_pod_value"; +static const char* float16_to_cinn_pod_value_repr = "float16_to_cinn_pod_value"; static const char* bool_to_cinn_pod_value_repr = "bool_to_cinn_pod_value"; -static const char* int8_to_cinn_pod_value_repr = "int8_to_cinn_pod_value"; +static const char* int8_to_cinn_pod_value_repr = "int8_to_cinn_pod_value"; static const char* int16_to_cinn_pod_value_repr = "int16_to_cinn_pod_value"; static const char* int32_to_cinn_pod_value_repr = "int32_to_cinn_pod_value"; static const char* int64_to_cinn_pod_value_repr = "int64_to_cinn_pod_value"; -static const char* uint8_to_cinn_pod_value_repr = "uint8_to_cinn_pod_value"; +static const char* uint8_to_cinn_pod_value_repr = "uint8_to_cinn_pod_value"; static const char* uint16_to_cinn_pod_value_repr = "uint16_to_cinn_pod_value"; static const char* uint32_to_cinn_pod_value_repr = "uint32_to_cinn_pod_value"; static const char* uint64_to_cinn_pod_value_repr = "uint64_to_cinn_pod_value"; -static const char* buffer_p_to_cinn_pod_value_repr = "buffer_p_to_cinn_pod_value"; +static const char* buffer_p_to_cinn_pod_value_repr = + "buffer_p_to_cinn_pod_value"; static const char* pod_value_to_buffer_p = "cinn_pod_value_to_buffer_p"; -static const char* pod_value_to_bool = "cinn_pod_value_to_bool"; +static const char* pod_value_to_bool = "cinn_pod_value_to_bool"; -static const char* pod_value_to_int8 = "cinn_pod_value_to_int8"; +static const char* pod_value_to_int8 = "cinn_pod_value_to_int8"; static const char* pod_value_to_int16 = "cinn_pod_value_to_int16"; static const char* pod_value_to_int32 = "cinn_pod_value_to_int32"; static const char* pod_value_to_int64 = "cinn_pod_value_to_int64"; -static const char* pod_value_to_uint8 = "cinn_pod_value_to_uint8"; +static const char* pod_value_to_uint8 = "cinn_pod_value_to_uint8"; static const char* pod_value_to_uint16 = "cinn_pod_value_to_uint16"; static const char* pod_value_to_uint32 = "cinn_pod_value_to_uint32"; static const char* pod_value_to_uint64 = "cinn_pod_value_to_uint64"; -static const char* pod_value_to_float = "cinn_pod_value_to_float"; -static const char* pod_value_to_double = "cinn_pod_value_to_double"; +static const char* pod_value_to_float = "cinn_pod_value_to_float"; +static const char* pod_value_to_double = "cinn_pod_value_to_double"; static const char* pod_value_to_bfloat16 = "cinn_pod_value_to_bfloat16"; -static const char* pod_value_to_float16 = "cinn_pod_value_to_float16"; +static const char* pod_value_to_float16 = "cinn_pod_value_to_float16"; static const char* pod_value_to_void_p = "cinn_pod_value_to_void_p"; diff --git a/paddle/cinn/runtime/intrinsic_types.h b/paddle/cinn/runtime/intrinsic_types.h index 15b417fde84ad..6a6c460e6323c 100644 --- a/paddle/cinn/runtime/intrinsic_types.h +++ b/paddle/cinn/runtime/intrinsic_types.h @@ -26,12 +26,15 @@ namespace runtime { * Type representation for cinn_buffer_t. */ struct BufferType { - static BufferType Create(const Type& primitive) { return BufferType(primitive); } + static BufferType Create(const Type& primitive) { + return BufferType(primitive); + } static Type cinn_type(); private: - explicit BufferType(const Type& primitive_type) : primitive_type(primitive_type) { + explicit BufferType(const Type& primitive_type) + : primitive_type(primitive_type) { CHECK(primitive_type.valid()); CHECK(primitive_type.is_primitive()); } diff --git a/paddle/cinn/runtime/tiny_runtime.cc b/paddle/cinn/runtime/tiny_runtime.cc index 303f903e4c7d9..efb4c2075ec0a 100644 --- a/paddle/cinn/runtime/tiny_runtime.cc +++ b/paddle/cinn/runtime/tiny_runtime.cc @@ -65,10 +65,10 @@ void *load_program(const char *paramfile) { ctx->major_v = *(int *)(buf + 4); ctx->minor_v = *(int *)(buf + 8); - int *namelist_pos = (int *)(buf + 16); - int *podvalue_pos = (int *)(buf + *namelist_pos); + int *namelist_pos = (int *)(buf + 16); + int *podvalue_pos = (int *)(buf + *namelist_pos); int *persistent_pos = (int *)(buf + *podvalue_pos); - int *inst_pos = (int *)(buf + *persistent_pos); + int *inst_pos = (int *)(buf + *persistent_pos); if (fsize < *inst_pos) { return nullptr; } @@ -77,8 +77,8 @@ void *load_program(const char *paramfile) { std::vector namev(namelen); std::map name2index; for (int i = 0; i < namelen; i++) { - int offset = (namelist_pos + 2)[i]; - namev[i] = (char *)(buf + offset); + int offset = (namelist_pos + 2)[i]; + namev[i] = (char *)(buf + offset); name2index[namev[i]] = i; } @@ -106,7 +106,8 @@ void *load_program(const char *paramfile) { ctx->instructions.push_back(inst); int instargc = inst_pos[2 + i * 3 + 1]; ctx->inst_argc.push_back(instargc); - cinn_pod_value_t *argv = (cinn_pod_value_t *)(buf + inst_pos[2 + i * 3 + 2]); + cinn_pod_value_t *argv = + (cinn_pod_value_t *)(buf + inst_pos[2 + i * 3 + 2]); for (int i = 0; i < instargc; i++) { int idx = (uintptr_t)((cinn_buffer_t *)argv[i]); cinn_value_t tmp_v; @@ -119,7 +120,7 @@ void *load_program(const char *paramfile) { } int set_maxconcurrency(int c) { - int old_c = max_num_workers; + int old_c = max_num_workers; max_num_workers = c; return old_c; } @@ -129,8 +130,8 @@ void run_program(void *ctx) { param_context_t *pc = (param_context_t *)ctx; for (int i = 0; i < pc->instructions.size(); i++) { const char *sym = pc->instructions[i].c_str(); - void *p = dlsym(RTLD_DEFAULT, sym); - func_t f = (func_t)p; + void *p = dlsym(RTLD_DEFAULT, sym); + func_t f = (func_t)p; f(pc->inst_argv[i], pc->inst_argc[i]); } } @@ -144,7 +145,9 @@ cinn_pod_value_t *get_pod_value(void *ctx, const char *tname) { } typedef int (*FCINNParallelLambda)(int task_id, int num_task, void *datas); -int cinn_backend_parallel_launch(FCINNParallelLambda flambda, void *datas, int num_task) { +int cinn_backend_parallel_launch(FCINNParallelLambda flambda, + void *datas, + int num_task) { int num_workers = max_num_workers; if (num_task == 0) num_task = num_workers; omp_set_num_threads(num_task); diff --git a/paddle/cinn/utils/data_util.cc b/paddle/cinn/utils/data_util.cc index 515c381e3381d..5066395305f75 100644 --- a/paddle/cinn/utils/data_util.cc +++ b/paddle/cinn/utils/data_util.cc @@ -18,7 +18,11 @@ namespace cinn { -void SetRandInt(hlir::framework::Tensor tensor, const common::Target& target, int seed, int low, int high) { +void SetRandInt(hlir::framework::Tensor tensor, + const common::Target& target, + int seed, + int low, + int high) { if (seed == -1) { std::random_device rd; seed = rd(); @@ -34,7 +38,10 @@ void SetRandInt(hlir::framework::Tensor tensor, const common::Target& target, in auto* data = tensor->mutable_data(target); #ifdef CINN_WITH_CUDA if (target == common::DefaultNVGPUTarget()) { - cudaMemcpy(data, random_data.data(), num_ele * sizeof(int), cudaMemcpyHostToDevice); + cudaMemcpy(data, + random_data.data(), + num_ele * sizeof(int), + cudaMemcpyHostToDevice); return; } #endif @@ -43,7 +50,9 @@ void SetRandInt(hlir::framework::Tensor tensor, const common::Target& target, in } template <> -void SetRandData(hlir::framework::Tensor tensor, const common::Target& target, int seed) { +void SetRandData(hlir::framework::Tensor tensor, + const common::Target& target, + int seed) { if (seed == -1) { std::random_device rd; seed = rd(); @@ -59,7 +68,10 @@ void SetRandData(hlir::framework::Tensor tensor, const common::Target& targ auto* data = tensor->mutable_data(target); #ifdef CINN_WITH_CUDA if (target == common::DefaultNVGPUTarget()) { - cudaMemcpy(data, random_data.data(), num_ele * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(data, + random_data.data(), + num_ele * sizeof(float), + cudaMemcpyHostToDevice); return; } #endif @@ -68,7 +80,9 @@ void SetRandData(hlir::framework::Tensor tensor, const common::Target& targ } template <> -void SetRandData(hlir::framework::Tensor tensor, const common::Target& target, int seed) { +void SetRandData(hlir::framework::Tensor tensor, + const common::Target& target, + int seed) { if (seed == -1) { std::random_device rd; seed = rd(); @@ -84,7 +98,10 @@ void SetRandData(hlir::framework::Tensor tensor, const common::Target& ta auto* data = tensor->mutable_data(target); #ifdef CINN_WITH_CUDA if (target == common::DefaultNVGPUTarget()) { - cudaMemcpy(data, random_data.data(), num_ele * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(data, + random_data.data(), + num_ele * sizeof(float), + cudaMemcpyHostToDevice); } else if (target == common::DefaultHostTarget()) { std::copy(random_data.begin(), random_data.end(), data); } else { @@ -97,12 +114,16 @@ void SetRandData(hlir::framework::Tensor tensor, const common::Target& ta } template -std::vector GetTensorData(const hlir::framework::Tensor& tensor, const common::Target& target) { +std::vector GetTensorData(const hlir::framework::Tensor& tensor, + const common::Target& target) { auto size = tensor->shape().numel(); std::vector data(size); #ifdef CINN_WITH_CUDA if (target == common::DefaultNVGPUTarget()) { - cudaMemcpy(data.data(), static_cast(tensor->data()), size * sizeof(T), cudaMemcpyDeviceToHost); + cudaMemcpy(data.data(), + static_cast(tensor->data()), + size * sizeof(T), + cudaMemcpyDeviceToHost); } else if (target == common::DefaultHostTarget()) { std::copy(tensor->data(), tensor->data() + size, data.begin()); } else { @@ -115,7 +136,9 @@ std::vector GetTensorData(const hlir::framework::Tensor& tensor, const common return data; } -template std::vector GetTensorData(const hlir::framework::Tensor& tensor, const common::Target& target); -template std::vector GetTensorData(const hlir::framework::Tensor& tensor, const common::Target& target); +template std::vector GetTensorData( + const hlir::framework::Tensor& tensor, const common::Target& target); +template std::vector GetTensorData( + const hlir::framework::Tensor& tensor, const common::Target& target); } // namespace cinn diff --git a/paddle/cinn/utils/data_util.h b/paddle/cinn/utils/data_util.h index 0d1b65042e799..a55ad554579f1 100644 --- a/paddle/cinn/utils/data_util.h +++ b/paddle/cinn/utils/data_util.h @@ -25,21 +25,31 @@ namespace cinn { /** - * @brief Fill an int Tensor with random data, which is going to be [low, high). + * @brief Fill an int Tensor with random data, which is going to be [low, + * high). * - * @param tensor A Tensor that needs to be filled with data has to be of type Int. + * @param tensor A Tensor that needs to be filled with data has to be of type + * Int. * @param target The type of device that tensor need. * @param seed Random number seed. Default setting is -1. - * @param low Set the lower bound of the data range, which is represented as [low, high). - * @param high Set the upper bound of the data range, which is represented as [low, high). + * @param low Set the lower bound of the data range, which is represented as + * [low, high). + * @param high Set the upper bound of the data range, which is represented as + * [low, high). */ -void SetRandInt( - hlir::framework::Tensor tensor, const common::Target& target, int seed = -1, int low = 0, int high = 11); +void SetRandInt(hlir::framework::Tensor tensor, + const common::Target& target, + int seed = -1, + int low = 0, + int high = 11); template -void SetRandData(hlir::framework::Tensor tensor, const common::Target& target, int seed = -1); +void SetRandData(hlir::framework::Tensor tensor, + const common::Target& target, + int seed = -1); template -std::vector GetTensorData(const hlir::framework::Tensor& tensor, const common::Target& target); +std::vector GetTensorData(const hlir::framework::Tensor& tensor, + const common::Target& target); } // namespace cinn diff --git a/paddle/cinn/utils/dot_lang.cc b/paddle/cinn/utils/dot_lang.cc index 184eed21a0a64..f36f90b781f12 100644 --- a/paddle/cinn/utils/dot_lang.cc +++ b/paddle/cinn/utils/dot_lang.cc @@ -25,7 +25,7 @@ size_t dot_node_counter{0}; size_t dot_cluster_counter{0}; void ResetDotCounters() { - dot_node_counter = 0; + dot_node_counter = 0; dot_cluster_counter = 0; } @@ -35,7 +35,9 @@ std::string DotAttr::repr() const { return ss.str(); } -DotNode::DotNode(const std::string& name, const std::vector& attrs, const std::string& cluster_id) +DotNode::DotNode(const std::string& name, + const std::vector& attrs, + const std::string& cluster_id) : name(name), attrs(attrs), cluster_id_(cluster_id) { std::stringstream ss; ss << "node_" << dot_node_counter++; @@ -60,7 +62,9 @@ std::string DotNode::repr() const { return ss.str(); } -DotCluster::DotCluster(const std::string& name, const std::vector& attrs) : name(name), attrs(attrs) { +DotCluster::DotCluster(const std::string& name, + const std::vector& attrs) + : name(name), attrs(attrs) { std::stringstream ss; ss << "cluster_" << dot_cluster_counter++; id_ = ss.str(); @@ -102,16 +106,21 @@ void DotLang::AddNode(const std::string& id, } } -void DotLang::AddCluster(const std::string& id, const std::vector& attrs) { +void DotLang::AddCluster(const std::string& id, + const std::vector& attrs) { CHECK(!clusters_.count(id)) << "duplicate Cluster '" << id << "'"; clusters_.emplace(id, DotCluster{id, attrs}); } -void DotLang::AddEdge(const std::string& source, const std::string& target, const std::vector& attrs) { +void DotLang::AddEdge(const std::string& source, + const std::string& target, + const std::vector& attrs) { CHECK(!source.empty()); CHECK(!target.empty()); - CHECK(nodes_.find(source) != nodes_.end()) << "Call AddNode to add " << source << " to dot first"; - CHECK(nodes_.find(target) != nodes_.end()) << "Call AddNode to add " << target << " to dot first"; + CHECK(nodes_.find(source) != nodes_.end()) + << "Call AddNode to add " << source << " to dot first"; + CHECK(nodes_.find(target) != nodes_.end()) + << "Call AddNode to add " << target << " to dot first"; auto sid = nodes_.at(source).id(); auto tid = nodes_.at(target).id(); edges_.emplace_back(sid, tid, attrs); diff --git a/paddle/cinn/utils/dot_lang.h b/paddle/cinn/utils/dot_lang.h index 6b7abfcf5f566..786f83ce7524f 100644 --- a/paddle/cinn/utils/dot_lang.h +++ b/paddle/cinn/utils/dot_lang.h @@ -46,9 +46,9 @@ class DotLang { */ void AddNode(const std::string& id, const std::vector& attrs, - std::string label = "", + std::string label = "", std::string cluster_id = "", - bool allow_duplicate = false); + bool allow_duplicate = false); /** * Add a subgraph to the DOT graph. @@ -63,7 +63,9 @@ class DotLang { * @param target The id of the sink of the edge. * @param attrs The attributes of the edge. */ - void AddEdge(const std::string& source, const std::string& target, const std::vector& attrs); + void AddEdge(const std::string& source, + const std::string& target, + const std::vector& attrs); std::string operator()() const { return Build(); } @@ -81,7 +83,8 @@ struct DotAttr { std::string key; std::string value; - DotAttr(const std::string& key, const std::string& value) : key(key), value(value) {} + DotAttr(const std::string& key, const std::string& value) + : key(key), value(value) {} std::string repr() const; }; @@ -91,7 +94,9 @@ struct DotNode { std::vector attrs; DotNode() = default; - DotNode(const std::string& name, const std::vector& attrs, const std::string& cluster_id); + DotNode(const std::string& name, + const std::vector& attrs, + const std::string& cluster_id); std::string id() const { return id_; } std::string cluster_id() const { return cluster_id_; } @@ -125,7 +130,9 @@ struct DotEdge { std::string target; std::vector attrs; - DotEdge(const std::string& source, const std::string& target, const std::vector& attrs) + DotEdge(const std::string& source, + const std::string& target, + const std::vector& attrs) : source(source), target(target), attrs(attrs) {} std::string repr() const; diff --git a/paddle/cinn/utils/event.cc b/paddle/cinn/utils/event.cc index bf731496eb976..6f319bf5e44a2 100644 --- a/paddle/cinn/utils/event.cc +++ b/paddle/cinn/utils/event.cc @@ -58,12 +58,12 @@ std::string Summary::Format(const std::vector &events) { std::unordered_map unique_items; std::unordered_map category_cost; - double total_cost = 0.0; + double total_cost = 0.0; size_t max_annot_size = 20; for (auto &e : events) { if (unique_items.count(e.annotation_) == 0U) { items.emplace_back(e); - unique_items[e.annotation_] = &items.back(); + unique_items[e.annotation_] = &items.back(); unique_items.at(e.annotation_)->info.duration_ = 0.0; } // Sum cost for category @@ -76,7 +76,8 @@ std::string Summary::Format(const std::vector &events) { } // Calculate Ratio for (auto &item : items) { - item.sub_raito = item.info.duration_ / category_cost[item.info.type_] * 100.0; + item.sub_raito = + item.info.duration_ / category_cost[item.info.type_] * 100.0; item.total_raito = item.info.duration_ / total_cost * 100.0; } @@ -88,13 +89,18 @@ std::string Summary::Format(const std::vector &events) { std::string Summary::AsStr(const std::vector &items, int data_width) { std::ostringstream os; - os << "\n\n-------------------------> Profiling Report <-------------------------\n\n"; + os << "\n\n-------------------------> Profiling Report " + "<-------------------------\n\n"; - std::vector titles = {"Category", "Name", "CostTime(ms)", "Ratio in Category(%)", "Ratio in Total(%)"}; - std::vector widths = {20, data_width, 20, 20, 20}; + std::vector titles = {"Category", + "Name", + "CostTime(ms)", + "Ratio in Category(%)", + "Ratio in Total(%)"}; + std::vector widths = {20, data_width, 20, 20, 20}; size_t pad_size = 0; - int idx = 0; + int idx = 0; for (auto &t : titles) { pad_size = widths[idx] >= t.size() ? widths[idx] - t.size() : 1; os << ' ' << t << std::string(pad_size, ' '); @@ -109,7 +115,7 @@ std::string Summary::AsStr(const std::vector &items, int data_width) { std::to_string(item.info.duration_), item.sub_raito.ToStr(), item.total_raito.ToStr()}; - idx = 0; + idx = 0; for (auto &info : infos) { pad_size = widths[idx] > info.size() ? widths[idx] - info.size() : 1; os << ' ' << info << std::string(pad_size, ' '); diff --git a/paddle/cinn/utils/event.h b/paddle/cinn/utils/event.h index 1f1212e468f70..dad99634bf41e 100644 --- a/paddle/cinn/utils/event.h +++ b/paddle/cinn/utils/event.h @@ -79,7 +79,9 @@ class Summary { Raito total_raito{0.0}; // precentage of total process Item(const HostEvent& e) : info(e) {} - bool operator<(const Item& other) const { return total_raito.value > other.total_raito.value; } + bool operator<(const Item& other) const { + return total_raito.value > other.total_raito.value; + } }; static std::string Format(const std::vector& events); @@ -101,7 +103,9 @@ class HostEventRecorder { std::vector& Events() { return events_; } - void RecordEvent(const std::string& annotation, double duration, EventType type) { + void RecordEvent(const std::string& annotation, + double duration, + EventType type) { GetInstance().Events().emplace_back(annotation, duration, type); } diff --git a/paddle/cinn/utils/functional.cc b/paddle/cinn/utils/functional.cc index 44826cae9da4d..9fd5799bc6e87 100644 --- a/paddle/cinn/utils/functional.cc +++ b/paddle/cinn/utils/functional.cc @@ -23,8 +23,9 @@ std::vector GetPositiveAxes(const std::vector& axes, int rank) { std::vector new_axes(axes.size()); for (int i = 0; i < axes.size(); ++i) { int axis = axes[i] + (axes[i] < 0 ? rank : 0); - CHECK(axis >= 0 && axis < rank) << "The axis should in [" << -rank << ", " << rank << "), but axes[" << i - << "]=" << axes[i] << " not."; + CHECK(axis >= 0 && axis < rank) + << "The axis should in [" << -rank << ", " << rank << "), but axes[" + << i << "]=" << axes[i] << " not."; new_axes[i] = axis; } return new_axes; @@ -32,7 +33,8 @@ std::vector GetPositiveAxes(const std::vector& axes, int rank) { int GetPositiveAxes(int axis, int rank) { int dim = axis + (axis < 0 ? rank : 0); - CHECK(dim >= 0 && dim < rank) << "The axis should in [0, " << rank << "), but axis=" << axis << " not."; + CHECK(dim >= 0 && dim < rank) + << "The axis should in [0, " << rank << "), but axis=" << axis << " not."; return dim; } diff --git a/paddle/cinn/utils/functional.h b/paddle/cinn/utils/functional.h index ff92b06f26719..79f71cc10c386 100644 --- a/paddle/cinn/utils/functional.h +++ b/paddle/cinn/utils/functional.h @@ -26,10 +26,14 @@ namespace cinn { namespace utils { template -std::vector Map(const InT &in, std::function fn) { +std::vector Map( + const InT &in, + std::function fn) { std::vector res; - std::transform( - in.begin(), in.end(), std::back_inserter(res), [&](const typename InT::value_type &x) { return fn(x); }); + std::transform(in.begin(), + in.end(), + std::back_inserter(res), + [&](const typename InT::value_type &x) { return fn(x); }); return res; } @@ -56,32 +60,42 @@ auto Max(T &&t, Ts &&...ts) { template struct IsVector { template - static auto infer(U *) - -> std::enable_if_t, U>::value, std::true_type>; + static auto infer(U *) -> std::enable_if_t< + std::is_same, U>::value, + std::true_type>; template static std::false_type infer(...); - static constexpr bool value = decltype(infer>>(nullptr))::value; + static constexpr bool value = + decltype(infer>>(nullptr))::value; }; template -struct IsString : std::integral_constant>::value> {}; +struct IsString : std::integral_constant< + bool, + std::is_same>::value> {}; template auto Flatten(const absl::optional> &c) - -> std::enable_if_t::value || IsString::value, std::vector> { + -> std::enable_if_t::value || IsString::value, + std::vector> { return c ? std::vector{c->get()} : std::vector{}; } template