From cc9a9c0e842d274f8e57339ccf82d2c42c97a88d Mon Sep 17 00:00:00 2001 From: Will Froom Date: Wed, 11 Dec 2024 02:36:38 -0800 Subject: [PATCH] [XLA:CPU] Use KernelApiIrBuilder in IrEmitter2 PiperOrigin-RevId: 705032240 --- xla/service/cpu/BUILD | 2 +- xla/service/cpu/ir_emitter2.cc | 119 ++++------------------------ xla/service/cpu/ir_emitter2.h | 35 +------- xla/service/cpu/ir_emitter2_test.cc | 32 ++++---- 4 files changed, 36 insertions(+), 152 deletions(-) diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 7098fea386963..d8868d2f7425f 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -649,7 +649,6 @@ cc_library( ":dot_op_emitter", ":elemental_math_emitter", ":ir_emitter", - ":ir_function", ":parallel_loop_emitter", ":shape_partition", "//xla:cpu_function_runtime", @@ -657,6 +656,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/backends/cpu/codegen:kernel_api_ir_builder", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:elemental_ir_emitter", diff --git a/xla/service/cpu/ir_emitter2.cc b/xla/service/cpu/ir_emitter2.cc index 73c6ebfd3c574..60d0e6a74523d 100644 --- a/xla/service/cpu/ir_emitter2.cc +++ b/xla/service/cpu/ir_emitter2.cc @@ -49,6 +49,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "llvm/Support/CodeGen.h" +#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h" #include "xla/cpu_function_runtime.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -81,45 +82,6 @@ limitations under the License. #include "tsl/platform/statusor.h" namespace xla::cpu { -namespace { - -// Following struct types correspond to HostKernel C API. -// See: xla/stream_executor/host/host_kernel_c_api.h - -static llvm::StructType* Dim3StructTy(llvm::LLVMContext& ctx, - std::string_view name) { - auto* i64 = llvm::IntegerType::getInt64Ty(ctx); - return llvm::StructType::create(name, i64, i64, i64); -} - -static llvm::StructType* KernelThreadDimTy(llvm::LLVMContext& ctx) { - return Dim3StructTy(ctx, "SE_HOST_KernelThreadDim"); -} - -static llvm::StructType* KernelThreadTy(llvm::LLVMContext& ctx) { - return Dim3StructTy(ctx, "SE_HOST_KernelThread"); -} - -static llvm::StructType* KernelArgTy(llvm::LLVMContext& ctx) { - auto* ptr = llvm::PointerType::getUnqual(ctx); - auto* i64 = llvm::IntegerType::getInt64Ty(ctx); - return llvm::StructType::create("SE_HOST_KernelArg", ptr, i64); -} - -static llvm::StructType* KernelCallFrameTy(llvm::LLVMContext& ctx) { - auto* ptr = llvm::PointerType::getUnqual(ctx); - auto* i64 = llvm::IntegerType::getInt64Ty(ctx); - return llvm::StructType::create("SE_HOST_KernelCallFrame", ptr, ptr, i64, - ptr); -} - -static llvm::FunctionType* KernelFunctionTy(llvm::LLVMContext& ctx) { - return llvm::FunctionType::get(llvm::PointerType::getUnqual(ctx), - llvm::PointerType::getUnqual(ctx), - /*isVarArg=*/false); -} - -} // namespace //===----------------------------------------------------------------------===// // ElementalIrEmitter @@ -217,10 +179,10 @@ IrEmitter2::IrEmitter2(const HloModule& hlo_module, llvm::Module* module, : hlo_module_(hlo_module), module_(module), nested_ir_emitter_(nested_ir_emitter), - call_frame_ty_(KernelCallFrameTy(module_->getContext())), - thread_dims_ty_(KernelThreadDimTy(module_->getContext())), - thread_ty_(KernelThreadTy(module_->getContext())), - arg_ty_(KernelArgTy(module_->getContext())) {} + kernel_api_ir_builder_(module_->getContext(), + hlo_module_.config() + .debug_options() + .xla_llvm_enable_invariant_load_metadata()) {} bool IrEmitter2::fast_min_max() const { return hlo_module_.config().debug_options().xla_cpu_enable_fast_min_max(); @@ -656,61 +618,6 @@ absl::Status IrEmitter2::VerifyKernelParameters( return absl::OkStatus(); } -IrEmitter2::KernelThreadDims IrEmitter2::EmitKernelThreadDims( - llvm::IRBuilderBase& b, llvm::Value* call_frame) { - auto* td_gep = b.CreateStructGEP(call_frame_ty_, call_frame, 0, "tdims_gep"); - auto* tdims = b.CreateLoad(b.getPtrTy(), td_gep, "tdims"); - auto* x_gep = b.CreateStructGEP(thread_dims_ty_, tdims, 0, "tdim_x_gep"); - auto* y_gep = b.CreateStructGEP(thread_dims_ty_, tdims, 1, "tdim_y_gep"); - auto* z_gep = b.CreateStructGEP(thread_dims_ty_, tdims, 2, "tdim_z_gep"); - - return {b.CreateLoad(b.getInt64Ty(), x_gep, "tdim_x"), - b.CreateLoad(b.getInt64Ty(), y_gep, "tdim_y"), - b.CreateLoad(b.getInt64Ty(), z_gep, "tdim_z")}; -} - -IrEmitter2::KernelThread IrEmitter2::EmitKernelThread(llvm::IRBuilderBase& b, - llvm::Value* call_frame) { - auto* t_gep = b.CreateStructGEP(call_frame_ty_, call_frame, 1, "tid_gep"); - auto* tids = b.CreateLoad(b.getPtrTy(), t_gep, "tids"); - auto* x_gep = b.CreateStructGEP(thread_ty_, tids, 0, "tid_x_gep"); - auto* y_gep = b.CreateStructGEP(thread_ty_, tids, 1, "tid_y_gep"); - auto* z_gep = b.CreateStructGEP(thread_ty_, tids, 2, "tid_z_gep"); - - return {b.CreateLoad(b.getInt64Ty(), x_gep, "tid_x"), - b.CreateLoad(b.getInt64Ty(), y_gep, "tid_y"), - b.CreateLoad(b.getInt64Ty(), z_gep, "tid_z")}; -} - -llvm_ir::IrArray IrEmitter2::EmitKernelArgument(llvm::IRBuilderBase& b, - llvm::Value* call_frame, - int64_t index, - const Shape& shape) { - llvm::Type* ptr = llvm::PointerType::get(b.getContext(), 0); - std::string name = absl::StrCat("arg", index); - - auto* args_gep = b.CreateStructGEP(call_frame_ty_, call_frame, 3, "args_gep"); - auto* args = b.CreateLoad(ptr, args_gep, "args"); - auto* data_gep = b.CreateConstGEP2_32(arg_ty_, args, index, 0, name + "_gep"); - auto* data = b.CreateLoad(ptr, data_gep, name); - - // All buffers passed to host kernels are expected to be properly aligned, - // emit metadata to allow LLVM to use that information for optimization. - llvm_ir::SetAlignmentMetadataForLoad(data, cpu_function_runtime::MinAlign()); - - // All buffers pointers passed to host kernels are expected to be - // dereferenceable. - IrEmitter::AttachDereferenceableMetadataForLoad(data, ByteSizeOf(shape)); - - // All buffers pointers passed to host kernels are expected to be invariant - // over the whole program. Note the metadata is attached only to loading - // buffer pointers, not to loading actual buffers. - AttachInvariantLoadMetadataForLoad(data); - - return llvm_ir::IrArray(data, llvm_ir::ShapeToIrType(shape, b.getContext()), - shape); -} - absl::StatusOr IrEmitter2::EmitKernelPrototype( std::string_view name, absl::Span arguments, absl::Span results) { @@ -778,8 +685,8 @@ absl::StatusOr IrEmitter2::EmitKernelPrototype( // Create a kernel function with HostKernel API. We use external linkage // because we'll be resolving this function from the XLA runtime. - llvm::Function* function = llvm::Function::Create( - KernelFunctionTy(ctx), llvm::GlobalValue::ExternalLinkage, name, module_); + llvm::Function* function = + kernel_api_ir_builder_.EmitKernelFunction(*module_, name); function->setCallingConv(llvm::CallingConv::C); // Generate unwind information so that GDB can crawl through the stack frames @@ -802,8 +709,10 @@ absl::StatusOr IrEmitter2::EmitKernelPrototype( llvm::Value* call_frame = function->getArg(0); // Build thread coordinates from the call frame. - KernelThreadDims kernel_thread_dims = EmitKernelThreadDims(b, call_frame); - KernelThread kernel_thread = EmitKernelThread(b, call_frame); + KernelApiIrBuilder::ThreadDims kernel_thread_dims = + kernel_api_ir_builder_.EmitKernelThreadDims(b, call_frame); + KernelApiIrBuilder::ThreadId kernel_thread = + kernel_api_ir_builder_.EmitKernelThread(b, call_frame); int64_t idx = 0; @@ -815,7 +724,8 @@ absl::StatusOr IrEmitter2::EmitKernelPrototype( std::vector ir_arguments; for (int64_t i = 0; i < arguments.size(); ++i) { const KernelParameter& argument = arguments[i]; - auto ir_argument = EmitKernelArgument(b, call_frame, idx++, argument.shape); + auto ir_argument = kernel_api_ir_builder_.EmitKernelArgument( + b, call_frame, idx++, argument.shape); if (auto* noalias = get_noalias(argument.slice)) { ir_argument.AddNoaliasMetadata(noalias); } @@ -833,7 +743,8 @@ absl::StatusOr IrEmitter2::EmitKernelPrototype( // IrArrays for the results. std::vector ir_results; for (const KernelParameter& result : results) { - auto ir_result = EmitKernelArgument(b, call_frame, idx++, result.shape); + auto ir_result = kernel_api_ir_builder_.EmitKernelArgument( + b, call_frame, idx++, result.shape); if (auto* noalias = get_noalias(result.slice)) { ir_result.AddNoaliasMetadata(noalias); } diff --git a/xla/service/cpu/ir_emitter2.h b/xla/service/cpu/ir_emitter2.h index 3c7f874c041f5..38f97c87d07c3 100644 --- a/xla/service/cpu/ir_emitter2.h +++ b/xla/service/cpu/ir_emitter2.h @@ -32,6 +32,7 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -82,20 +83,6 @@ class IrEmitter2 { BufferAllocation::Slice slice; }; - // Thread dimensions of the kernel invocation. - struct KernelThreadDims { - llvm::Value* x; - llvm::Value* y; - llvm::Value* z; - }; - - // Thread coordinates of the kernel invocation. - struct KernelThread { - llvm::Value* x; - llvm::Value* y; - llvm::Value* z; - }; - // Emitted kernel information that defines how to launch it at run time. struct KernelInfo { explicit KernelInfo(KernelPrototype prototype, @@ -167,8 +154,8 @@ class IrEmitter2 { llvm::BasicBlock* return_block; // LLVM values identifying kernel invocation thread coordinates. - KernelThreadDims thread_dims; - KernelThread thread; + KernelApiIrBuilder::ThreadDims thread_dims; + KernelApiIrBuilder::ThreadId thread; // LLVM values corresponding to the kernel arguments and results arrays. All // tuples are flattened as we do not have any tuples at run time and only @@ -221,16 +208,6 @@ class IrEmitter2 { absl::Span arguments, absl::Span results); - KernelThreadDims EmitKernelThreadDims(llvm::IRBuilderBase& b, - llvm::Value* call_frame); - - KernelThread EmitKernelThread(llvm::IRBuilderBase& b, - llvm::Value* call_frame); - - llvm_ir::IrArray EmitKernelArgument(llvm::IRBuilderBase& b, - llvm::Value* call_frame, int64_t index, - const Shape& shape); - // Returns parallel config for the given instruction or std::nullopt if // the instruction has to be compiled to a single threaded loop. std::optional GetParallelConfig(const HloInstruction* instr); @@ -268,11 +245,7 @@ class IrEmitter2 { // to reductions inside fusions). IrEmitter* nested_ir_emitter_; - // LLVM types defining HostKernel API (see host_kernel_c_api.h). - llvm::StructType* call_frame_ty_; - llvm::StructType* thread_dims_ty_; - llvm::StructType* thread_ty_; - llvm::StructType* arg_ty_; + KernelApiIrBuilder kernel_api_ir_builder_; // Keeps track of all the functions emitted so far. std::vector kernels_; diff --git a/xla/service/cpu/ir_emitter2_test.cc b/xla/service/cpu/ir_emitter2_test.cc index ee2464c7b9cad..16d043c7ac438 100644 --- a/xla/service/cpu/ir_emitter2_test.cc +++ b/xla/service/cpu/ir_emitter2_test.cc @@ -144,40 +144,40 @@ TEST_F(IrEmitter2Test, BuildKernelPrototype) { absl::StrCat(R"( CHECK: define ptr @test(ptr %0) #0 { - CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 0 - CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 0 - CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 1 - CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 2 + CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 0 + CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThreadDim, {{.*}} i32 0 + CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThreadDim, {{.*}} i32 1 + CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThreadDim, {{.*}} i32 2 CHECK: load i64 CHECK: load i64 CHECK: load i64 - CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 1 - CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 0 - CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 1 - CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 2 + CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 1 + CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThread, {{.*}} i32 0 + CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThread, {{.*}} i32 1 + CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThread, {{.*}} i32 2 CHECK: load i64 CHECK: load i64 CHECK: load i64 - CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 + CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 3 CHECK: load ptr - CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 0, i32 0 + CHECK: getelementptr %XLA_CPU_KernelArg, {{.*}} i32 0, i32 0 CHECK: %[[ARG0:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0:.+]], !dereferenceable ![[DEREF_BYTES:.+]], !align ![[ALIGNMENT:.+]] - CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 + CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 3 CHECK: load ptr - CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 1, i32 0 + CHECK: getelementptr %XLA_CPU_KernelArg, {{.*}} i32 1, i32 0 CHECK: %[[ARG1:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] - CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 + CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 3 CHECK: load ptr - CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 2, i32 0 + CHECK: getelementptr %XLA_CPU_KernelArg, {{.*}} i32 2, i32 0 CHECK: %[[ARG2:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] - CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 + CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 3 CHECK: load ptr - CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 3, i32 0 + CHECK: getelementptr %XLA_CPU_KernelArg, {{.*}} i32 3, i32 0 CHECK: %[[ARG3:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] CHECK-NEXT: %[[PTR0:.+]] = getelementptr inbounds float, ptr %[[ARG0]]