Skip to content

Commit

Permalink
[xla:gpu] Pass custom-call results as xla:ffi results to handlers
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621980275
  • Loading branch information
ezhulenev authored and copybara-github committed Apr 5, 2024
1 parent b0c6c26 commit b31cc49
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 109 deletions.
95 changes: 69 additions & 26 deletions xla/ffi/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,45 @@ template <PrimitiveType dtype> using BufferR3 = Buffer<dtype, 3>;
template <PrimitiveType dtype> using BufferR4 = Buffer<dtype, 4>;
// clang-format on

namespace internal {

inline BufferBase DecodeBuffer(XLA_FFI_Buffer* buf) {
size_t size_bytes = primitive_util::ByteWidth(PrimitiveType(buf->dtype));
for (int64_t i = 0; i < buf->rank; ++i) size_bytes *= buf->dims[i];

BufferBase buffer;
buffer.dtype = PrimitiveType(buf->dtype);
buffer.data = se::DeviceMemoryBase(buf->data, size_bytes);
buffer.dimensions = absl::MakeConstSpan(buf->dims, buf->rank);
return buffer;
}

template <PrimitiveType dtype, size_t rank>
std::optional<Buffer<dtype, rank>> DecodeBuffer(XLA_FFI_Buffer* buf,
DiagnosticEngine& diagnostic) {
if (auto buf_dtype = PrimitiveType(buf->dtype);
XLA_FFI_PREDICT_FALSE(buf_dtype != dtype)) {
return diagnostic.Emit("Wrong buffer dtype: expected ")
<< primitive_util::LowercasePrimitiveTypeName(dtype) << " but got "
<< primitive_util::LowercasePrimitiveTypeName(buf_dtype);
}

if constexpr (rank != internal::kDynamicRank) {
if (XLA_FFI_PREDICT_FALSE(buf->rank != rank)) {
return diagnostic.Emit("Wrong buffer rank: expected ")
<< rank << " but got " << buf->rank;
}
}

Buffer<dtype, rank> buffer;
buffer.data =
se::DeviceMemory<NativeType<dtype>>(se::DeviceMemoryBase(buf->data));
buffer.dimensions = absl::MakeConstSpan(buf->dims, buf->rank);
return buffer;
}

} // namespace internal

//===----------------------------------------------------------------------===//
// Arguments binding
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -117,16 +156,7 @@ struct ArgDecoding<BufferBase> {
<< XLA_FFI_ArgType_BUFFER << " but got " << type;
}

auto* buf = reinterpret_cast<XLA_FFI_Buffer*>(arg);

size_t size_bytes = primitive_util::ByteWidth(PrimitiveType(buf->dtype));
for (int64_t i = 0; i < buf->rank; ++i) size_bytes *= buf->dims[i];

BufferBase buffer;
buffer.dtype = PrimitiveType(buf->dtype);
buffer.data = se::DeviceMemoryBase(buf->data, size_bytes);
buffer.dimensions = absl::MakeConstSpan(buf->dims, buf->rank);
return buffer;
return internal::DecodeBuffer(reinterpret_cast<XLA_FFI_Buffer*>(arg));
}
};

Expand All @@ -140,27 +170,40 @@ struct ArgDecoding<Buffer<dtype, rank>> {
<< XLA_FFI_ArgType_BUFFER << " but got " << type;
}

auto* buf = reinterpret_cast<XLA_FFI_Buffer*>(arg);
return internal::DecodeBuffer<dtype, rank>(
reinterpret_cast<XLA_FFI_Buffer*>(arg), diagnostic);
}
};

//===----------------------------------------------------------------------===//
// Results decoding
//===----------------------------------------------------------------------===//

if (auto actual_dtype = PrimitiveType(buf->dtype);
XLA_FFI_PREDICT_FALSE(actual_dtype != dtype)) {
return diagnostic.Emit("Wrong buffer dtype: expected ")
<< primitive_util::LowercasePrimitiveTypeName(dtype) << " but got "
<< primitive_util::LowercasePrimitiveTypeName(actual_dtype);
template <>
struct RetDecoding<BufferBase> {
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
static std::optional<Result<BufferBase>> Decode(
XLA_FFI_RetType type, void* arg, DiagnosticEngine& diagnostic) {
if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_RetType_BUFFER)) {
return diagnostic.Emit("Wrong result type: expected ")
<< XLA_FFI_RetType_BUFFER << " but got " << type;
}
return internal::DecodeBuffer(reinterpret_cast<XLA_FFI_Buffer*>(arg));
}
};

if constexpr (rank != internal::kDynamicRank) {
if (XLA_FFI_PREDICT_FALSE(buf->rank != rank)) {
return diagnostic.Emit("Wrong buffer rank: expected ")
<< rank << " but got " << buf->rank;
}
template <PrimitiveType dtype, size_t rank>
struct RetDecoding<Buffer<dtype, rank>> {
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
static std::optional<Result<Buffer<dtype, rank>>> Decode(
XLA_FFI_RetType type, void* arg, DiagnosticEngine& diagnostic) {
if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_RetType_BUFFER)) {
return diagnostic.Emit("Wrong result type: expected ")
<< XLA_FFI_RetType_BUFFER << " but got " << type;
}

Buffer<dtype, rank> buffer;
buffer.data = se::DeviceMemory<internal::NativeType<dtype>>(
se::DeviceMemoryBase(buf->data));
buffer.dimensions = absl::MakeConstSpan(buf->dims, buf->rank);
return buffer;
return internal::DecodeBuffer<dtype, rank>(
reinterpret_cast<XLA_FFI_Buffer*>(arg), diagnostic);
}
};

Expand Down
53 changes: 30 additions & 23 deletions xla/service/gpu/custom_call_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,13 @@ TEST_F(CustomCallTest, WithStatusFailed) {
// XLA runtime custom calls provides type-safe custom call API
//===----------------------------------------------------------------------===//

static absl::Status AlwaysFail(ffi::BufferBase arg, int32_t value) {
static absl::Status AlwaysFail(ffi::Result<ffi::BufferBase>, int32_t value) {
return absl::InternalError(absl::StrCat("Uh oh, wrong value: ", value));
}

XLA_FFI_DEFINE_HANDLER(kAlwaysFail, AlwaysFail,
ffi::Ffi::Bind()
.Arg<ffi::BufferBase>() // arg
.Ret<ffi::BufferBase>() //
.Attr<int32_t>("value") // value
);
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$always_fail",
Expand All @@ -376,9 +376,9 @@ TEST_F(CustomCallTest, RuntimeCustomCallAlwaysFail) {
}

static absl::Status Memcpy(se::Stream* stream, ffi::BufferBase src,
ffi::BufferBase dst) {
ffi::Result<ffi::BufferBase> dst) {
return stream->MemcpyD2D(
&dst.data, src.data,
&dst->data, src.data,
absl::c_accumulate(src.dimensions, 1.0, std::multiplies<int64_t>()) *
sizeof(float));
}
Expand All @@ -387,7 +387,7 @@ XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy,
ffi::Ffi::Bind()
.Ctx<ffi::Stream>()
.Arg<ffi::BufferBase>() // src
.Arg<ffi::BufferBase>() // dst
.Ret<ffi::BufferBase>() // dst
);
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", PLATFORM,
kMemcpy);
Expand All @@ -405,13 +405,14 @@ TEST_F(CustomCallTest, ExportedFfiMemcpy) {
EXPECT_THAT(result.data<float>(), ::testing::Each(42));
}

static absl::Status HandleUserPointer(ffi::BufferBase, const std::string* str) {
static absl::Status HandleUserPointer(ffi::Result<ffi::BufferBase>,
const std::string* str) {
return absl::InternalError(*str);
}

XLA_FFI_DEFINE_HANDLER(kHandleUserPointer, HandleUserPointer,
ffi::Ffi::Bind()
.Arg<ffi::BufferBase>() // buffer for result
.Ret<ffi::BufferBase>() // buffer for result
.Attr<ffi::Pointer<std::string>>("message"));

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$user_data", PLATFORM,
Expand All @@ -435,14 +436,14 @@ TEST_F(CustomCallTest, PassUserPointerWithAttrs) {
}

bool is_ffi_invoked = false;
static absl::Status IsInvoked(ffi::BufferBase) {
static absl::Status IsInvoked(ffi::Result<ffi::BufferBase>) {
is_ffi_invoked = true;
return absl::OkStatus();
}

XLA_FFI_DEFINE_HANDLER(
kIsInvoked, IsInvoked,
ffi::Ffi::Bind().Arg<ffi::BufferBase>()); // Buffer for result (unused).
ffi::Ffi::Bind().Ret<ffi::BufferBase>()); // Buffer for result (unused).

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$isinvoked", PLATFORM,
kIsInvoked);
Expand Down Expand Up @@ -477,7 +478,8 @@ TEST_F(CustomCallTest, ExportedFfiUnknownTarget) {
// fusions/address_computation_fusion_test.cc

// Reusing kExpectedOpaque from the original test.
static absl::Status Opaque(ffi::BufferBase, const std::string* str) {
static absl::Status Opaque(ffi::Result<ffi::BufferBase>,
const std::string* str) {
std::string opaque(*str);
if (opaque != kExpectedOpaque)
return absl::InternalError(absl::StrFormat(
Expand All @@ -488,7 +490,7 @@ static absl::Status Opaque(ffi::BufferBase, const std::string* str) {

XLA_FFI_DEFINE_HANDLER(kOpaque, Opaque,
ffi::Ffi::Bind()
.Arg<ffi::BufferBase>() // Dummy result buffer.
.Ret<ffi::BufferBase>() // Dummy result buffer.
.Attr<ffi::Pointer<std::string>>("opaque"));

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$opaque", PLATFORM,
Expand All @@ -514,28 +516,31 @@ static absl::Status TokensChecker(std::vector<ffi::BufferBase> inputs,
return absl::OkStatus();
}

static absl::Status Tokens1Input(ffi::BufferBase input1, ffi::BufferBase,
static absl::Status Tokens1Input(ffi::BufferBase input1,
ffi::Result<ffi::BufferBase>,
const std::string* opaque) {
return TokensChecker({input1}, opaque);
}

static absl::Status Tokens2Inputs(ffi::BufferBase input1,
ffi::BufferBase input2, ffi::BufferBase,
ffi::BufferBase input2,
ffi::Result<ffi::BufferBase>,
const std::string* opaque) {
return TokensChecker({input1, input2}, opaque);
}

static absl::Status Tokens3Inputs(ffi::BufferBase input1,
ffi::BufferBase input2,
ffi::BufferBase input3, ffi::BufferBase,
ffi::BufferBase input3,
ffi::Result<ffi::BufferBase>,
const std::string* opaque) {
return TokensChecker({input1, input2, input3}, opaque);
}

XLA_FFI_DEFINE_HANDLER(kTokens1Input, Tokens1Input,
ffi::Ffi::Bind()
.Arg<ffi::BufferBase>() // 1 input buffer.
.Arg<ffi::BufferBase>() // Output buffer.
.Ret<ffi::BufferBase>() // Output buffer.
.Attr<ffi::Pointer<std::string>>("opaque"));

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens_1input",
Expand All @@ -545,7 +550,7 @@ XLA_FFI_DEFINE_HANDLER(kTokens2Inputs, Tokens2Inputs,
ffi::Ffi::Bind()
.Arg<ffi::BufferBase>() // 1st input buffer.
.Arg<ffi::BufferBase>() // 2nd input buffer.
.Arg<ffi::BufferBase>() // Output buffer.
.Ret<ffi::BufferBase>() // Output buffer.
.Attr<ffi::Pointer<std::string>>("opaque"));

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens_2inputs",
Expand All @@ -556,7 +561,7 @@ XLA_FFI_DEFINE_HANDLER(kTokens3Inputs, Tokens3Inputs,
.Arg<ffi::BufferBase>() // 1st input buffer.
.Arg<ffi::BufferBase>() // 2nd input buffer.
.Arg<ffi::BufferBase>() // 3rd input buffer.
.Arg<ffi::BufferBase>() // Output buffer.
.Ret<ffi::BufferBase>() // Output buffer.
.Attr<ffi::Pointer<std::string>>("opaque"));

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens_3inputs",
Expand Down Expand Up @@ -596,10 +601,12 @@ TEST_P(CustomCallTokensTest, ExportedFfiTokensTest) {
INSTANTIATE_TEST_SUITE_P(CustomCallTokensTest, CustomCallTokensTest,
::testing::ValuesIn(GetTokenTestCases()));

static absl::Status AlwaysSucceed(ffi::BufferBase) { return absl::OkStatus(); }
static absl::Status AlwaysSucceed(ffi::Result<ffi::BufferBase>) {
return absl::OkStatus();
}

XLA_FFI_DEFINE_HANDLER(kAlwaysSucceed, AlwaysSucceed,
ffi::Ffi::Bind().Arg<ffi::BufferBase>());
ffi::Ffi::Bind().Ret<ffi::BufferBase>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$always_succeed",
PLATFORM, kAlwaysSucceed);
Expand All @@ -619,7 +626,7 @@ TEST_F(CustomCallTest, ExportedFfiWithStatusSucceeded) {
// XLA:FFI handler for testing attributes decoding
//===----------------------------------------------------------------------===//

static absl::Status FfiAttributes(ffi::BufferBase,
static absl::Status FfiAttributes(ffi::Result<ffi::BufferBase>,
absl::Span<const int32_t> i32_arr) {
if (i32_arr.size() != 4)
return absl::InternalError("i32_arr size does not match");
Expand All @@ -632,7 +639,7 @@ static absl::Status FfiAttributes(ffi::BufferBase,

XLA_FFI_DEFINE_HANDLER(kFfiAttributes, FfiAttributes,
ffi::Ffi::Bind()
.Arg<ffi::BufferBase>()
.Ret<ffi::BufferBase>()
.Attr<absl::Span<const int32_t>>("i32_arr"));

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla.gpu.ffi_attributes",
Expand All @@ -656,7 +663,7 @@ TEST_F(CustomCallTest, FfiAttributes) {

static absl::Status MemcpyWithCalledComputation(
se::Stream* stream, se::OwningScratchAllocator<> scratch_allocator,
ffi::BufferBase src, ffi::BufferBase dst,
ffi::BufferBase src, ffi::Result<ffi::BufferBase> dst,
const HloComputation* called_computation) {
if (called_computation == nullptr)
return absl::InternalError("Called computation is not defined");
Expand All @@ -680,7 +687,7 @@ XLA_FFI_DEFINE_HANDLER(kMemcpyWithCalledComputation,
.Ctx<ffi::Stream>()
.Ctx<ffi::ScratchAllocator>() // scratch
.Arg<ffi::BufferBase>() // src
.Arg<ffi::BufferBase>() // dst
.Ret<ffi::BufferBase>() // dst
.Ctx<ffi::CalledComputation>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
Expand Down
Loading

0 comments on commit b31cc49

Please sign in to comment.