Skip to content

Commit

Permalink
[tsl:concurrency] Add MakeAsyncValueRef overloads with automatic resu…
Browse files Browse the repository at this point in the history
…lt type inference

PiperOrigin-RevId: 666482145
  • Loading branch information
ezhulenev authored and copybara-github committed Aug 22, 2024
1 parent 6036e64 commit 21f90d1
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
22 changes: 20 additions & 2 deletions xla/tsl/concurrency/async_value_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -969,10 +969,18 @@ template <typename T, typename F, typename R = std::invoke_result_t<F>,
std::enable_if_t<std::is_constructible_v<T, R>>* = nullptr>
AsyncValueRef<T> MakeAsyncValueRef(AsyncValue::Executor& executor, F&& f) {
auto result = MakeUnconstructedAsyncValueRef<T>();
executor.Execute([result, f = std::forward<F>(f)] { result.emplace(f()); });
executor.Execute(
[result, f = std::forward<F>(f)]() mutable { result.emplace(f()); });
return result;
}

// A `MakeAsyncValueRef` overload that automatically infers the type of result
// from `f`.
template <typename F, typename R = std::invoke_result_t<F>>
AsyncValueRef<R> MakeAsyncValueRef(AsyncValue::Executor& executor, F&& f) {
return MakeAsyncValueRef<R>(executor, std::forward<F>(f));
}

// Allocates an AsyncValueRef that is constructed from the result of calling an
// `f` on a user-provided `executor`. `F` must return an absl::StatusOr<U>, and
// result of type `T` must be constructible from `U`.
Expand All @@ -988,7 +996,7 @@ template <typename T, typename F, typename R = std::invoke_result_t<F>,
std::is_constructible_v<T, typename R::value_type>>* = nullptr>
AsyncValueRef<T> TryMakeAsyncValueRef(AsyncValue::Executor& executor, F&& f) {
auto result = MakeUnconstructedAsyncValueRef<T>();
executor.Execute([result, f = std::forward<F>(f)] {
executor.Execute([result, f = std::forward<F>(f)]() mutable {
absl::StatusOr<typename R::value_type> status_or = f();
if (ABSL_PREDICT_TRUE(status_or.ok())) {
result.emplace(std::move(status_or).value());
Expand All @@ -999,6 +1007,16 @@ AsyncValueRef<T> TryMakeAsyncValueRef(AsyncValue::Executor& executor, F&& f) {
return result;
}

// A `TryMakeAsyncValueRef` overload that automatically infers the type of
// result from `f`.
template <typename F, typename R = std::invoke_result_t<F>,
std::enable_if_t<internal::is_status_or_v<R>>* = nullptr>
AsyncValueRef<typename R::value_type> TryMakeAsyncValueRef(
AsyncValue::Executor& executor, F&& f) {
return TryMakeAsyncValueRef<typename R::value_type>(executor,
std::forward<F>(f));
}

//===----------------------------------------------------------------------===//
// Constructing non-reference-counted values in user provided storage.
//===----------------------------------------------------------------------===//
Expand Down
36 changes: 36 additions & 0 deletions xla/tsl/concurrency/async_value_ref_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,17 @@ TEST(AsyncValueRefTest, MakeAsyncValueRef) {
EXPECT_EQ(ref.get(), 42.0f);
}

{ // Make AsyncValueRef with automatic type inference.
AsyncValueRef<float> ref =
MakeAsyncValueRef(executor, []() -> float { return 42.0f; });

EXPECT_FALSE(ref.IsAvailable());
EXPECT_EQ(executor.Quiesce(), 1);

EXPECT_TRUE(ref.IsAvailable());
EXPECT_EQ(ref.get(), 42.0f);
}

{ // Make AsyncValueRef from a function that returns a StatusOr value.
AsyncValueRef<float> ref = TryMakeAsyncValueRef<float>(
executor, []() -> absl::StatusOr<float> { return 42.0f; });
Expand All @@ -445,6 +456,18 @@ TEST(AsyncValueRefTest, MakeAsyncValueRef) {
EXPECT_EQ(ref.get(), 42.0f);
}

{ // Make AsyncValueRef from a function that returns a StatusOr value with
// automatic type inference.
AsyncValueRef<float> ref = TryMakeAsyncValueRef(
executor, []() -> absl::StatusOr<float> { return 42.0f; });

EXPECT_FALSE(ref.IsAvailable());
EXPECT_EQ(executor.Quiesce(), 1);

EXPECT_TRUE(ref.IsAvailable());
EXPECT_EQ(ref.get(), 42.0f);
}

{ // Make AsyncValueRef from a function that returns a StatusOr error.
AsyncValueRef<float> ref = TryMakeAsyncValueRef<float>(
executor,
Expand All @@ -456,6 +479,19 @@ TEST(AsyncValueRefTest, MakeAsyncValueRef) {
EXPECT_TRUE(ref.IsError());
EXPECT_EQ(ref.GetError(), absl::InternalError("test"));
}

{ // Make AsyncValueRef from a function that returns a StatusOr error with
// automatic type inference.
AsyncValueRef<float> ref = TryMakeAsyncValueRef(
executor,
[]() -> absl::StatusOr<float> { return absl::InternalError("test"); });

EXPECT_FALSE(ref.IsAvailable());
EXPECT_EQ(executor.Quiesce(), 1);

EXPECT_TRUE(ref.IsError());
EXPECT_EQ(ref.GetError(), absl::InternalError("test"));
}
}

TEST(AsyncValueRefTest, MapAvailableOnExecutor) {
Expand Down

0 comments on commit 21f90d1

Please sign in to comment.