Skip to content

Commit

Permalink
[xla:collectives] NFC: Migrate XLA:GPU to strongly typed RankId
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702170919
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Dec 3, 2024
1 parent 9b1e229 commit 9c668fd
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 58 deletions.
2 changes: 1 addition & 1 deletion xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class ResourceRequests : public Thunk::ResourceRequests {
NcclClique::AcquiredCliquesMap cliques_map;

for (const CliqueRequest& r : ordered_cliques) {
std::optional<int64_t> rank = r.key.rank(params.global_device_id);
std::optional<RankId> rank = r.key.rank(params.global_device_id);

if (!rank.has_value()) {
return absl::InternalError(absl::StrCat(
Expand Down
6 changes: 6 additions & 0 deletions xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/collectives:nccl_communicator",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/service:collective_ops_utils",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:device_memory_allocator",
Expand Down Expand Up @@ -257,6 +258,7 @@ cc_library(
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/service:collective_ops_utils",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:device_memory_allocator",
Expand All @@ -282,6 +284,7 @@ cc_library(
"//xla:types",
"//xla:util",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/service:global_device_id",
"//xla/service:lockable",
"//xla/service:rendezvous",
Expand Down Expand Up @@ -313,6 +316,7 @@ cc_library(
compatible_with = get_compatible_with_portable(),
deps = [
"//xla/core/collectives:clique_id",
"//xla/core/collectives:rank_id",
"//xla/service:global_device_id",
"//xla/tsl/lib/gtl:int_type",
"@com_google_absl//absl/algorithm:container",
Expand Down Expand Up @@ -950,6 +954,7 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/hlo/ir:hlo",
"//xla/hlo/translate/mhlo_to_hlo:attribute_exporter",
"//xla/service:buffer_assignment",
Expand Down Expand Up @@ -1187,6 +1192,7 @@ cc_library(
":nccl_clique_key",
"//xla:executable_run_options",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/ffi:execution_context",
"//xla/hlo/ir:hlo",
"//xla/hlo/translate/mhlo_to_hlo:location_exporter",
Expand Down
23 changes: 15 additions & 8 deletions xla/service/gpu/runtime/nccl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
Expand All @@ -32,6 +33,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "xla/backends/gpu/collectives/nccl_communicator.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/primitive_util.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"
Expand Down Expand Up @@ -297,7 +299,7 @@ class DefaultNcclApi final : public NcclApi {

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>> CommSplit(
absl::Span<const Communicator* const> comms, int32_t color,
absl::Span<const int32_t> keys, std::optional<Config> config) final;
absl::Span<const RankId> keys, std::optional<Config> config) final;

absl::Status CommAbort(Communicator* comm) final;
absl::Status CommFinalize(Communicator* comm) final;
Expand Down Expand Up @@ -406,8 +408,9 @@ DefaultNcclApi::CommInitRanks(int32_t nranks, const NcclCliqueId& clique_id,
auto activate_context = ranks[i].device->Activate();

TF_ASSIGN_OR_RETURN(auto nccl_unique_id, AsNcclUniqueId(clique_id));
XLA_NCCL_RETURN_IF_ERROR(ncclCommInitRankConfig(
&comm_handles[i], nranks, nccl_unique_id, ranks[i].rank, &comm_config));
XLA_NCCL_RETURN_IF_ERROR(
ncclCommInitRankConfig(&comm_handles[i], nranks, nccl_unique_id,
ranks[i].rank.value(), &comm_config));
}
TF_RETURN_IF_ERROR(GroupEnd());

Expand All @@ -420,11 +423,15 @@ DefaultNcclApi::CommInitRanks(int32_t nranks, const NcclCliqueId& clique_id,

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
DefaultNcclApi::CommSplit(absl::Span<const Communicator* const> comms,
int32_t color, absl::Span<const int32_t> keys,
int32_t color, absl::Span<const RankId> keys,
std::optional<Config> config) {
auto rank_formatter = [](std::string* str, RankId rank) {
absl::StrAppend(str, rank.value());
};

VLOG(1) << absl::StreamFormat(
"Split %d NCCL communicators using color %d and keys: [%s]", comms.size(),
color, absl::StrJoin(keys, ","));
color, absl::StrJoin(keys, ",", rank_formatter));

#if !defined(TENSORFLOW_USE_ROCM) || TF_ROCM_VERSION >= 60000
if (keys.size() != comms.size()) {
Expand Down Expand Up @@ -456,9 +463,9 @@ DefaultNcclApi::CommSplit(absl::Span<const Communicator* const> comms,
for (size_t i = 0; i < comms.size(); ++i) {
VLOG(1) << "Split NCCL communicator " << comms[i] << " with color " << color
<< " and key " << keys[i];
XLA_NCCL_RETURN_IF_ERROR(ncclCommSplit(Cast(comms[i]), color, keys[i],
&split_comms_handles[i],
/*config=*/comm_config_ptr));
XLA_NCCL_RETURN_IF_ERROR(ncclCommSplit(
Cast(comms[i]), color, keys[i].value(), &split_comms_handles[i],
/*config=*/comm_config_ptr));
}
TF_RETURN_IF_ERROR(GroupEnd());

Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/runtime/nccl_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -118,11 +119,11 @@ class NcclApi {
};

struct DeviceRank {
DeviceRank(se::StreamExecutor* device, int32_t rank)
DeviceRank(se::StreamExecutor* device, RankId rank)
: device(device), rank(rank) {}

se::StreamExecutor* device;
int32_t rank;
RankId rank;
};

// Returns a slice of device memory `buff` containing `count` values of data
Expand Down Expand Up @@ -157,7 +158,7 @@ class NcclApi {
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommsplit
virtual absl::StatusOr<std::vector<std::unique_ptr<Communicator>>> CommSplit(
absl::Span<const Communicator* const> comms, int32_t color,
absl::Span<const int32_t> keys, std::optional<Config> config) = 0;
absl::Span<const RankId> keys, std::optional<Config> config) = 0;

// Abort any uncompleted operations and destroys the communicator. Frees
// resources that are allocated to a communicator object comm.
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/runtime/nccl_api_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/gpu/runtime/nccl_api.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"
Expand Down Expand Up @@ -96,7 +97,7 @@ class NcclApiStub final : public NcclApi {
}

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>> CommSplit(
absl::Span<const Communicator* const>, int32_t, absl::Span<const int32_t>,
absl::Span<const Communicator* const>, int32_t, absl::Span<const RankId>,
std::optional<Config>) final {
return UnimplementedError();
}
Expand Down
53 changes: 28 additions & 25 deletions xla/service/gpu/runtime/nccl_clique.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License.
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/debug_options_flags.h"
#include "xla/executable_run_options.h"
#include "xla/service/global_device_id.h"
Expand Down Expand Up @@ -109,12 +110,12 @@ static bool TerminateOnNcclError() {

NcclCliqueCommunicators::NcclCliqueCommunicators(
NcclCliqueKey clique_key, std::optional<NcclCliqueId> clique_id,
absl::btree_map<int32_t, std::unique_ptr<Communicator>> communicators)
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators)
: clique_key_(std::move(clique_key)),
clique_id_(std::move(clique_id)),
communicators_(std::move(communicators)) {}

std::optional<Communicator*> NcclCliqueCommunicators::comm(int32_t rank) {
std::optional<Communicator*> NcclCliqueCommunicators::comm(RankId rank) {
if (auto it = communicators_.find(rank); it != communicators_.end()) {
return it->second.get();
}
Expand All @@ -126,7 +127,7 @@ bool NcclCliqueCommunicators::IsLocal() const {
}

void NcclCliqueCommunicators::ForEachComm(
absl::FunctionRef<void(int32_t, Communicator*)> fn) {
absl::FunctionRef<void(RankId, Communicator*)> fn) {
for (auto& [rank, comm] : communicators_) {
fn(rank, comm.get());
}
Expand All @@ -141,7 +142,7 @@ std::string NcclCliqueCommunicators::DebugString() const {
int32_t cnt = 0;
for (const auto& [rank, comm] : communicators_) {
if (cnt++) absl::StrAppend(&out, ", ");
absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank, comm.get());
absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank.value(), comm.get());
}
return out;
}
Expand Down Expand Up @@ -195,7 +196,7 @@ static void CheckClique(const NcclCliqueKey& clique_key,
VLOG(5) << "Checking NCCL clique " << clique_key.ToString()
<< " for async errors; num_communicators="
<< clique->num_communicators();
clique->ForEachComm([](int32_t rank, Communicator* comm) {
clique->ForEachComm([](RankId rank, Communicator* comm) {
if (auto status = CheckComm(comm); !status.ok()) LOG(ERROR) << status;
});
} else {
Expand Down Expand Up @@ -241,7 +242,7 @@ static void StartNcclCliqueHeartBeatMonitor() {

static auto DeviceRanksToString(absl::Span<const NcclApi::DeviceRank> ranks) {
return absl::StrJoin(ranks, ",", [](std::string* str, auto& rank) {
str->append(std::to_string(rank.rank));
str->append(std::to_string(rank.rank.value()));
});
}

Expand All @@ -251,7 +252,7 @@ static auto DeviceRanksToString(absl::Span<const NcclApi::DeviceRank> ranks) {
static absl::StatusOr<std::shared_ptr<NcclClique::Lock>> InitializeNcclClique(
se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key,
const NcclCliqueIdCallback& clique_id_callback,
int32_t num_local_participants, int32_t rank, NcclApi::Config& config) {
int32_t num_local_participants, RankId rank, NcclApi::Config& config) {
int nranks = clique_key.devices().size();
VLOG(3) << "Initialize NCCL clique " << clique_key.ToString() << " rank #"
<< rank << "; num_local_participants=" << num_local_participants;
Expand All @@ -274,7 +275,7 @@ static absl::StatusOr<std::shared_ptr<NcclClique::Lock>> InitializeNcclClique(
return Internal(
"Failed to synchronize device activity on rank %d. Do not attempt "
"to initialize NCCL clique.",
device_rank.rank);
device_rank.rank.value());
}
}

Expand All @@ -296,7 +297,7 @@ static absl::StatusOr<std::shared_ptr<NcclClique::Lock>> InitializeNcclClique(
std::vector<std::unique_ptr<Communicator>> created_comms,
NcclApi::Default()->CommInitRanks(nranks, clique_id, ranks, config));

absl::btree_map<int32_t, std::unique_ptr<Communicator>> comms;
absl::btree_map<RankId, std::unique_ptr<Communicator>> comms;
for (size_t i = 0; i < ranks.size(); ++i) {
comms[ranks[i].rank] = std::move(created_comms[i]);
}
Expand Down Expand Up @@ -332,7 +333,7 @@ static absl::StatusOr<std::shared_ptr<NcclClique::Lock>> InitializeNcclClique(
auto rendezvous_key = std::make_tuple(run_id, clique_key);
auto initialization_rendezvous_name =
absl::StrFormat("initialize clique for rank %d; clique=%s; run_id=%d",
rank, clique_key.ToString(), run_id.ToInt());
rank.value(), clique_key.ToString(), run_id.ToInt());

NcclApi::DeviceRank device_rank = {device, rank};
bool synchronized = device->SynchronizeAllActivity();
Expand Down Expand Up @@ -375,17 +376,18 @@ static int32_t GetCommSplitColor(const NcclCliqueKey& clique_key) {
static absl::StatusOr<std::shared_ptr<NcclClique::Lock>> InitializeNcclClique(
se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key,
std::shared_ptr<NcclClique::Lock> parent_clique,
int32_t num_local_participants, int32_t rank, NcclApi::Config& config) {
int32_t num_local_participants, RankId rank, NcclApi::Config& config) {
// Find our rank in the parent clique.
const NcclCliqueKey& parent_clique_key = (*parent_clique)->clique_key();
int32_t parent_rank = *parent_clique_key.rank(clique_key.devices()[rank]);
RankId parent_rank =
*parent_clique_key.rank(clique_key.devices()[rank.value()]);

VLOG(3) << "Initialize NCCL clique " << clique_key.ToString() << " rank #"
<< rank << " by splitting rank #" << parent_rank
<< rank << " by splitting rank #" << parent_rank.value()
<< " in parent clique " << parent_clique_key.ToString()
<< "; num_local_participants=" << num_local_participants;

using RankPair = std::pair<int32_t, int32_t>;
using RankPair = std::pair<RankId, RankId>;
RankPair rank_pair = {parent_rank, rank};

// Current approach for communicator splitting works because of XLAs SPMD
Expand All @@ -402,26 +404,26 @@ static absl::StatusOr<std::shared_ptr<NcclClique::Lock>> InitializeNcclClique(
auto split = [&](absl::Span<const RankPair* const> rank_pairs)
-> absl::StatusOr<NcclClique::Lock> {
// Collect mapping from ranks in parent clique to ranks in a new clique.
absl::btree_map<int32_t, int32_t> rank_mapping;
absl::btree_map<RankId, RankId> rank_mapping;
for (auto* rank_pair : rank_pairs) {
rank_mapping[rank_pair->first] = rank_pair->second;
}

auto rank_mapping_formatter = [](std::string* str, auto mapping) {
absl::StrAppend(str, mapping.first, "->", mapping.second);
absl::StrAppend(str, mapping.first.value(), "->", mapping.second.value());
};

// Collect parent communicators we'll be splitting from and keys for
// creating new communicators.
std::vector<Communicator*> parent_comms;
std::vector<int32_t> keys;
std::vector<RankId> keys;

for (auto& [parent_rank, split_rank] : rank_mapping) {
auto parent_comm = (*parent_clique)->comm(parent_rank);
if (!parent_comm.has_value()) {
return absl::InvalidArgumentError(absl::StrFormat(
"Parent clique %s does not have a communicator for rank %d",
parent_clique_key.ToString(), parent_rank));
parent_clique_key.ToString(), parent_rank.value()));
}

parent_comms.push_back(*parent_comm);
Expand All @@ -441,7 +443,7 @@ static absl::StatusOr<std::shared_ptr<NcclClique::Lock>> InitializeNcclClique(
auto splitted_comms,
NcclApi::Default()->CommSplit(parent_comms, color, keys, config));

absl::btree_map<int32_t, std::unique_ptr<Communicator>> comms;
absl::btree_map<RankId, std::unique_ptr<Communicator>> comms;
for (size_t i = 0; i < splitted_comms.size(); ++i) {
comms[keys[i]] = std::move(splitted_comms[i]);
}
Expand Down Expand Up @@ -476,8 +478,9 @@ static absl::StatusOr<std::shared_ptr<NcclClique::Lock>> InitializeNcclClique(
// will update cliques state, and others will destroy unused communicators.
auto rendezvous_key = std::make_tuple(run_id, clique_key, parent_clique_key);
auto initialization_rendezvous_name = absl::StrFormat(
"initialize clique for rank %d; clique=%s; run_id=%d; parent=%s", rank,
clique_key.ToString(), run_id.ToInt(), parent_clique_key.ToString());
"initialize clique for rank %d; clique=%s; run_id=%d; parent=%s",
rank.value(), clique_key.ToString(), run_id.ToInt(),
parent_clique_key.ToString());

return RendezvousSingle<absl::StatusOr<NcclClique::Lock>>(
initialization_rendezvous_name, rendezvous_key, rank_pair,
Expand All @@ -490,7 +493,7 @@ using AcquiredCliquesMap = NcclClique::AcquiredCliquesMap;

absl::StatusOr<std::shared_ptr<NcclClique::Lock>> AcquireNcclClique(
se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key,
const NcclCliqueIdCallback& clique_id_callback, int32_t rank,
const NcclCliqueIdCallback& clique_id_callback, RankId rank,
size_t num_local_participants, const AcquiredCliquesMap& acquired_cliques,
int64_t max_nchannels) {
VLOG(2) << "Acquire NCCL clique " << clique_key.ToString() << "; run"
Expand All @@ -502,8 +505,8 @@ absl::StatusOr<std::shared_ptr<NcclClique::Lock>> AcquireNcclClique(
// members participate in XLA run.
auto rendezvous_key = std::make_tuple(run_id, clique_key);
auto rendezvous_name =
absl::StrFormat("acquire clique for rank %d; clique=%s; run_id=%d", rank,
clique_key.ToString(), run_id.ToInt());
absl::StrFormat("acquire clique for rank %d; clique=%s; run_id=%d",
rank.value(), clique_key.ToString(), run_id.ToInt());

TF_ASSIGN_OR_RETURN(
std::shared_ptr<NcclClique::Lock> clique,
Expand Down Expand Up @@ -552,7 +555,7 @@ absl::Status NcclClique::CheckAsyncErrors() {

absl::Status NcclCliqueCommunicators::AsyncErrorChecker::Check() {
absl::Status status = absl::OkStatus();
communicators_.ForEachComm([&status](int32_t rank, Communicator* comm) {
communicators_.ForEachComm([&status](RankId rank, Communicator* comm) {
// Do not overwrite previous errors.
if (!status.ok()) return;
status = NcclApi::Default()->CommGetAsyncError(comm);
Expand Down
Loading

0 comments on commit 9c668fd

Please sign in to comment.