Skip to content

Commit

Permalink
[XLA:CPU] Use absl::call_once to lazily initialize kernel and compa…
Browse files Browse the repository at this point in the history
…rator functions

This simplifies the code a little and should otherwise be performance neutral.

It is also a bit easier to understand whether the code is safe or not. Previously, sort_thunk does:
```c++
  LessThan* less_than = less_that_ptr_.load()
  if (less_than == nullptr) {
    Comparator comparator = ...
    absl::MutexLock lock(&mutex_);
    less_than_ = ...;
    less_than_ptr_.store(...);
  }
```
However, two racing threads might both observe that `less_than` is nullptr which results in both of them trying to acquire the mutex and populate both `less_than_` and `less_than_ptr_`.

The problem is that another thread may witness that `less_than_` is non-null without acquiring the mutex and thus may have its hands on objects in bad states.

While it is simple enough to recheck `less_than_ptr_` after the mutex is acquired, it is even simpler to just use `call_once`.

This has the added benefit of only using an acquire atomic operation internal to the `call_once` implementation vs the `seq_cst` load on `less_than_ptr_`.

PiperOrigin-RevId: 701435907
  • Loading branch information
majnemer authored and Google-ML-Automation committed Nov 30, 2024
1 parent 533eb56 commit 20d4636
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 46 deletions.
2 changes: 2 additions & 0 deletions xla/backends/cpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ cc_library(
"//xla/stream_executor:launch_dim",
"//xla/tsl/concurrency:async_value",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
Expand Down Expand Up @@ -966,6 +967,7 @@ cc_library(
"//xla/stream_executor:device_memory",
"//xla/tsl/concurrency:async_value",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/container:inlined_vector",
Expand Down
32 changes: 15 additions & 17 deletions xla/backends/cpu/runtime/kernel_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.

#define EIGEN_USE_THREADS

#include <atomic>
#include <cstddef>
#include <cstdint>
#include <memory>
Expand All @@ -28,13 +27,13 @@ limitations under the License.

#include "absl/algorithm/container.h"
#include "absl/base/attributes.h"
#include "absl/base/call_once.h"
#include "absl/base/optimization.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/numeric/bits.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include "xla/backends/cpu/runtime/buffer_allocations.h"
Expand Down Expand Up @@ -119,8 +118,7 @@ KernelThunk<num_arguments, num_results>::KernelThunk(
kernel_name_(std::move(kernel_name)),
thread_dim_(thread_dim),
min_alignment_(min_alignment),
call_once_(thread_dim_ == se::ThreadDim()),
kernel_ptr_(nullptr) {
call_once_(thread_dim_ == se::ThreadDim()) {
// Resize storage for arguments and results buffers if it is dynamic.
if constexpr (IsDynamic(num_arguments)) {
arguments_buffers_.resize(arguments_buffers.size());
Expand Down Expand Up @@ -206,20 +204,20 @@ KernelThunk<num_arguments, num_results>::ExecuteInternal(

// TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk
// initialization stage.
Kernel* kernel = kernel_ptr_.load(std::memory_order_acquire);

// Because thunks are owned by a parent CpuExecutable, we can safely assume
// that kernel pointer will not change after we find it the first time.
if (ABSL_PREDICT_FALSE(kernel == nullptr)) {
TF_ASSIGN_OR_RETURN(XLA_CPU_Kernel * kernel_fn,
params.function_registry->FindKernel(kernel_name_));

absl::MutexLock lock(&mutex_);
if ((kernel = kernel_ptr_.load(std::memory_order_relaxed)) == nullptr) {
kernel = &kernel_.emplace(num_kernel_args_, kernel_fn);
kernel_ptr_.store(kernel, std::memory_order_release);
absl::call_once(kernel_init_flag_, [&]() {
// Because thunks are owned by a parent CpuExecutable, we can safely assume
// that kernel pointer will not change after we find it the first time.
absl::StatusOr<Thunk::FunctionRegistry::Kernel> kernel_fn =
params.function_registry->FindKernel(kernel_name_);

if (ABSL_PREDICT_TRUE(kernel_fn.ok())) {
kernel_.emplace(num_kernel_args_, *kernel_fn);
} else {
kernel_ = std::move(kernel_fn.status());
}
}
});
TF_RETURN_IF_ERROR(kernel_.status());
Kernel* kernel = &kernel_.value();

// Use a fast path if kernel called just once.
if (ABSL_PREDICT_TRUE(call_once_)) {
Expand Down
8 changes: 3 additions & 5 deletions xla/backends/cpu/runtime/kernel_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@ limitations under the License.
#include <type_traits>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/base/call_once.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/runtime/kernel.h"
#include "xla/backends/cpu/runtime/kernel_c_api.h"
Expand Down Expand Up @@ -120,9 +119,8 @@ class KernelThunk : public Thunk {
bool call_once_;

// Lazily loaded host kernel corresponding to `kernel_name_`.
absl::Mutex mutex_;
std::optional<Kernel> kernel_ ABSL_GUARDED_BY(mutex_);
std::atomic<Kernel*> kernel_ptr_; // pointer to `kernel_`
absl::once_flag kernel_init_flag_;
absl::StatusOr<Kernel> kernel_;

// Pre-initialized kernel arguments that are updated with memory addresses
// before the kernel launch.
Expand Down
46 changes: 25 additions & 21 deletions xla/backends/cpu/runtime/sort_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/call_once.h"
#include "absl/base/dynamic_annotations.h"
#include "absl/base/optimization.h"
#include "absl/container/inlined_vector.h"
Expand All @@ -39,7 +40,6 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/layout_util.h"
Expand Down Expand Up @@ -115,8 +115,7 @@ SortThunk::SortThunk(Info info, absl::Span<const Input> inputs,
dimension_(dimension),
is_stable_(is_stable),
direction_(direction),
less_than_(std::move(less_than)),
less_than_ptr_(&*less_than_) {}
less_than_(std::move(less_than)) {}

SortThunk::SortThunk(Info info, absl::Span<const Input> inputs,
int64_t dimension, bool is_stable,
Expand All @@ -127,8 +126,7 @@ SortThunk::SortThunk(Info info, absl::Span<const Input> inputs,
dimension_(dimension),
is_stable_(is_stable),
direction_(direction),
comparator_name_(std::move(comparator_name)),
less_than_ptr_(nullptr) {}
comparator_name_(std::move(comparator_name)) {}

namespace {

Expand Down Expand Up @@ -789,25 +787,31 @@ tsl::AsyncValueRef<SortThunk::ExecuteEvent> SortThunk::Execute(
input.slice.ToString(), data.back().opaque());
}

LessThan* less_than = less_than_ptr_.load();

// Because thunks are owned by a parent CpuExecutable, we can safely assume
// that comparator pointer will not change after we find it the first time,
// and we can create a comparator adaptor to a LessThan function.
if (ABSL_PREDICT_FALSE(less_than == nullptr)) {
TF_ASSIGN_OR_RETURN(
FunctionRegistry::Comparator comparator,
params.function_registry->FindComparator(comparator_name_));

absl::MutexLock lock(&mutex_);
less_than_ = [comparator](const void** data) {
bool result;
comparator(&result, nullptr, data, nullptr, nullptr, nullptr);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&result, sizeof(result));
return result;
};
less_than_ptr_.store(less_than = &*less_than_);
}
absl::call_once(less_than_init_flag_, [&]() {
if (less_than_.ok()) {
// `less_than_` may already be initialized in the constructor.
return;
}
absl::StatusOr<FunctionRegistry::Comparator> comparator =
params.function_registry->FindComparator(comparator_name_);

if (ABSL_PREDICT_TRUE(comparator.ok())) {
less_than_ = [comparator](const void** data) {
bool result;
(*comparator)(&result, nullptr, data, nullptr, nullptr, nullptr);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&result, sizeof(result));
return result;
};
} else {
less_than_ = std::move(comparator.status());
}
});

TF_RETURN_IF_ERROR(less_than_.status());
LessThan* less_than = &less_than_.value();

TF_RETURN_IF_ERROR(SortInplace(absl::MakeSpan(data), shapes, dimension_,
is_stable_, less_than, direction_));
Expand Down
6 changes: 3 additions & 3 deletions xla/backends/cpu/runtime/sort_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <string>
#include <vector>

#include "absl/base/call_once.h"
#include "absl/base/thread_annotations.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -84,9 +85,8 @@ class SortThunk final : public Thunk {
std::string comparator_name_;

// Lazily resolved LessThan comparator function.
absl::Mutex mutex_;
std::optional<LessThan> less_than_ ABSL_GUARDED_BY(mutex_);
std::atomic<LessThan*> less_than_ptr_; // pointer to `less_than_`
absl::once_flag less_than_init_flag_;
absl::StatusOr<LessThan> less_than_;
};

} // namespace xla::cpu
Expand Down

0 comments on commit 20d4636

Please sign in to comment.