Skip to content

Commit

Permalink
[xla:cpu] Remove unused globals and function declarations from LLVM m…
Browse files Browse the repository at this point in the history
…odules after split

PiperOrigin-RevId: 666355364
  • Loading branch information
ezhulenev authored and copybara-github committed Aug 22, 2024
1 parent e7885c1 commit 4eb3e66
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 52 deletions.
3 changes: 3 additions & 0 deletions xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -987,8 +987,10 @@ cc_library(
"//xla:util",
"//xla/service:llvm_compiler",
"//xla/service/llvm_ir:llvm_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@llvm-project//llvm:Analysis",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Instrumentation",
Expand All @@ -998,6 +1000,7 @@ cc_library(
"@llvm-project//llvm:Passes",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:TargetParser",
"@tsl//tsl/platform:logging",
],
)
Expand Down
61 changes: 39 additions & 22 deletions xla/service/cpu/compiler_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,33 @@ limitations under the License.

#include "xla/service/cpu/compiler_functor.h"

#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/CGSCCPassManager.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/MCContext.h"
#include "llvm/Object/ObjectFile.h"
#include "llvm/Passes/OptimizationLevel.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SmallVectorMemoryBuffer.h"
#include "llvm/Support/TypeSize.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Instrumentation/DataFlowSanitizer.h"
#include "xla/service/cpu/cpu_runtime.h"
#include "xla/service/cpu/llvm_ir_runtime.h"
#include "xla/service/llvm_ir/llvm_util.h"
#include "xla/types.h"
#include "xla/util.h"
#include "tsl/platform/logging.h"

Expand Down Expand Up @@ -102,8 +103,18 @@ llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> CompilerFunctor::operator()(
VLOG(2) << "IR before optimizations";
XLA_VLOG_LINES(2, llvm_ir::DumpToString(&module));

if (pre_optimization_hook_) {
pre_optimization_hook_(module);
// Get a target machine for compilation. If compilations run concurrently on
// multiple threads, `CompilerFunctor` user (in most cases `SimpleOrcJIT`)
// must guarantee that target machine builder will return a unique
// TargetMachine for each compilation, as it is not thread safe.
std::shared_ptr<llvm::TargetMachine> target_machine =
target_machine_builder_();

{ // Synchronize access to user-defined hooks.
absl::MutexLock lock(&mutex_);
if (pre_optimization_hook_) {
pre_optimization_hook_(module);
}
}

llvm::OptimizationLevel opt_level;
Expand Down Expand Up @@ -140,10 +151,10 @@ llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> CompilerFunctor::operator()(
llvm::StandardInstrumentations si(module.getContext(), false);
si.registerCallbacks(pic, &mam);

llvm::PassBuilder pb(target_machine_, pto, {}, &pic);
llvm::PassBuilder pb(target_machine.get(), pto, {}, &pic);

// Add the appropriate TargetLibraryInfo.
llvm::Triple target_triple(target_machine_->getTargetTriple());
llvm::Triple target_triple(target_machine->getTargetTriple());
auto target_library_info_impl =
std::make_unique<llvm::TargetLibraryInfoImpl>(target_triple);
target_library_info_impl->addVectorizableFunctions(
Expand Down Expand Up @@ -184,26 +195,32 @@ llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> CompilerFunctor::operator()(

VLOG(2) << "IR after optimizations";

if (post_optimization_hook_) {
post_optimization_hook_(module);
{ // Synchronize access to user-defined hooks.
absl::MutexLock lock(&mutex_);
if (post_optimization_hook_) {
post_optimization_hook_(module);
}
}

// Generate code.
llvm::MCContext* mc_context;
llvm::legacy::PassManager codegen_passes;
target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream);
target_machine->addPassesToEmitMC(codegen_passes, mc_context, ostream);
codegen_passes.run(module);

std::unique_ptr<llvm::MemoryBuffer> mc_memory_buffer(
new llvm::SmallVectorMemoryBuffer(std::move(mc_stream_buffer)));

if (post_codegen_hook_) {
llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
llvm::object::ObjectFile::createObjectFile(*mc_memory_buffer);
if (obj_file) {
post_codegen_hook_(*obj_file.get());
} else {
LOG(WARNING) << "Could not convert memory buffer to object file!";
{ // Synchronize access to user-defined hooks.
absl::MutexLock lock(&mutex_);
if (post_codegen_hook_) {
llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
llvm::object::ObjectFile::createObjectFile(*mc_memory_buffer);
if (obj_file) {
post_codegen_hook_(*obj_file.get());
} else {
LOG(WARNING) << "Could not convert memory buffer to object file!";
}
}
}

Expand Down
43 changes: 30 additions & 13 deletions xla/service/cpu/compiler_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,42 @@ limitations under the License.
#ifndef XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_
#define XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_

#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/functional/any_invocable.h"
#include "absl/synchronization/mutex.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "llvm/ExecutionEngine/Orc/Mangling.h"
#include "llvm/IR/FMF.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/Object/ObjectFile.h"
#include "llvm/Support/Error.h"
#include "llvm/Target/TargetMachine.h"
#include "xla/service/llvm_compiler.h"

namespace xla {
namespace cpu {
namespace xla::cpu {

// Functor class for compiling an LLVM module down to an object file. For use by
// Orc JIT compile layer.
class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler {
public:
// Returns an instance of llvm::TargetMachine for a compilation. It can be
// a shared TargetMachine if compilation is single threaded, or must be a
// unique TargetMachine if compilation is multi threaded (because
// TargetMachine is not thread safe).
//
// See `llvm::orc::ConcurrentIRCompiler` to see corresponding API in ORC.
using TargetMachineBuilder =
std::function<std::shared_ptr<llvm::TargetMachine>()>;

explicit CompilerFunctor(
llvm::TargetMachine* target_machine, int opt_level,
TargetMachineBuilder target_machine_builder, int opt_level,
bool optimize_for_size, bool disable_expensive_passes,
bool disable_slp_vectorizer, llvm::FastMathFlags fast_math_flags,
LLVMCompiler::ModuleHook pre_optimization_hook = nullptr,
Expand All @@ -48,37 +61,41 @@ class CompilerFunctor : public llvm::orc::IRCompileLayer::IRCompiler {
bool dfsan_enabled = false,
const std::vector<std::string>& dfsan_abi_list_files = {})
: IRCompiler(llvm::orc::IRSymbolMapper::ManglingOptions()),
target_machine_(target_machine),
target_machine_builder_(std::move(target_machine_builder)),
opt_level_(opt_level),
optimize_for_size_(optimize_for_size),
disable_expensive_passes_(disable_expensive_passes),
disable_slp_vectorizer_(disable_slp_vectorizer),
fast_math_flags_(fast_math_flags),
dfsan_enabled_(dfsan_enabled),
dfsan_abi_list_files_(dfsan_abi_list_files),
pre_optimization_hook_(std::move(pre_optimization_hook)),
post_optimization_hook_(std::move(post_optimization_hook)),
post_codegen_hook_(std::move(post_codegen_hook)),
dfsan_enabled_(dfsan_enabled),
dfsan_abi_list_files_(dfsan_abi_list_files) {}
post_codegen_hook_(std::move(post_codegen_hook)) {}

// Compile a Module to an ObjectFile.
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> operator()(
llvm::Module& module) override;

private:
llvm::TargetMachine* target_machine_;
TargetMachineBuilder target_machine_builder_;
const unsigned opt_level_;
const bool optimize_for_size_;
const bool disable_expensive_passes_;
const bool disable_slp_vectorizer_;
const llvm::FastMathFlags fast_math_flags_;
LLVMCompiler::ModuleHook pre_optimization_hook_;
LLVMCompiler::ModuleHook post_optimization_hook_;
absl::AnyInvocable<void(const llvm::object::ObjectFile&)> post_codegen_hook_;
const bool dfsan_enabled_ = false;
const std::vector<std::string> dfsan_abi_list_files_;

LLVMCompiler::ModuleHook pre_optimization_hook_ ABSL_GUARDED_BY(mutex_);
LLVMCompiler::ModuleHook post_optimization_hook_ ABSL_GUARDED_BY(mutex_);
absl::AnyInvocable<void(const llvm::object::ObjectFile&)> post_codegen_hook_
ABSL_GUARDED_BY(mutex_);

// Synchronizes access to user-defined compilation hooks.
absl::Mutex mutex_;
};

} // namespace cpu
} // namespace xla
} // namespace xla::cpu

#endif // XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_
39 changes: 34 additions & 5 deletions xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Mangler.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -1131,6 +1132,32 @@ CreateConstantAllocations(const BufferAssignment& assignment) {
return constants;
}

// Removes unused globals and function declarations from the LLVM module.
//
// After splitting LLVM module into multiple parts, we end up with unused
// symbols in each part: external globals and function declarations. We don't
// support linking across modules added to SimpleOrcJIT, and we don't need it,
// because we never construct LLVM IR that might require cross-module linking,
// so we can just remove unused symbols from each part.
static void RemoveUnusedSymbols(llvm::Module& module) {
llvm::SmallVector<llvm::GlobalVariable*> unused_globals;
llvm::SmallVector<llvm::Function*> unused_functions;

for (llvm::GlobalVariable& gv : module.globals()) {
if (gv.use_empty()) unused_globals.push_back(&gv);
}
for (llvm::Function& f : module.functions()) {
if (f.isDeclaration() && f.use_empty()) unused_functions.push_back(&f);
}

for (auto* gv : unused_globals) {
module.eraseGlobalVariable(gv);
}
for (auto* f : unused_functions) {
f->eraseFromParent();
}
}

absl::StatusOr<std::unique_ptr<CpuExecutable>>
CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
TraceMe trace([&] {
Expand Down Expand Up @@ -1203,7 +1230,7 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {

// Select an order for emitting the HLO instructions for each
// computation. Using this sequence enables tighter buffer liveness analysis
// and reduced memory usage (as compared to using DependencyHloOrdering).
// and reduced memory usage (as compared to using `DependencyHloOrdering`).
TF_ASSIGN_OR_RETURN(
HloSchedule schedule,
ScheduleModule(module.get(), BufferSizeBytesFunction(),
Expand Down Expand Up @@ -1313,11 +1340,13 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {

llvm::SplitModule(
*llvm_module, num_parts,
[&, dylib_index = 0](auto llvm_module_part) mutable {
[&, n = 0](std::unique_ptr<llvm::Module> llvm_module_part) mutable {
// Remove unused symbols left in the module after splitting.
RemoveUnusedSymbols(*llvm_module_part);
cantFail((*jit)->AddModule(
llvm::orc::ThreadSafeModule(std::move(llvm_module_part),
thread_safe_context),
dylib_index++ % parallel_codegen_split_count));
n++ % parallel_codegen_split_count));
},
/*PreserveLocals=*/true);
} else {
Expand Down Expand Up @@ -1565,7 +1594,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
break;
}
llvm::CodeGenOptLevel opt_level = CodeGenOptLevel(modules[0]->config());
std::unique_ptr<llvm::TargetMachine> target_machine =
std::shared_ptr<llvm::TargetMachine> target_machine =
absl::WrapUnique(target->createTargetMachine(
triple.getTriple(), options.cpu_name(), options.features(),
CompilerTargetOptions(modules[0]->config()), reloc_model,
Expand Down Expand Up @@ -1697,7 +1726,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
};

CompilerFunctor compiler_functor(
target_machine.get(), static_cast<int>(opt_level),
[&] { return target_machine; }, static_cast<int>(opt_level),
options::OptimizeForSizeRequested(module->config()),
module->config().debug_options().xla_llvm_disable_expensive_passes(),
options::SlpVectorizerDisabled(module->config()),
Expand Down
22 changes: 15 additions & 7 deletions xla/service/cpu/simple_orc_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ limitations under the License.
#include "llvm/Support/Memory.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/Process.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Host.h"
#include "mlir/ExecutionEngine/CRunnerUtils.h"
Expand Down Expand Up @@ -95,8 +96,7 @@ extern "C" uint16_t __truncsfbf2(float);
// Converts an F64 value to a BF16.
extern "C" uint16_t __truncdfbf2(double);

namespace xla {
namespace cpu {
namespace xla::cpu {

std::vector<std::string> DetectMachineAttributes() {
std::vector<std::string> result;
Expand Down Expand Up @@ -322,6 +322,13 @@ SimpleOrcJIT::InferTargetMachineForJIT(
return target_machine;
}

static CompilerFunctor::TargetMachineBuilder CreateTargetMachineBuilder(
llvm::TargetOptions target_options, llvm::CodeGenOptLevel opt_level) {
return [target_options, opt_level]() {
return SimpleOrcJIT::InferTargetMachineForJIT(target_options, opt_level);
};
}

SimpleOrcJIT::SimpleOrcJIT(
std::unique_ptr<llvm::orc::ExecutorProcessControl> target_process_control,
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
Expand All @@ -332,7 +339,9 @@ SimpleOrcJIT::SimpleOrcJIT(
LLVMCompiler::ModuleHook post_optimization_hook,
absl::AnyInvocable<void(const llvm::object::ObjectFile&)> post_codegen_hook,
size_t num_jit_dylibs)
: target_machine_(InferTargetMachineForJIT(target_options, opt_level)),
: target_machine_builder_(
CreateTargetMachineBuilder(target_options, opt_level)),
target_machine_(target_machine_builder_()),
target_triple_(target_machine_->getTargetTriple()),
data_layout_(target_machine_->createDataLayout()),
target_process_control_(std::move(target_process_control)),
Expand All @@ -345,7 +354,7 @@ SimpleOrcJIT::SimpleOrcJIT(
compile_layer_(
*execution_session_, object_layer_,
std::make_unique<CompilerFunctor>(
target_machine_.get(), static_cast<int>(opt_level),
[&] { return target_machine_; }, static_cast<int>(opt_level),
optimize_for_size, disable_expensive_passes,
disable_slp_vectorizer, fast_math_flags,
std::move(pre_optimization_hook),
Expand Down Expand Up @@ -707,7 +716,6 @@ bool RegisterKnownJITSymbols() {
}

bool unused = RegisterKnownJITSymbols();
} // namespace

} // namespace cpu
} // namespace xla
} // namespace
} // namespace xla::cpu
Loading

0 comments on commit 4eb3e66

Please sign in to comment.