Skip to content

Commit

Permalink
[xla:cpu] Remove unused globals and function declarations from LLVM m…
Browse files Browse the repository at this point in the history
…odules after split

PiperOrigin-RevId: 666548782
  • Loading branch information
ezhulenev authored and copybara-github committed Aug 23, 2024
1 parent b95c143 commit 4cf7b55
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Mangler.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -1131,6 +1132,32 @@ CreateConstantAllocations(const BufferAssignment& assignment) {
return constants;
}

// Removes unused globals and function declarations from the LLVM module.
//
// After splitting LLVM module into multiple parts, we end up with unused
// symbols in each part: external globals and function declarations. We don't
// support linking across modules added to SimpleOrcJIT, and we don't need it,
// because we never construct LLVM IR that might require cross-module linking,
// so we can just remove unused symbols from each part.
static void RemoveUnusedSymbols(llvm::Module& module) {
llvm::SmallVector<llvm::GlobalVariable*> unused_globals;
llvm::SmallVector<llvm::Function*> unused_functions;

for (llvm::GlobalVariable& gv : module.globals()) {
if (gv.use_empty()) unused_globals.push_back(&gv);
}
for (llvm::Function& f : module.functions()) {
if (f.isDeclaration() && f.use_empty()) unused_functions.push_back(&f);
}

for (auto* gv : unused_globals) {
module.eraseGlobalVariable(gv);
}
for (auto* f : unused_functions) {
f->eraseFromParent();
}
}

absl::StatusOr<std::unique_ptr<CpuExecutable>>
CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
TraceMe trace([&] {
Expand Down Expand Up @@ -1313,11 +1340,13 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {

llvm::SplitModule(
*llvm_module, num_parts,
[&, dylib_index = 0](auto llvm_module_part) mutable {
[&, n = 0](std::unique_ptr<llvm::Module> llvm_module_part) mutable {
// Remove unused symbols left in the module after splitting.
RemoveUnusedSymbols(*llvm_module_part);
cantFail((*jit)->AddModule(
llvm::orc::ThreadSafeModule(std::move(llvm_module_part),
thread_safe_context),
dylib_index++ % parallel_codegen_split_count));
n++ % parallel_codegen_split_count));
},
/*PreserveLocals=*/true);
} else {
Expand Down

0 comments on commit 4cf7b55

Please sign in to comment.