Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
add FLAGS_cinn_parallel_compile_thread to control parallel compile th…
Browse files Browse the repository at this point in the history
…read number (#1377)
  • Loading branch information
thisjiang authored Apr 27, 2023
1 parent 41ba0d6 commit 1d7fad1
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
25 changes: 17 additions & 8 deletions cinn/hlir/framework/parallel_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,14 @@
#include "cinn/ir/module.h"

DECLARE_int32(cinn_parallel_compile_size);
DECLARE_int32(cinn_parallel_compile_thread);

namespace cinn {
namespace hlir {
namespace framework {
static constexpr int DebugLogMaxLen = 30000;

std::vector<std::unique_ptr<Instruction>> ParallelCompiler::operator()() {
if (!FLAGS_cinn_parallel_compile_size) {
return std::vector<std::unique_ptr<Instruction>>();
}
if (graph_->fusion_groups.size() == 0) {
hlir::framework::ApplyPasses(graph_.get(), {"BuildNonFusedGroupsPass"});
}
Expand Down Expand Up @@ -70,9 +68,17 @@ void ParallelCompiler::SplitTask() {
CHECK(graph_->fusion_groups.size());
CHECK(graph_->fusion_groups.size() == option_.lowered_funcs.size() || option_.lowered_funcs.size() == 0);
// split task
int num_per_task = std::max((graph_->fusion_groups.size() - 1) / FLAGS_cinn_parallel_compile_size + 1, 16UL);
int max_task_num =
FLAGS_cinn_parallel_compile_thread > 0 ? FLAGS_cinn_parallel_compile_thread : graph_->fusion_groups.size();

int group_per_task = graph_->fusion_groups.size();
if (max_task_num > 1) {
group_per_task = FLAGS_cinn_parallel_compile_size > 0
? FLAGS_cinn_parallel_compile_size
: ((graph_->fusion_groups.size() + max_task_num - 1) / max_task_num);
}

for (int idx = 0; idx < graph_->fusion_groups.size(); idx += num_per_task) {
for (int idx = 0; idx < graph_->fusion_groups.size(); idx += group_per_task) {
tasks_.emplace_back(this, scope_, graph_, option_, target_);
}
VLOG(2) << "Split task to " << tasks_.size() << " sub-task!";
Expand Down Expand Up @@ -133,15 +139,17 @@ void ParallelCompiler::Task::Lowering() {
continue;
}
auto& group = graph->fusion_groups[idx];
VLOG(1) << "=============================================";
VLOG(1) << "Lowering Group:\n" << graph->DebugGroupedGraph(group->CollectNodes());
VLOG(1) << "=============================================";
VLOG(1) << "Start Lowering Group " << idx << " at " << std::this_thread::get_id() << " :\n"
<< "Group " << idx << " {\n"
<< graph->DebugGroupedGraph(group->CollectNodes()) << "}\n";
lowered_funcs.emplace_back(std::move(op_lowerer.Lower(group)));
CHECK_EQ(lowered_funcs.back().size(), 1) << "Lowerd Function Is Not Equal 1!";
}
}

void ParallelCompiler::Task::CodegenAndJit() {
VLOG(2) << "Start Codegen and JIT with Group [" << cinn::utils::Join(this->gidx, ", ") << "] at "
<< std::this_thread::get_id();
// build module
ir::Module::Builder builder(common::UniqName("module"), target);
for (auto& func : lowered_funcs) {
Expand Down Expand Up @@ -193,6 +201,7 @@ void ParallelCompiler::Task::CodegenAndJit() {
void ParallelCompiler::Task::BuildInstruction() {
// create instruction.
for (int idx : gidx) {
VLOG(2) << "Start BuildInstruction of Group " << idx << " at " << std::this_thread::get_id();
auto& group = graph->fusion_groups[idx];
CHECK(group->input_names.size() > 0 || group->output_names.size() > 0);
auto instr = std::unique_ptr<Instruction>(
Expand Down
2 changes: 1 addition & 1 deletion cinn/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void Module::Builder::Clear() {

Module Module::Builder::Build() {
if (module_->functions.empty()) {
LOG(ERROR) << "Module has no functions";
VLOG(1) << "Module has no functions";
}

auto res = ir::Module(module_.get());
Expand Down
7 changes: 5 additions & 2 deletions cinn/runtime/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ using ::GFLAGS_NAMESPACE::StringFromEnv;
DEFINE_string(cinn_x86_builtin_code_root, StringFromEnv("FLAGS_cinn_x86_builtin_code_root", ""), "");

DEFINE_int32(cinn_parallel_compile_size,
// Revert changes in PR #990 to pass the model unittests
Int32FromEnv("FLAGS_cinn_parallel_compile_size", 8),
Int32FromEnv("FLAGS_cinn_parallel_compile_size", 16),
"When use parallel compile, set the number of group compiled by each thread.");

DEFINE_int32(cinn_parallel_compile_thread,
Int32FromEnv("FLAGS_cinn_parallel_compile_thread", -1),
"How much thread the parallel compile used.");

DEFINE_bool(cinn_use_op_fusion, BoolFromEnv("FLAGS_cinn_use_op_fusion", true), "Whether to use op fusion pass.");

DEFINE_bool(cinn_use_common_subexpression_elimination,
Expand Down

0 comments on commit 1d7fad1

Please sign in to comment.