diff --git a/xla/pjrt/cpu/BUILD b/xla/pjrt/cpu/BUILD index cd7d7cf1d8096..38200e06ceea1 100644 --- a/xla/pjrt/cpu/BUILD +++ b/xla/pjrt/cpu/BUILD @@ -195,7 +195,6 @@ cc_library( "//xla/service/cpu:cpu_executable_run_options", "//xla/service/cpu:cpu_runtime", "//xla/service/cpu:cpu_xfeed", - "//xla/service/cpu:simple_orc_jit", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/concurrency:ref_count", diff --git a/xla/pjrt/cpu/cpu_client.cc b/xla/pjrt/cpu/cpu_client.cc index c880294584d2f..636f6412bb51d 100644 --- a/xla/pjrt/cpu/cpu_client.cc +++ b/xla/pjrt/cpu/cpu_client.cc @@ -90,7 +90,6 @@ limitations under the License. #include "xla/service/cpu/cpu_executable_run_options.h" #include "xla/service/cpu/cpu_runtime.h" #include "xla/service/cpu/cpu_xfeed.h" -#include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" #include "xla/service/dump.h" diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index a8d5cdaaf127b..95f702d804ca2 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -11,7 +11,6 @@ load( "//xla:xla.bzl", "xla_cc_binary", "xla_cc_test", - "xla_internal", ) load("//xla/tests:build_defs.bzl", "xla_test") load("//xla/tsl:tsl.bzl", "internal_visibility", "tf_openmp_copts", "tsl_copts") @@ -231,7 +230,6 @@ cc_library( ":onednn_ops_rewriter", ":parallel_task_assignment", ":runtime_symbol_generator", - ":simple_orc_jit", ":thunk_emitter", ":xla_framework", "//xla:cpu_function_runtime", @@ -540,41 +538,6 @@ cc_library( ], ) -cc_library( - name = "simple_orc_jit", - srcs = ["simple_orc_jit.cc"], - hdrs = ["simple_orc_jit.h"], - copts = if_enable_acl(["-DXLA_CPU_USE_ACL=1"]) + tsl_copts(), - deps = [ - ":orc_jit_memory_mapper", - ":runtime_symbol_generator", - "//xla:types", - "//xla:util", - "//xla/backends/cpu/codegen:contiguous_section_memory_manager", - "//xla/backends/cpu/codegen:cpu_features", - "//xla/backends/cpu/codegen:ir_compiler", - "//xla/backends/cpu/codegen:jit_compiler", - "//xla/service:custom_call_target_registry", - "//xla/service:llvm_compiler", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Core", - "@llvm-project//llvm:ExecutionEngine", - "@llvm-project//llvm:MC", # fixdeps: keep - "@llvm-project//llvm:OrcJIT", - "@llvm-project//llvm:OrcShared", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:Target", # fixdeps: keep - "@llvm-project//llvm:TargetParser", - "@llvm-project//mlir:mlir_c_runner_utils", - "@tsl//tsl/platform:logging", - ] + xla_internal(["service/cpu:named_orc_jit_memory_mapper"]), -) - cc_library( name = "runtime_lightweight_check", hdrs = ["runtime_lightweight_check.h"], @@ -617,7 +580,6 @@ cc_library( hdrs = ["cpu_executable.h"], deps = [ ":cpu_runtime", - ":simple_orc_jit", "//xla:executable_run_options", "//xla:literal", "//xla:shape_tree", diff --git a/xla/service/cpu/cpu_executable.cc b/xla/service/cpu/cpu_executable.cc index 9bf0eefe873fd..2b3949fe7eab1 100644 --- a/xla/service/cpu/cpu_executable.cc +++ b/xla/service/cpu/cpu_executable.cc @@ -54,7 +54,6 @@ limitations under the License. #include "xla/literal.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/cpu_runtime.h" -#include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" #include "xla/service/executable.h" diff --git a/xla/service/cpu/cpu_executable.h b/xla/service/cpu/cpu_executable.h index fb21eacfda4b1..2222bf0adcb48 100644 --- a/xla/service/cpu/cpu_executable.h +++ b/xla/service/cpu/cpu_executable.h @@ -36,7 +36,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" #include "xla/service/executable.h" diff --git a/xla/service/cpu/cpu_runtime.h b/xla/service/cpu/cpu_runtime.h index d814ff4f74395..1212e7763e6a1 100644 --- a/xla/service/cpu/cpu_runtime.h +++ b/xla/service/cpu/cpu_runtime.h @@ -37,10 +37,8 @@ namespace runtime { // Names of runtime functions. These get resolved from the generated code to the // right symbol at link time in one of two ways: -// 1. When using the JIT, the symbol resolver (SimpleResolver in -// third_party/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc) maps -// this symbol name to -// the actual symbol. +// 1. When using the JIT, the symbol resolver (xla::cpu::RuntimeSymbolGenerator) +// maps this symbol name to the actual symbol. // 2. When using ahead-of-time compilation, the linker can resolve the name // because it is a symbol in the cpu_runtime library. extern const char* const kEigenMatMulF16SymbolName; diff --git a/xla/service/cpu/simple_orc_jit.cc b/xla/service/cpu/simple_orc_jit.cc deleted file mode 100644 index 99db17d5aba3f..0000000000000 --- a/xla/service/cpu/simple_orc_jit.cc +++ /dev/null @@ -1,191 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/cpu/simple_orc_jit.h" - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ExecutionEngine/ExecutionEngine.h" -#include "llvm/ExecutionEngine/Orc/Core.h" -#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" -#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" -#include "llvm/ExecutionEngine/Orc/SymbolStringPool.h" -#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" -#include "llvm/ExecutionEngine/RTDyldMemoryManager.h" -#include "llvm/ExecutionEngine/RuntimeDyld.h" -#include "llvm/IR/FMF.h" -#include "llvm/IR/Mangler.h" -#include "llvm/Support/CodeGen.h" -#include "llvm/Support/Error.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/Process.h" -#include "llvm/Target/TargetOptions.h" -#include "llvm/TargetParser/Host.h" -#include "xla/backends/cpu/codegen/contiguous_section_memory_manager.h" -#include "xla/backends/cpu/codegen/cpu_features.h" -#include "xla/backends/cpu/codegen/ir_compiler.h" -#include "xla/backends/cpu/codegen/jit_compiler.h" -#include "xla/service/cpu/orc_jit_memory_mapper.h" -#include "xla/service/cpu/runtime_symbol_generator.h" -#include "xla/service/llvm_compiler.h" -#include "tsl/platform/logging.h" - -namespace xla::cpu { - -SimpleOrcJIT::SimpleOrcJIT( - std::unique_ptr target_process_control, - std::unique_ptr execution_session, - const llvm::TargetOptions& target_options, llvm::CodeGenOptLevel opt_level, - bool optimize_for_size, bool disable_expensive_passes, - bool disable_slp_vectorizer, llvm::FastMathFlags fast_math_flags, - LLVMCompiler::ModuleHook pre_optimization_hook, - LLVMCompiler::ModuleHook post_optimization_hook, - std::function post_codegen_hook, - size_t num_jit_dylibs, absl::string_view max_cpu_isa) - : target_machine_builder_(JitCompiler::InferTargetMachineBuilder( - target_options, opt_level, CpuFeatureFromString(max_cpu_isa))), - target_machine_(target_machine_builder_().value()), - target_triple_(target_machine_->getTargetTriple()), - data_layout_(target_machine_->createDataLayout()), - target_process_control_(std::move(target_process_control)), - execution_session_(std::move(execution_session)), - object_layer_(*execution_session_, - []() { - return std::make_unique( - orc_jit_memory_mapper::GetInstance()); - }), - compile_layer_( - *execution_session_, object_layer_, - std::make_unique( - target_machine_builder_, - IrCompiler::Options{ - /*optimization_level=*/static_cast(opt_level), - /*optimize_for_size=*/optimize_for_size, - /*fast_math_flags=*/fast_math_flags, - /*disable_expensive_passes=*/disable_expensive_passes, - /*disable_slp_vectorizer=*/disable_slp_vectorizer, - }, - IrCompiler::CompilationHooks{ - std::move(pre_optimization_hook), - std::move(post_optimization_hook), - std::move(post_codegen_hook), - })), - gdb_jit_event_listener_( - llvm::JITEventListener::createGDBRegistrationListener()), - perf_jit_event_listener_( - llvm::JITEventListener::createPerfJITEventListener()) { - VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() - << " features: " << target_machine_->getTargetFeatureString().str(); - - // Always create at least one dylib. - num_jit_dylibs = std::max(size_t{1}, num_jit_dylibs); - jit_dylibs_.resize(num_jit_dylibs); - for (size_t i = 0; i < num_jit_dylibs; ++i) { - jit_dylibs_[i] = &execution_session_->createBareJITDylib( - absl::StrCat("")); - jit_dylibs_[i]->addGenerator( - std::make_unique(data_layout_)); - } - - object_layer_.registerJITEventListener(*this); - if (perf_jit_event_listener_) { - object_layer_.registerJITEventListener(*perf_jit_event_listener_); - } - - // Copied from LLJIT, required to find symbols on Windows. - if (target_triple_.isOSBinFormatCOFF()) { - object_layer_.setOverrideObjectFlagsWithResponsibilityFlags(true); - object_layer_.setAutoClaimResponsibilityForObjectSymbols(true); - } -} - -SimpleOrcJIT::~SimpleOrcJIT() { - if (auto err = execution_session_->endSession()) { - execution_session_->reportError(std::move(err)); - } -} - -llvm::Expected> SimpleOrcJIT::Create( - const llvm::TargetOptions& target_options, llvm::CodeGenOptLevel opt_level, - bool optimize_for_size, bool disable_expensive_passes, - bool disable_slp_vectorizer, llvm::FastMathFlags fast_math_flags, - LLVMCompiler::ModuleHook pre_optimization_hook, - LLVMCompiler::ModuleHook post_optimization_hook, - std::function post_codegen_hook, - size_t num_jit_dylibs, absl::string_view max_cpu_isa) { - auto SSP = std::make_shared(); - auto target_process_control = - llvm::orc::SelfExecutorProcessControl::Create(std::move(SSP)); - if (!target_process_control) { - return target_process_control.takeError(); - } - - auto execution_session = std::make_unique( - std::make_unique()); - return std::make_unique( - std::move(*target_process_control), std::move(execution_session), - target_options, opt_level, optimize_for_size, disable_expensive_passes, - disable_slp_vectorizer, fast_math_flags, std::move(pre_optimization_hook), - std::move(post_optimization_hook), std::move(post_codegen_hook), - num_jit_dylibs, std::move(max_cpu_isa)); -} - -void SimpleOrcJIT::notifyObjectLoaded( - llvm::JITEventListener::ObjectKey key, - const llvm::object::ObjectFile& object, - const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { - gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info); - size_of_generated_code_in_bytes_ += object.getData().size(); -} - -void SimpleOrcJIT::notifyFreeingObject(llvm::JITEventListener::ObjectKey key) { - gdb_jit_event_listener_->notifyFreeingObject(key); -} - -llvm::Error SimpleOrcJIT::AddObjFile( - std::unique_ptr obj_file, size_t dylib_index) { - return object_layer_.add(*jit_dylibs_[dylib_index], std::move(obj_file)); -} - -llvm::Error SimpleOrcJIT::AddModule(llvm::orc::ThreadSafeModule module, - size_t dylib_index) { - return compile_layer_.add(*jit_dylibs_[dylib_index], std::move(module)); -} - -void SimpleOrcJIT::DoneCompiling() { - // The target machine takes a non-trivial amount of memory, so once we are - // done compiling throw it away. - target_machine_.reset(); -} - -llvm::Expected SimpleOrcJIT::FindCompiledSymbol( - const std::string& name) { - return execution_session_->lookup(jit_dylibs_, name); -} - -} // namespace xla::cpu diff --git a/xla/service/cpu/simple_orc_jit.h b/xla/service/cpu/simple_orc_jit.h deleted file mode 100644 index 25b23fe94de28..0000000000000 --- a/xla/service/cpu/simple_orc_jit.h +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_ -#define XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ExecutionEngine/JITEventListener.h" -#include "llvm/ExecutionEngine/Orc/Core.h" -#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" -#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" -#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" -#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" -#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" -#include "llvm/ExecutionEngine/RuntimeDyld.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/FMF.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/CodeGen.h" -#include "llvm/Support/Error.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Target/TargetMachine.h" -#include "llvm/Target/TargetOptions.h" -#include "llvm/TargetParser/Triple.h" -#include "xla/backends/cpu/codegen/cpu_features.h" // IWYU pragma: keep -#include "xla/backends/cpu/codegen/ir_compiler.h" -#include "xla/service/llvm_compiler.h" - -namespace xla::cpu { - -// Simplified LLVM JIT based on the new Orc API. -// -// This class wraps Orc's functionality into a single interface that only -// exposes what we need for XLA. -// -// Supports JIT-ing multiple modules but without cross-module linking. -// Implements eager compilation - the module is lowered to binary as soon as -// it's added to the JIT. -class SimpleOrcJIT : public llvm::JITEventListener { - public: - using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer; - using CompileLayerT = llvm::orc::IRCompileLayer; - - // Create a new JIT, targeting the host architecture. - // - // {pre,post}_optimization_hook is invoked on the module before/after all - // LLVM IR-level optimizations. post_codegen_hook is invoked after - // compiling to machine code. - SimpleOrcJIT( - std::unique_ptr target_process_control, - std::unique_ptr execution_session, - const llvm::TargetOptions& target_options, - llvm::CodeGenOptLevel opt_level, bool optimize_for_size, - bool disable_expensive_passes, bool disable_slp_vectorizer, - llvm::FastMathFlags fast_math_flags, - LLVMCompiler::ModuleHook pre_optimization_hook, - LLVMCompiler::ModuleHook post_optimization_hook, - std::function post_codegen_hook, - size_t num_jit_dylibs, absl::string_view max_cpu_isa); - - static llvm::Expected> Create( - const llvm::TargetOptions& target_options, - llvm::CodeGenOptLevel opt_level, bool optimize_for_size, - bool disable_expensive_passes, bool disable_slp_vectorizer, - llvm::FastMathFlags fast_math_flags, - LLVMCompiler::ModuleHook pre_optimization_hook, - LLVMCompiler::ModuleHook post_optimization_hook, - std::function post_codegen_hook, - size_t num_jit_dylibs, absl::string_view max_cpu_isa); - - ~SimpleOrcJIT() override; - - const llvm::DataLayout& data_layout() const { return data_layout_; } - - const llvm::Triple& target_triple() const { return target_triple_; } - - llvm::Error AddObjFile(std::unique_ptr obj_file, - size_t dylib_index = 0); - llvm::Error AddModule(llvm::orc::ThreadSafeModule module, - size_t dylib_index = 0); - - // Discards objects we no longer need once we are done compiling. - void DoneCompiling(); - - // Get the runtime address of the compiled symbol whose name is given. Returns - // nullptr if the symbol cannot be found. - llvm::Expected FindCompiledSymbol( - const std::string& name); - - llvm::TargetMachine* target_machine() const { return target_machine_.get(); } - - int64_t SizeOfGeneratedCodeInBytes() const { - return size_of_generated_code_in_bytes_; - } - - void AddKernelSymbol(std::string_view name) { - kernel_symbols_.insert(std::string(name)); - } - - private: - void notifyObjectLoaded( - llvm::JITEventListener::ObjectKey key, - const llvm::object::ObjectFile& object, - const llvm::RuntimeDyld::LoadedObjectInfo& object_info) override; - void notifyFreeingObject(llvm::JITEventListener::ObjectKey key) override; - - // Target machine builder that is used to construct target machines for this - // instance of SimpleOrcJIT, and to construct `target_machine_`. - IrCompiler::TargetMachineBuilder target_machine_builder_; - std::shared_ptr target_machine_; - - llvm::Triple target_triple_; - const llvm::DataLayout data_layout_; - std::unique_ptr target_process_control_; - std::unique_ptr execution_session_; - ObjLayerT object_layer_; - CompileLayerT compile_layer_; - llvm::SmallVector jit_dylibs_; - int64_t size_of_generated_code_in_bytes_ = 0; - - // Symbols corresponding to kernel functions. Because we use module splitting, - // some of the modules might have a declaration, but no definition of the - // kernel function, and this is fine, and should not log an error. - absl::flat_hash_set kernel_symbols_; - - // Non owning pointer to a JIT event listener that registers the JIT events - // with an attached GDB. - // - // Note: we get a pointer to this event listener using - // `createGDBRegistrationListener` which makes it look like we're supposed to - // free this, but the function is poorly named and really just returns a - // pointer to a static object. - llvm::JITEventListener* gdb_jit_event_listener_; - - llvm::JITEventListener* perf_jit_event_listener_; -}; - -} // namespace xla::cpu - -#endif // XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_ diff --git a/xla/service/cpu/tests/BUILD b/xla/service/cpu/tests/BUILD index d98ab38c9935c..033e921dd4dfb 100644 --- a/xla/service/cpu/tests/BUILD +++ b/xla/service/cpu/tests/BUILD @@ -331,7 +331,6 @@ xla_cc_test( "//xla/backends/cpu/codegen:cpu_features", "//xla/hlo/ir:hlo", "//xla/service/cpu:cpu_compiler", - "//xla/service/cpu:simple_orc_jit", "//xla/tests:hlo_test_base", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings",