Skip to content

Commit

Permalink
[xla:collectives] NFC: Move CliqueIsCallback alias to gpu_executable_…
Browse files Browse the repository at this point in the history
…run_options

PiperOrigin-RevId: 702238215
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Dec 3, 2024
1 parent 42a164f commit 50098e0
Show file tree
Hide file tree
Showing 15 changed files with 83 additions and 70 deletions.
3 changes: 2 additions & 1 deletion xla/core/collectives/BUILD
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
load("//xla/tsl:tsl.bzl", "internal_visibility")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [":friends"],
default_visibility = internal_visibility([":friends"]),
licenses = ["notice"],
)

Expand Down
3 changes: 3 additions & 0 deletions xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ cc_library(
hdrs = ["nccl_id_store.h"],
deps = [
"//xla:status_macros",
"//xla:util",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:global_device_id",
"//xla/service/gpu/runtime:nccl_api",
Expand All @@ -228,6 +230,7 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
Expand Down
22 changes: 15 additions & 7 deletions xla/pjrt/gpu/nccl_id_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,45 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/service/gpu/runtime/nccl_api.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {

absl::StatusOr<CliqueId> NcclIdStore::GetNcclUniqueId(
const gpu::NcclCliqueKey& key) {
absl::StatusOr<CliqueId> NcclIdStore::GetNcclUniqueId(const CliqueKey& key) {
auto* gpu_key = tsl::down_cast<const gpu::NcclCliqueKey*>(&key);
if (gpu_key == nullptr) {
return InvalidArgument("Expected GPU clique key");
}

// The caller must ensure that threads calling this method concurrently have
// unique keys, otherwise the global key-value store may hold the wrong value.
{
absl::MutexLock lock(&mu_);
auto it = cache_.find(key);
auto it = cache_.find(*gpu_key);
if (it != cache_.end()) {
return it->second;
}
}
CliqueId clique_id;
int primary_node_id = device_to_node_.at(key.devices()[0]);
int primary_node_id = device_to_node_.at(gpu_key->devices()[0]);
if (node_id_ == primary_node_id) {
TF_ASSIGN_OR_RETURN(clique_id, gpu::NcclApi::Default()->GetUniqueId());
TF_RETURN_IF_ERROR(kv_store_->Set(key.ToString(), clique_id.ToString()));
TF_RETURN_IF_ERROR(
kv_store_->Set(gpu_key->ToString(), clique_id.ToString()));
} else {
TF_ASSIGN_OR_RETURN(std::string id_str,
kv_store_->Get(key.ToString(), absl::Minutes(10)));
kv_store_->Get(gpu_key->ToString(), absl::Minutes(10)));
clique_id = CliqueId(id_str);
}
absl::MutexLock lock(&mu_);
auto result = cache_.emplace(key, std::move(clique_id));
auto result = cache_.emplace(*gpu_key, std::move(clique_id));
TF_RET_CHECK(result.second) << "Unique ID already in cache.";
return result.first->second;
}
Expand Down
3 changes: 2 additions & 1 deletion xla/pjrt/gpu/nccl_id_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/global_device_id.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"
Expand All @@ -42,7 +43,7 @@ class NcclIdStore {
device_to_node_(std::move(device_to_node)),
kv_store_(std::move(kv_store)) {}

absl::StatusOr<CliqueId> GetNcclUniqueId(const gpu::NcclCliqueKey& key);
absl::StatusOr<CliqueId> GetNcclUniqueId(const CliqueKey& key);

private:
const int node_id_;
Expand Down
6 changes: 3 additions & 3 deletions xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1175,8 +1175,8 @@ absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
if (num_nodes > 1) {
auto nccl_id_store = std::make_shared<NcclIdStore>(node_id, device_to_node,
std::move(kv_store));
gpu_executable_run_options->set_nccl_clique_id_callback(
[nccl_id_store](const gpu::NcclCliqueKey& key) {
gpu_executable_run_options->set_clique_id_callback(
[nccl_id_store](const CliqueKey& key) {
return nccl_id_store->GetNcclUniqueId(key);
});
}
Expand Down Expand Up @@ -1300,7 +1300,7 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetStreamExecutorGpuClient(

auto gpu_run_options = std::make_unique<gpu::GpuExecutableRunOptions>();
if (options.enable_mock_nccl) {
gpu_run_options->set_enable_mock_nccl_collectives();
gpu_run_options->set_enable_mock_collectives();
}

static const bool xla_gpu_require_exclusive_lock =
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,12 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//xla:executable_run_options",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/service:global_device_id",
"//xla/service/gpu/runtime:nccl_clique_key",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

Expand Down
6 changes: 3 additions & 3 deletions xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ class ResourceRequests : public Thunk::ResourceRequests {

bool is_local = r.key.devices().size() == r.num_local_participants;
TF_ASSIGN_OR_RETURN(
const NcclCliqueIdCallback* clique_id_callback,
GetNcclCliqueIdCallback(params.nccl_clique_id_callback, is_local));
const CliqueIdCallback* clique_id_callback,
GetCliqueIdCallback(params.nccl_clique_id_callback, is_local));

int64_t max_channels = r.key.stream_kind() == AsyncStreamKind::kCollective
? params.collective_max_nchannels
Expand Down Expand Up @@ -348,7 +348,7 @@ absl::Status ExecuteThunks(
run_options->run_options().gpu_executable_run_options()
? run_options->run_options()
.gpu_executable_run_options()
->enable_mock_nccl_collectives()
->enable_mock_collectives()
: false;

int64_t collective_max_nchannels =
Expand Down
18 changes: 7 additions & 11 deletions xla/service/gpu/gpu_executable_run_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ limitations under the License.

#include "xla/executable_run_options.h"
#include "xla/service/global_device_id.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"

namespace xla {
namespace gpu {
namespace xla::gpu {

GpuExecutableRunOptions& GpuExecutableRunOptions::set_gpu_global_device_ids(
std::optional<std::map<int, GlobalDeviceId>> gpu_global_device_ids) {
Expand All @@ -37,16 +35,14 @@ GpuExecutableRunOptions::gpu_global_device_ids() const {
return gpu_global_device_ids_;
}

GpuExecutableRunOptions& GpuExecutableRunOptions::set_nccl_clique_id_callback(
NcclCliqueIdCallback nccl_clique_id_callback) {
nccl_clique_id_callback_ = std::move(nccl_clique_id_callback);
GpuExecutableRunOptions& GpuExecutableRunOptions::set_clique_id_callback(
CliqueIdCallback clique_id_callback) {
clique_id_callback_ = std::move(clique_id_callback);
return *this;
}

const NcclCliqueIdCallback& GpuExecutableRunOptions::nccl_clique_id_callback()
const {
return nccl_clique_id_callback_;
const CliqueIdCallback& GpuExecutableRunOptions::clique_id_callback() const {
return clique_id_callback_;
}

} // namespace gpu
} // namespace xla
} // namespace xla::gpu
36 changes: 19 additions & 17 deletions xla/service/gpu/gpu_executable_run_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@ limitations under the License.
#ifndef XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_
#define XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_

#include <functional>
#include <map>
#include <optional>

#include "absl/status/statusor.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/executable_run_options.h"
#include "xla/service/global_device_id.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"

namespace xla {
namespace gpu {
namespace xla::gpu {

// A callback to get a unique clique id.
using CliqueIdCallback = // NOLINT
std::function<absl::StatusOr<CliqueId>(const CliqueKey&)>;

// GPU-specific executable options.
// We keep these separate from ExecutableRunOptions to avoid adding
Expand All @@ -40,11 +46,10 @@ class GpuExecutableRunOptions {
const std::optional<std::map<int, GlobalDeviceId>>& gpu_global_device_ids()
const;

// Callback that returns a ncclUniqueId encoded as a string for a group of
// communicating GPU devices. Used only on NVidia GPUs.
GpuExecutableRunOptions& set_nccl_clique_id_callback(
NcclCliqueIdCallback nccl_clique_id_callback);
const NcclCliqueIdCallback& nccl_clique_id_callback() const;
// Callback that returns a unique clieque id for a given clique key.
GpuExecutableRunOptions& set_clique_id_callback(
CliqueIdCallback clique_id_callback);
const CliqueIdCallback& clique_id_callback() const;

// Whether the run requires an exclusive lock on the GPU.
bool requires_exclusive_lock_on_gpu() const {
Expand All @@ -57,24 +62,21 @@ class GpuExecutableRunOptions {
return *this;
}

bool enable_mock_nccl_collectives() const {
return enable_mock_nccl_collectives_;
}
bool enable_mock_collectives() const { return enable_mock_collectives_; }

// Enables mocking nccl collective operations on the GPU.
GpuExecutableRunOptions& set_enable_mock_nccl_collectives() {
enable_mock_nccl_collectives_ = true;
GpuExecutableRunOptions& set_enable_mock_collectives() {
enable_mock_collectives_ = true;
return *this;
}

private:
bool requires_exclusive_lock_on_gpu_ = false;
bool enable_mock_nccl_collectives_ = false;
bool enable_mock_collectives_ = false;
std::optional<std::map<int, GlobalDeviceId>> gpu_global_device_ids_;
NcclCliqueIdCallback nccl_clique_id_callback_;
CliqueIdCallback clique_id_callback_;
};

} // namespace gpu
} // namespace xla
} // namespace xla::gpu

#endif // XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_
2 changes: 2 additions & 0 deletions xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,13 @@ cc_library(
"//xla:types",
"//xla:util",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/service:global_device_id",
"//xla/service:lockable",
"//xla/service:rendezvous",
"//xla/service/gpu:gpu_executable_run_options",
"//xla/stream_executor:stream_executor_h",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
Expand Down
18 changes: 10 additions & 8 deletions xla/service/gpu/runtime/nccl_clique.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ limitations under the License.
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/debug_options_flags.h"
#include "xla/executable_run_options.h"
#include "xla/service/global_device_id.h"
#include "xla/service/gpu/gpu_executable_run_options.h"
#include "xla/service/gpu/runtime/nccl_api.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"
#include "xla/service/lockable.h"
Expand All @@ -60,24 +62,24 @@ limitations under the License.
namespace xla::gpu {

//===----------------------------------------------------------------------===//
// NcclCliqueIdCallback
// CliqueIdCallback
//===----------------------------------------------------------------------===//

bool IsGlobalNcclConfig() {
static const char* const nccl_comm_id = std::getenv("NCCL_COMM_ID");
return nccl_comm_id != nullptr;
}

absl::StatusOr<const NcclCliqueIdCallback*> GetNcclCliqueIdCallback(
const NcclCliqueIdCallback* clique_id_callback, bool is_local) {
absl::StatusOr<const CliqueIdCallback*> GetCliqueIdCallback(
const CliqueIdCallback* clique_id_callback, bool is_local) {
if (clique_id_callback != nullptr) return clique_id_callback;

TF_RET_CHECK(is_local || IsGlobalNcclConfig())
<< "If non-local devices are taking part of a collective API on "
"GPU, the nccl_clique_id_callback must be provided by the client.";

static auto* local_callback = new NcclCliqueIdCallback(
[](const NcclCliqueKey&) { return NcclApi::Default()->GetUniqueId(); });
static auto* local_callback = new CliqueIdCallback(
[](const CliqueKey&) { return NcclApi::Default()->GetUniqueId(); });
return local_callback;
}

Expand Down Expand Up @@ -252,8 +254,8 @@ static auto DeviceRanksToString(absl::Span<const NcclApi::DeviceRank> ranks) {
// all participating ranks that own a shared pointer).
static absl::StatusOr<std::shared_ptr<NcclClique::Lock>> InitializeNcclClique(
se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key,
const NcclCliqueIdCallback& clique_id_callback,
int32_t num_local_participants, RankId rank, NcclApi::Config& config) {
const CliqueIdCallback& clique_id_callback, int32_t num_local_participants,
RankId rank, NcclApi::Config& config) {
int nranks = clique_key.devices().size();
VLOG(3) << "Initialize NCCL clique " << clique_key.ToString() << " rank #"
<< rank << "; num_local_participants=" << num_local_participants;
Expand Down Expand Up @@ -494,7 +496,7 @@ using AcquiredCliquesMap = NcclClique::AcquiredCliquesMap;

absl::StatusOr<std::shared_ptr<NcclClique::Lock>> AcquireNcclClique(
se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key,
const NcclCliqueIdCallback& clique_id_callback, RankId rank,
const CliqueIdCallback& clique_id_callback, RankId rank,
size_t num_local_participants, const AcquiredCliquesMap& acquired_cliques,
int64_t max_nchannels) {
VLOG(2) << "Acquire NCCL clique " << clique_key.ToString() << "; run"
Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/runtime/nccl_clique.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/executable_run_options.h"
#include "xla/service/gpu/gpu_executable_run_options.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"
#include "xla/service/lockable.h"
#include "xla/stream_executor/stream_executor.h"
Expand Down Expand Up @@ -67,8 +68,8 @@ bool IsGlobalNcclConfig();

// Returns a clique id callback passed as an argument if it's not null or a
// default callback to get create a clique id if we are running in local mode.
absl::StatusOr<const NcclCliqueIdCallback*> GetNcclCliqueIdCallback(
const NcclCliqueIdCallback* clique_id_callback, // may be null
absl::StatusOr<const CliqueIdCallback*> GetCliqueIdCallback(
const CliqueIdCallback* clique_id_callback, // may be null
bool is_local);

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -169,7 +170,7 @@ class NcclClique : public Lockable<NcclCliqueCommunicators, NcclCliqueName> {
// cliques.
absl::StatusOr<std::shared_ptr<NcclClique::Lock>> AcquireNcclClique(
se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key,
const NcclCliqueIdCallback& clique_id_callback, RankId rank,
const CliqueIdCallback& clique_id_callback, RankId rank,
size_t num_local_participants,
const NcclClique::AcquiredCliquesMap& acquired_cliques,
int64_t max_nchannels = 0);
Expand Down
8 changes: 0 additions & 8 deletions xla/service/gpu/runtime/nccl_clique_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,6 @@ class NcclCliqueKey : public CliqueKey {
bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b);
bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b);

//===----------------------------------------------------------------------===//
// NcclCliqueId
//===----------------------------------------------------------------------===//

// A callback to get a unique clique id (see `ncclUniqueId` documentation).
using NcclCliqueIdCallback = // NOLINT
std::function<absl::StatusOr<CliqueId>(const NcclCliqueKey&)>;

} // namespace xla::gpu

#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_CLIQUE_KEY_H_
Loading

0 comments on commit 50098e0

Please sign in to comment.