Skip to content

Commit

Permalink
PR #19869: [XLA] Wraps the FFI backend config opaque string with Cust…
Browse files Browse the repository at this point in the history
…omCallBackendConfig

Imported from GitHub PR #19869

Copybara import of the project:

--
f4419ec by Yunlong Liu <[email protected]>:

Proto changes.

Change CPU/GPU call sites.

Adds CPU impl and test.

Adds CPU tests.

Makes up another call site.

add back nooss tag

--
04cb10e by Yunlong Liu <[email protected]>:

fix protos and clean up code a bit

Merging this change closes #19869

COPYBARA_INTEGRATE_REVIEW=#19869 from yliu120:custom_call_config 04cb10e
PiperOrigin-RevId: 701886222
  • Loading branch information
yliu120 authored and Google-ML-Automation committed Dec 2, 2024
1 parent 069e841 commit 40e1c61
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 60 deletions.
43 changes: 43 additions & 0 deletions xla/pjrt/cpu/cpu_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,49 @@ TEST(TfrtCpuClientTest, ForwardUserDataToFfiHandler) {
*result_literal));
}

static absl::Status MemsetFromAttr(
float attr, ffi::Result<ffi::BufferR1<PrimitiveType::F32>> result) {
for (size_t i = 0; i < result->element_count(); ++i) {
result->typed_data()[i] = attr;
}
return absl::OkStatus();
}

XLA_FFI_DEFINE_HANDLER(kMemsetFromAttr, MemsetFromAttr,
ffi::Ffi::Bind()
.Attr<float>("attr")
.Ret<ffi::BufferR1<PrimitiveType::F32>>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "MemsetFromAttr", "HOST",
kMemsetFromAttr);

TEST(TfrtCpuClientTest, PassAttrToFfiHandler) {
static constexpr char const* kProgram = R"(
HloModule ffi_handler
ENTRY main {
ROOT %custom-call = f32[4] custom-call(),
custom_call_target="MemsetFromAttr",
api_version=API_VERSION_TYPED_FFI,
backend_config={"custom_call_config": {"attributes": "{attr = 3.0 : f32}"}}
})";

TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions()));

TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kProgram, {}));
XlaComputation xla_computation(hlo_module->ToProto());
TF_ASSERT_OK_AND_ASSIGN(auto executable,
client->Compile(xla_computation, {}));

ExecuteOptions opts;
auto result = executable->Execute(/*argument_handles=*/{{}}, opts);

TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<xla::Literal> result_literal,
result->at(0).at(0)->ToLiteralSync());
EXPECT_TRUE(LiteralTestUtil::Equal(
LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f, 3.0f}), *result_literal));
}

} // namespace

//===----------------------------------------------------------------------===//
Expand Down
43 changes: 43 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,49 @@ TEST(StreamExecutorGpuClientTest, ForwardUserDataToFfiHandler) {
*result_literal));
}

static absl::Status MemsetFromAttr(
se::Stream* stream, float attr,
ffi::Result<ffi::BufferR1<PrimitiveType::F32>> result) {
uint32_t pattern;
std::memcpy(&pattern, &attr, sizeof(pattern));

se::DeviceMemoryBase base = result->device_memory();
return stream->Memset32(&base, pattern, base.size());
}

XLA_FFI_DEFINE_HANDLER(kMemsetFromAttr, MemsetFromAttr,
ffi::Ffi::Bind()
.Ctx<ffi::Stream>()
.Attr<float>("attr")
.Ret<ffi::BufferR1<PrimitiveType::F32>>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "MemsetFromAttr",
PlatformUtil::CanonicalPlatformName("GPU").value(),
kMemsetFromAttr);

TEST(StreamExecutorGpuClientTest, PassAttrToFfiHandler) {
static constexpr char const* kProgram = R"(
HloModule ffi_handler
ENTRY main {
ROOT %custom-call = f32[4] custom-call(),
custom_call_target="MemsetFromAttr",
api_version=API_VERSION_TYPED_FFI,
backend_config={"custom_call_backend_config": {"attributes": "{attr = 3.0 : f32}"}}
})";

TF_ASSERT_OK_AND_ASSIGN(auto client,
GetStreamExecutorGpuClient(GpuClientOptions()));
TF_ASSERT_OK_AND_ASSIGN(auto executable,
CompileExecutable(kProgram, *client));

ExecuteOptions opts;
auto result = executable->Execute(/*argument_handles=*/{{}}, opts);
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<xla::Literal> result_literal,
ExtractSingleResult(result));
EXPECT_TRUE(LiteralTestUtil::Equal(
LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f, 3.0f}), *result_literal));
}

TEST(StreamExecutorGpuClientTest, ToLiteralAsync) {
TF_ASSERT_OK_AND_ASSIGN(auto client,
GetStreamExecutorGpuClient(GpuClientOptions()));
Expand Down
1 change: 1 addition & 0 deletions xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ cc_library(
"//xla/service:hlo_module_config",
"//xla/service:hlo_proto_cc",
"//xla/service:pattern_matcher",
"//xla/service/cpu:backend_config_proto_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
Expand Down
13 changes: 13 additions & 0 deletions xla/service/cpu/backend_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@ package xla.cpu;

import "xla/service/cpu/onednn_config.proto";

// Backend config for a general custom call instruction, e.g. XLA FFI.
message CustomCallBackendConfig {
// Generic configurations that can be parsed by XLA.
oneof raw_backend_config_oneof {
// An opaque ASCII string could be parsed by the compiler.
string opaque = 1;
// Attributes parsed by XLA FFI.
string attributes = 2;
}
}

// Backend config for XLA:CPU.
message BackendConfig {
// Number of partitions per outer dimension (in order, starting with
Expand All @@ -19,5 +30,7 @@ message BackendConfig {
OneDnnSoftmaxConfig onednn_softmax_config = 4;
// Configuration to be used by oneDNN convolution
OneDnnConvolutionConfig onednn_conv_config = 5;
// Configuration to be used by general custom call, e.g., FFI.
CustomCallBackendConfig custom_call_config = 6;
}
}
16 changes: 14 additions & 2 deletions xla/service/cpu/thunk_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ limitations under the License.
#include "xla/layout_util.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/cpu/backend_config.pb.h"
#include "xla/service/cpu/dot_op_emitter.h"
#include "xla/service/cpu/ir_emission_utils.h"
#include "xla/service/cpu/ir_emitter2.h"
Expand Down Expand Up @@ -964,13 +965,24 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitCustomCallThunk(
}

// Get backend config and buffer assignments.
auto backend_config = custom_call->opaque();
auto backend_config = custom_call->backend_config<BackendConfig>();
if (!backend_config.ok()) {
LOG(WARNING) << "Unable to parse backend config for custom call: "
<< backend_config.status().message() << "\n"
<< "Fall back to parse the opaque str.";
}
auto& backend_config_str =
!backend_config.ok()
? custom_call->opaque()
: ((version == API_VERSION_TYPED_FFI)
? backend_config->custom_call_config().attributes()
: backend_config->custom_call_config().opaque());
TF_ASSIGN_OR_RETURN(auto op_buffers,
GetCustomCallOpBuffers(instruction, buffer_assignment_));

return ThunkSequence::Of<CustomCallThunk>(ThunkInfo(instruction),
custom_call_target, op_buffers,
backend_config, version);
backend_config_str, version);
}

absl::StatusOr<ThunkSequence> ThunkEmitter::EmitSliceToDynamicThunk(
Expand Down
14 changes: 14 additions & 0 deletions xla/service/gpu/backend_configs.proto
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,19 @@ message CudnnfMHABackendConfig {
int32 sliding_window_length = 24;
}

// Backend config for a general custom call instruction, e.g. XLA FFI.
message CustomCallBackendConfig {
// Generic configurations that can be parsed by XLA.
oneof raw_backend_config_oneof {
// An opaque ASCII string could be parsed by the compiler.
string opaque = 1;
// Attributes parsed by XLA FFI.
string attributes = 2;
}
}

// Generic backend config for XLA:GPU
// Next-Id: 12
message GpuBackendConfig {
// Specifies which operation queue the current instruction will run on.
// A backend may have multiple operation queues to run instructions
Expand Down Expand Up @@ -298,6 +310,8 @@ message GpuBackendConfig {
CudnnNormBackendConfig cudnn_norm_backend_config = 8;

CudnnfMHABackendConfig cudnn_fmha_backend_config = 9;

CustomCallBackendConfig custom_call_backend_config = 11;
}

// This attribute instructs the latency-hiding scheduler to
Expand Down
17 changes: 17 additions & 0 deletions xla/service/gpu/custom_call_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,23 @@ TEST_F(CustomCallTest, RuntimeCustomCallAlwaysFail) {
EXPECT_THAT(status.message(), ::testing::HasSubstr("Uh oh, wrong value: 42"));
}

// Same as the above test but just pass attribute through
// the backend config proto string instead.
TEST_F(CustomCallTest, PassAttributesByBackendConfig) {
XlaBuilder b(TestName());
CustomCall(
&b, "__xla_test$$always_fail", /*operands=*/{},
ShapeUtil::MakeShape(F32, {}), /*opaque=*/
R"({"custom_call_backend_config": {"attributes": "{value = 42 : i32}"}})",
/*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
auto status = Execute(&b, {}).status();
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
EXPECT_THAT(status.message(), ::testing::HasSubstr("Uh oh, wrong value: 42"));
}

static absl::Status Memcpy(se::Stream* stream, ffi::AnyBuffer src,
ffi::Result<ffi::AnyBuffer> dst) {
se::DeviceMemoryBase dst_mem = dst->device_memory();
Expand Down
60 changes: 30 additions & 30 deletions xla/service/gpu/fusions/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,6 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
// For legacy custom calls we convert all API versions into the latest
// status-returning one and pass backend config as an opaque string.
CustomCallThunk::CustomCallTarget custom_call_target;
std::string opaque;

// For XLA FFI handlers we decode opaque backend config into attributes map
// at IR emission time, so that we do not need to parse MLIR at run time. For
Expand Down Expand Up @@ -695,47 +694,48 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
custom_call.api_version());
}

auto& backend_config_str = custom_call.raw_backend_config_string();
switch (custom_call.api_version()) {
case CustomCallApiVersion::API_VERSION_ORIGINAL:
case CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED:
if (!backend_config_str.empty()) {
opaque = backend_config_str;
}
break;

case CustomCallApiVersion::API_VERSION_TYPED_FFI:
if (!backend_config_str.empty()) {
mlir::Attribute attr = mlir::parseAttribute(
backend_config_str, ir_emitter_context.mlir_context());
if (auto dict = mlir::dyn_cast_or_null<mlir::DictionaryAttr>(attr)) {
TF_ASSIGN_OR_RETURN(attributes, xla::ffi::BuildAttributesMap(dict));
break;
}
return absl::InternalError(
"Unsupported backend config. Expected a string parsable into "
"dictionary attribute");
}
break;

default:
return Internal("Unknown custom-call API version enum value: %d",
custom_call.api_version());
auto backend_config = custom_call.backend_config<GpuBackendConfig>();
if (!backend_config.ok()) {
LOG(WARNING) << "Unable to parse backend config for custom call: "
<< backend_config.status().message() << "\n"
<< "Fall back to parse the raw backend config str.";
}

std::unique_ptr<Thunk> thunk;
auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(&fusion);

auto ffi_thunk = [&](Slices ops, Slices res) {
auto ffi_thunk =
[&](Slices ops,
Slices res) -> absl::StatusOr<std::unique_ptr<CustomCallThunk>> {
auto& called_computations = custom_call.called_computations();
auto& backend_config_str =
backend_config.ok()
? backend_config->custom_call_backend_config().attributes()
: custom_call.raw_backend_config_string();
if (!backend_config_str.empty()) {
mlir::Attribute attr = mlir::parseAttribute(
backend_config_str, ir_emitter_context.mlir_context());
auto dict = mlir::dyn_cast_or_null<mlir::DictionaryAttr>(attr);
if (dict == nullptr) {
return absl::InternalError(
"Unsupported backend config. Expected a string parsable into "
"dictionary attribute");
}
TF_ASSIGN_OR_RETURN(attributes, xla::ffi::BuildAttributesMap(dict));
}
return CustomCallThunk::Create(
thunk_info, call_target_name, registration->bundle, std::move(ops),
std::move(res), std::move(attributes),
called_computations.empty() ? nullptr : called_computations[0]);
};

auto legacy_thunk = [&](Slices ops, Slices res) {
auto legacy_thunk =
[&](Slices ops,
Slices res) -> absl::StatusOr<std::unique_ptr<CustomCallThunk>> {
std::string opaque =
backend_config.ok()
? backend_config->custom_call_backend_config().opaque()
: custom_call.raw_backend_config_string();
return CustomCallThunk::Create(
thunk_info, call_target_name, std::move(custom_call_target),
std::move(ops), std::move(res), std::move(opaque));
Expand Down
53 changes: 25 additions & 28 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,6 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
// For legacy custom calls we convert all API versions into the latest
// status-returning one and pass backend config as an opaque string.
CustomCallThunk::CustomCallTarget custom_call_target;
std::string opaque;

// For XLA FFI handlers we decode opaque backend config into attributes map
// at IR emission time, so that we do not need to parse MLIR at run time. For
Expand Down Expand Up @@ -1175,45 +1174,43 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
instr->api_version());
}

auto& backend_config_str = instr->raw_backend_config_string();
switch (instr->api_version()) {
case CustomCallApiVersion::API_VERSION_ORIGINAL:
case CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED:
if (!backend_config_str.empty()) {
opaque = backend_config_str;
}
break;
auto backend_config = instr->backend_config<GpuBackendConfig>();
if (!backend_config.ok()) {
LOG(WARNING) << "Unable to parse backend config for custom call: "
<< backend_config.status().message() << "\n"
<< "Fall back to parse the raw backend config str.";
}

case CustomCallApiVersion::API_VERSION_TYPED_FFI:
if (!backend_config_str.empty()) {
mlir::Attribute attr = mlir::parseAttribute(
backend_config_str, ir_emitter_context_->mlir_context());
if (auto dict = mlir::dyn_cast_or_null<mlir::DictionaryAttr>(attr)) {
TF_ASSIGN_OR_RETURN(attributes, xla::ffi::BuildAttributesMap(dict));
break;
}
auto ffi_thunk = [&]() -> absl::StatusOr<std::unique_ptr<CustomCallThunk>> {
auto& called_computations = instr->called_computations();
auto& backend_config_str =
backend_config.ok()
? backend_config->custom_call_backend_config().attributes()
: instr->raw_backend_config_string();
if (!backend_config_str.empty()) {
mlir::Attribute attr = mlir::parseAttribute(
backend_config_str, ir_emitter_context_->mlir_context());
auto dict = mlir::dyn_cast_or_null<mlir::DictionaryAttr>(attr);
if (dict == nullptr) {
return absl::InternalError(
"Unsupported backend config. Expected a string parsable into "
"dictionary attribute");
}
break;

default:
return Internal("Unknown custom-call API version enum value: %d",
instr->api_version());
}

auto ffi_thunk = [&] {
auto& called_computations = instr->called_computations();
TF_ASSIGN_OR_RETURN(attributes, xla::ffi::BuildAttributesMap(dict));
}
return CustomCallThunk::Create(
Thunk::ThunkInfo::WithProfileAnnotation(instr), call_target_name,
registration->bundle, std::move(operands), std::move(results),
std::move(attributes),
called_computations.empty() ? nullptr : called_computations[0]);
};

auto legacy_thunk = [&] {
auto legacy_thunk =
[&]() -> absl::StatusOr<std::unique_ptr<CustomCallThunk>> {
std::string opaque =
backend_config.ok()
? backend_config->custom_call_backend_config().opaque()
: instr->raw_backend_config_string();
return CustomCallThunk::Create(
Thunk::ThunkInfo::WithProfileAnnotation(instr), call_target_name,
std::move(custom_call_target), std::move(operands), std::move(results),
Expand Down

0 comments on commit 40e1c61

Please sign in to comment.