Skip to content

Commit

Permalink
[XLA:CPU] Use KernelApiIrBuilder in IrEmitter2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705032240
  • Loading branch information
WillFroom authored and Google-ML-Automation committed Dec 11, 2024
1 parent 27f62a1 commit cc9a9c0
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 152 deletions.
2 changes: 1 addition & 1 deletion xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -649,14 +649,14 @@ cc_library(
":dot_op_emitter",
":elemental_math_emitter",
":ir_emitter",
":ir_function",
":parallel_loop_emitter",
":shape_partition",
"//xla:cpu_function_runtime",
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/backends/cpu/codegen:kernel_api_ir_builder",
"//xla/hlo/ir:hlo",
"//xla/service:buffer_assignment",
"//xla/service:elemental_ir_emitter",
Expand Down
119 changes: 15 additions & 104 deletions xla/service/cpu/ir_emitter2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/CodeGen.h"
#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h"
#include "xla/cpu_function_runtime.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand Down Expand Up @@ -81,45 +82,6 @@ limitations under the License.
#include "tsl/platform/statusor.h"

namespace xla::cpu {
namespace {

// Following struct types correspond to HostKernel C API.
// See: xla/stream_executor/host/host_kernel_c_api.h

static llvm::StructType* Dim3StructTy(llvm::LLVMContext& ctx,
std::string_view name) {
auto* i64 = llvm::IntegerType::getInt64Ty(ctx);
return llvm::StructType::create(name, i64, i64, i64);
}

static llvm::StructType* KernelThreadDimTy(llvm::LLVMContext& ctx) {
return Dim3StructTy(ctx, "SE_HOST_KernelThreadDim");
}

static llvm::StructType* KernelThreadTy(llvm::LLVMContext& ctx) {
return Dim3StructTy(ctx, "SE_HOST_KernelThread");
}

static llvm::StructType* KernelArgTy(llvm::LLVMContext& ctx) {
auto* ptr = llvm::PointerType::getUnqual(ctx);
auto* i64 = llvm::IntegerType::getInt64Ty(ctx);
return llvm::StructType::create("SE_HOST_KernelArg", ptr, i64);
}

static llvm::StructType* KernelCallFrameTy(llvm::LLVMContext& ctx) {
auto* ptr = llvm::PointerType::getUnqual(ctx);
auto* i64 = llvm::IntegerType::getInt64Ty(ctx);
return llvm::StructType::create("SE_HOST_KernelCallFrame", ptr, ptr, i64,
ptr);
}

static llvm::FunctionType* KernelFunctionTy(llvm::LLVMContext& ctx) {
return llvm::FunctionType::get(llvm::PointerType::getUnqual(ctx),
llvm::PointerType::getUnqual(ctx),
/*isVarArg=*/false);
}

} // namespace

//===----------------------------------------------------------------------===//
// ElementalIrEmitter
Expand Down Expand Up @@ -217,10 +179,10 @@ IrEmitter2::IrEmitter2(const HloModule& hlo_module, llvm::Module* module,
: hlo_module_(hlo_module),
module_(module),
nested_ir_emitter_(nested_ir_emitter),
call_frame_ty_(KernelCallFrameTy(module_->getContext())),
thread_dims_ty_(KernelThreadDimTy(module_->getContext())),
thread_ty_(KernelThreadTy(module_->getContext())),
arg_ty_(KernelArgTy(module_->getContext())) {}
kernel_api_ir_builder_(module_->getContext(),
hlo_module_.config()
.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {}

bool IrEmitter2::fast_min_max() const {
return hlo_module_.config().debug_options().xla_cpu_enable_fast_min_max();
Expand Down Expand Up @@ -656,61 +618,6 @@ absl::Status IrEmitter2::VerifyKernelParameters(
return absl::OkStatus();
}

IrEmitter2::KernelThreadDims IrEmitter2::EmitKernelThreadDims(
llvm::IRBuilderBase& b, llvm::Value* call_frame) {
auto* td_gep = b.CreateStructGEP(call_frame_ty_, call_frame, 0, "tdims_gep");
auto* tdims = b.CreateLoad(b.getPtrTy(), td_gep, "tdims");
auto* x_gep = b.CreateStructGEP(thread_dims_ty_, tdims, 0, "tdim_x_gep");
auto* y_gep = b.CreateStructGEP(thread_dims_ty_, tdims, 1, "tdim_y_gep");
auto* z_gep = b.CreateStructGEP(thread_dims_ty_, tdims, 2, "tdim_z_gep");

return {b.CreateLoad(b.getInt64Ty(), x_gep, "tdim_x"),
b.CreateLoad(b.getInt64Ty(), y_gep, "tdim_y"),
b.CreateLoad(b.getInt64Ty(), z_gep, "tdim_z")};
}

IrEmitter2::KernelThread IrEmitter2::EmitKernelThread(llvm::IRBuilderBase& b,
llvm::Value* call_frame) {
auto* t_gep = b.CreateStructGEP(call_frame_ty_, call_frame, 1, "tid_gep");
auto* tids = b.CreateLoad(b.getPtrTy(), t_gep, "tids");
auto* x_gep = b.CreateStructGEP(thread_ty_, tids, 0, "tid_x_gep");
auto* y_gep = b.CreateStructGEP(thread_ty_, tids, 1, "tid_y_gep");
auto* z_gep = b.CreateStructGEP(thread_ty_, tids, 2, "tid_z_gep");

return {b.CreateLoad(b.getInt64Ty(), x_gep, "tid_x"),
b.CreateLoad(b.getInt64Ty(), y_gep, "tid_y"),
b.CreateLoad(b.getInt64Ty(), z_gep, "tid_z")};
}

llvm_ir::IrArray IrEmitter2::EmitKernelArgument(llvm::IRBuilderBase& b,
llvm::Value* call_frame,
int64_t index,
const Shape& shape) {
llvm::Type* ptr = llvm::PointerType::get(b.getContext(), 0);
std::string name = absl::StrCat("arg", index);

auto* args_gep = b.CreateStructGEP(call_frame_ty_, call_frame, 3, "args_gep");
auto* args = b.CreateLoad(ptr, args_gep, "args");
auto* data_gep = b.CreateConstGEP2_32(arg_ty_, args, index, 0, name + "_gep");
auto* data = b.CreateLoad(ptr, data_gep, name);

// All buffers passed to host kernels are expected to be properly aligned,
// emit metadata to allow LLVM to use that information for optimization.
llvm_ir::SetAlignmentMetadataForLoad(data, cpu_function_runtime::MinAlign());

// All buffers pointers passed to host kernels are expected to be
// dereferenceable.
IrEmitter::AttachDereferenceableMetadataForLoad(data, ByteSizeOf(shape));

// All buffers pointers passed to host kernels are expected to be invariant
// over the whole program. Note the metadata is attached only to loading
// buffer pointers, not to loading actual buffers.
AttachInvariantLoadMetadataForLoad(data);

return llvm_ir::IrArray(data, llvm_ir::ShapeToIrType(shape, b.getContext()),
shape);
}

absl::StatusOr<IrEmitter2::KernelPrototype> IrEmitter2::EmitKernelPrototype(
std::string_view name, absl::Span<const KernelParameter> arguments,
absl::Span<const KernelParameter> results) {
Expand Down Expand Up @@ -778,8 +685,8 @@ absl::StatusOr<IrEmitter2::KernelPrototype> IrEmitter2::EmitKernelPrototype(

// Create a kernel function with HostKernel API. We use external linkage
// because we'll be resolving this function from the XLA runtime.
llvm::Function* function = llvm::Function::Create(
KernelFunctionTy(ctx), llvm::GlobalValue::ExternalLinkage, name, module_);
llvm::Function* function =
kernel_api_ir_builder_.EmitKernelFunction(*module_, name);
function->setCallingConv(llvm::CallingConv::C);

// Generate unwind information so that GDB can crawl through the stack frames
Expand All @@ -802,8 +709,10 @@ absl::StatusOr<IrEmitter2::KernelPrototype> IrEmitter2::EmitKernelPrototype(

llvm::Value* call_frame = function->getArg(0);
// Build thread coordinates from the call frame.
KernelThreadDims kernel_thread_dims = EmitKernelThreadDims(b, call_frame);
KernelThread kernel_thread = EmitKernelThread(b, call_frame);
KernelApiIrBuilder::ThreadDims kernel_thread_dims =
kernel_api_ir_builder_.EmitKernelThreadDims(b, call_frame);
KernelApiIrBuilder::ThreadId kernel_thread =
kernel_api_ir_builder_.EmitKernelThread(b, call_frame);

int64_t idx = 0;

Expand All @@ -815,7 +724,8 @@ absl::StatusOr<IrEmitter2::KernelPrototype> IrEmitter2::EmitKernelPrototype(
std::vector<llvm_ir::IrArray> ir_arguments;
for (int64_t i = 0; i < arguments.size(); ++i) {
const KernelParameter& argument = arguments[i];
auto ir_argument = EmitKernelArgument(b, call_frame, idx++, argument.shape);
auto ir_argument = kernel_api_ir_builder_.EmitKernelArgument(
b, call_frame, idx++, argument.shape);
if (auto* noalias = get_noalias(argument.slice)) {
ir_argument.AddNoaliasMetadata(noalias);
}
Expand All @@ -833,7 +743,8 @@ absl::StatusOr<IrEmitter2::KernelPrototype> IrEmitter2::EmitKernelPrototype(
// IrArrays for the results.
std::vector<llvm_ir::IrArray> ir_results;
for (const KernelParameter& result : results) {
auto ir_result = EmitKernelArgument(b, call_frame, idx++, result.shape);
auto ir_result = kernel_api_ir_builder_.EmitKernelArgument(
b, call_frame, idx++, result.shape);
if (auto* noalias = get_noalias(result.slice)) {
ir_result.AddNoaliasMetadata(noalias);
}
Expand Down
35 changes: 4 additions & 31 deletions xla/service/cpu/ir_emitter2.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
Expand Down Expand Up @@ -82,20 +83,6 @@ class IrEmitter2 {
BufferAllocation::Slice slice;
};

// Thread dimensions of the kernel invocation.
struct KernelThreadDims {
llvm::Value* x;
llvm::Value* y;
llvm::Value* z;
};

// Thread coordinates of the kernel invocation.
struct KernelThread {
llvm::Value* x;
llvm::Value* y;
llvm::Value* z;
};

// Emitted kernel information that defines how to launch it at run time.
struct KernelInfo {
explicit KernelInfo(KernelPrototype prototype,
Expand Down Expand Up @@ -167,8 +154,8 @@ class IrEmitter2 {
llvm::BasicBlock* return_block;

// LLVM values identifying kernel invocation thread coordinates.
KernelThreadDims thread_dims;
KernelThread thread;
KernelApiIrBuilder::ThreadDims thread_dims;
KernelApiIrBuilder::ThreadId thread;

// LLVM values corresponding to the kernel arguments and results arrays. All
// tuples are flattened as we do not have any tuples at run time and only
Expand Down Expand Up @@ -221,16 +208,6 @@ class IrEmitter2 {
absl::Span<const KernelParameter> arguments,
absl::Span<const KernelParameter> results);

KernelThreadDims EmitKernelThreadDims(llvm::IRBuilderBase& b,
llvm::Value* call_frame);

KernelThread EmitKernelThread(llvm::IRBuilderBase& b,
llvm::Value* call_frame);

llvm_ir::IrArray EmitKernelArgument(llvm::IRBuilderBase& b,
llvm::Value* call_frame, int64_t index,
const Shape& shape);

// Returns parallel config for the given instruction or std::nullopt if
// the instruction has to be compiled to a single threaded loop.
std::optional<ParallelConfig> GetParallelConfig(const HloInstruction* instr);
Expand Down Expand Up @@ -268,11 +245,7 @@ class IrEmitter2 {
// to reductions inside fusions).
IrEmitter* nested_ir_emitter_;

// LLVM types defining HostKernel API (see host_kernel_c_api.h).
llvm::StructType* call_frame_ty_;
llvm::StructType* thread_dims_ty_;
llvm::StructType* thread_ty_;
llvm::StructType* arg_ty_;
KernelApiIrBuilder kernel_api_ir_builder_;

// Keeps track of all the functions emitted so far.
std::vector<KernelInfo> kernels_;
Expand Down
32 changes: 16 additions & 16 deletions xla/service/cpu/ir_emitter2_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,40 +144,40 @@ TEST_F(IrEmitter2Test, BuildKernelPrototype) {
absl::StrCat(R"(
CHECK: define ptr @test(ptr %0) #0 {
CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 0
CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 0
CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 1
CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 2
CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 0
CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThreadDim, {{.*}} i32 0
CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThreadDim, {{.*}} i32 1
CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThreadDim, {{.*}} i32 2
CHECK: load i64
CHECK: load i64
CHECK: load i64
CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 1
CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 0
CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 1
CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 2
CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 1
CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThread, {{.*}} i32 0
CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThread, {{.*}} i32 1
CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThread, {{.*}} i32 2
CHECK: load i64
CHECK: load i64
CHECK: load i64
CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3
CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 3
CHECK: load ptr
CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 0, i32 0
CHECK: getelementptr %XLA_CPU_KernelArg, {{.*}} i32 0, i32 0
CHECK: %[[ARG0:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0:.+]], !dereferenceable ![[DEREF_BYTES:.+]], !align ![[ALIGNMENT:.+]]
CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3
CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 3
CHECK: load ptr
CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 1, i32 0
CHECK: getelementptr %XLA_CPU_KernelArg, {{.*}} i32 1, i32 0
CHECK: %[[ARG1:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]]
CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3
CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 3
CHECK: load ptr
CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 2, i32 0
CHECK: getelementptr %XLA_CPU_KernelArg, {{.*}} i32 2, i32 0
CHECK: %[[ARG2:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]]
CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3
CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 3
CHECK: load ptr
CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 3, i32 0
CHECK: getelementptr %XLA_CPU_KernelArg, {{.*}} i32 3, i32 0
CHECK: %[[ARG3:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]]
CHECK-NEXT: %[[PTR0:.+]] = getelementptr inbounds float, ptr %[[ARG0]]
Expand Down

0 comments on commit cc9a9c0

Please sign in to comment.