Skip to content

Commit

Permalink
[xla:cpu] Make parallel compilation configurable with XLA flags
Browse files Browse the repository at this point in the history
+ bump default number of splits up to 32

PiperOrigin-RevId: 666487172
  • Loading branch information
ezhulenev authored and copybara-github committed Aug 22, 2024
1 parent 11da26a commit cc369fd
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 11 deletions.
7 changes: 7 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_cpu_use_acl(true);
#endif
opts.set_xla_cpu_use_thunk_runtime(true);
opts.set_xla_cpu_parallel_codegen_split_count(32);
opts.set_xla_cpu_enable_concurrency_optimized_scheduler(false);
opts.set_xla_cpu_prefer_vector_width(256);

Expand Down Expand Up @@ -810,6 +811,12 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
bool_setter_for(&DebugOptions::set_xla_cpu_use_thunk_runtime),
debug_options->xla_cpu_use_thunk_runtime(),
"Use Thunk-based runtime for the CPU backend."));
flag_list->push_back(tsl::Flag(
"xla_cpu_parallel_codegen_split_count",
int32_setter_for(&DebugOptions::set_xla_cpu_parallel_codegen_split_count),
debug_options->xla_cpu_parallel_codegen_split_count(),
"Split LLVM module into at most this many parts before codegen to enable "
"parallel compilation for the CPU backend."));
flag_list->push_back(tsl::Flag(
"xla_cpu_enable_concurrency_optimized_scheduler",
bool_setter_for(
Expand Down
23 changes: 13 additions & 10 deletions xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1149,10 +1149,8 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {

// We split LLVM module and distribute it across separate DyLibs to enable
// parallel compilation at run time.
//
// TODO(b/359659598): Increase the number of dylibs to 32 and make compilation
// truly parallel. For now we use 2 dylibs to do basic sanity check.
constexpr size_t num_jit_dylibs = 2;
size_t parallel_codegen_split_count =
debug_options.xla_cpu_parallel_codegen_split_count();

auto jit = SimpleOrcJIT::Create(
CompilerTargetOptions(module->config()),
Expand All @@ -1163,7 +1161,7 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook,
post_optimization_ir_hook,
CreateOrcJITPostCompilationHook(module.get(), &obj_files),
num_jit_dylibs);
parallel_codegen_split_count);
if (!jit) {
return Internal("Creating JIT failed: %s", llvm::toString(jit.takeError()));
}
Expand Down Expand Up @@ -1279,10 +1277,12 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));

// We define the number of module parts based on the total number of
// external functions (kernels and comparators) that are called from thunks.
// external functions (kernels and comparators) that are called from thunks,
// and the maximum number of parts that we want to split the module into.
size_t num_external_function =
ir_emitter2.kernels().size() + ir_emitter2.comparators().size();
size_t num_parts = std::min(num_external_function, num_jit_dylibs);
size_t num_parts =
std::min(num_external_function, parallel_codegen_split_count);

// JIT compile the LLVM IR module to in-memory machine code. We split the
// module into `num_jit_dylibs` parts to allow parallel compilation. In
Expand All @@ -1293,18 +1293,21 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
llvm::orc::ThreadSafeContext thread_safe_context(std::move(llvm_context));

if (num_parts > 1) {
VLOG(2) << "Splitting module into " << num_parts
<< " parts before codegen to enable parallel compilation";
VLOG(3) << "Splitting LLVM module into " << num_parts
<< " parts before codegen to enable parallel compilation"
<< " (max split count: " << parallel_codegen_split_count << ")";
llvm::SplitModule(
*llvm_module, num_parts,
[&, dylib_index = 0](auto llvm_module_part) mutable {
cantFail((*jit)->AddModule(
llvm::orc::ThreadSafeModule(std::move(llvm_module_part),
thread_safe_context),
dylib_index++ % num_jit_dylibs));
dylib_index++ % parallel_codegen_split_count));
},
/*PreserveLocals=*/true);
} else {
VLOG(3) << "Compiled LLVM module without splitting (max split count: "
<< parallel_codegen_split_count << ")";
cantFail((*jit)->AddModule(llvm::orc::ThreadSafeModule(
std::move(llvm_module), thread_safe_context)));
}
Expand Down
3 changes: 3 additions & 0 deletions xla/service/cpu/simple_orc_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <stdint.h>

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstdio>
Expand Down Expand Up @@ -381,6 +382,8 @@ SimpleOrcJIT::SimpleOrcJIT(
}
};

// Always create at least one dylib.
num_jit_dylibs = std::max(size_t{1}, num_jit_dylibs);
jit_dylibs_.resize(num_jit_dylibs);
for (size_t i = 0; i < num_jit_dylibs; ++i) {
jit_dylibs_[i] = &execution_session_->createBareJITDylib(
Expand Down
7 changes: 6 additions & 1 deletion xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ message DebugOptions {
// When true, XLA:CPU uses the thunk runtime to execute compiled program.
bool xla_cpu_use_thunk_runtime = 298;

// The number of parts to split the LLVM module into before codegen. This
// allows XLA to compile all parts in parallel, and resolve kernel symbols
// from different dynamic libraries.
int32 xla_cpu_parallel_codegen_split_count = 323;

// A `prefer-vector-width` value that is passed to the LLVM backend. Default
// value is `256` (AVX2 on x86 platforms).
int32 xla_cpu_prefer_vector_width = 308;
Expand Down Expand Up @@ -937,7 +942,7 @@ message DebugOptions {
// effort to parallelize matrix operations.
bool xla_gpu_async_dot = 321;

// Next id: 323
// Next id: 324

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit cc369fd

Please sign in to comment.