diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc index ed56b400e5404..589caf3efa9ab 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc @@ -249,7 +249,7 @@ class ResourceRequests : public Thunk::ResourceRequests { NcclClique::AcquiredCliquesMap cliques_map; for (const CliqueRequest& r : ordered_cliques) { - std::optional rank = r.key.rank(params.global_device_id); + std::optional rank = r.key.rank(params.global_device_id); if (!rank.has_value()) { return absl::InternalError(absl::StrCat( diff --git a/xla/service/gpu/runtime/BUILD b/xla/service/gpu/runtime/BUILD index b7161fe75461b..adb908c0360ec 100644 --- a/xla/service/gpu/runtime/BUILD +++ b/xla/service/gpu/runtime/BUILD @@ -217,6 +217,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/backends/gpu/collectives:nccl_communicator", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", @@ -257,6 +258,7 @@ cc_library( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", @@ -282,6 +284,7 @@ cc_library( "//xla:types", "//xla:util", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/service:global_device_id", "//xla/service:lockable", "//xla/service:rendezvous", @@ -313,6 +316,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//xla/core/collectives:clique_id", + "//xla/core/collectives:rank_id", "//xla/service:global_device_id", "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/algorithm:container", @@ -950,6 +954,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/hlo/ir:hlo", "//xla/hlo/translate/mhlo_to_hlo:attribute_exporter", "//xla/service:buffer_assignment", @@ -1187,6 +1192,7 @@ cc_library( ":nccl_clique_key", "//xla:executable_run_options", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/ffi:execution_context", "//xla/hlo/ir:hlo", "//xla/hlo/translate/mhlo_to_hlo:location_exporter", diff --git a/xla/service/gpu/runtime/nccl_api.cc b/xla/service/gpu/runtime/nccl_api.cc index 85057521ab537..3079dd380a428 100644 --- a/xla/service/gpu/runtime/nccl_api.cc +++ b/xla/service/gpu/runtime/nccl_api.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -32,6 +33,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/backends/gpu/collectives/nccl_communicator.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/primitive_util.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/runtime/nccl_clique_key.h" @@ -297,7 +299,7 @@ class DefaultNcclApi final : public NcclApi { absl::StatusOr>> CommSplit( absl::Span comms, int32_t color, - absl::Span keys, std::optional config) final; + absl::Span keys, std::optional config) final; absl::Status CommAbort(Communicator* comm) final; absl::Status CommFinalize(Communicator* comm) final; @@ -406,8 +408,9 @@ DefaultNcclApi::CommInitRanks(int32_t nranks, const NcclCliqueId& clique_id, auto activate_context = ranks[i].device->Activate(); TF_ASSIGN_OR_RETURN(auto nccl_unique_id, AsNcclUniqueId(clique_id)); - XLA_NCCL_RETURN_IF_ERROR(ncclCommInitRankConfig( - &comm_handles[i], nranks, nccl_unique_id, ranks[i].rank, &comm_config)); + XLA_NCCL_RETURN_IF_ERROR( + ncclCommInitRankConfig(&comm_handles[i], nranks, nccl_unique_id, + ranks[i].rank.value(), &comm_config)); } TF_RETURN_IF_ERROR(GroupEnd()); @@ -420,11 +423,15 @@ DefaultNcclApi::CommInitRanks(int32_t nranks, const NcclCliqueId& clique_id, absl::StatusOr>> DefaultNcclApi::CommSplit(absl::Span comms, - int32_t color, absl::Span keys, + int32_t color, absl::Span keys, std::optional config) { + auto rank_formatter = [](std::string* str, RankId rank) { + absl::StrAppend(str, rank.value()); + }; + VLOG(1) << absl::StreamFormat( "Split %d NCCL communicators using color %d and keys: [%s]", comms.size(), - color, absl::StrJoin(keys, ",")); + color, absl::StrJoin(keys, ",", rank_formatter)); #if !defined(TENSORFLOW_USE_ROCM) || TF_ROCM_VERSION >= 60000 if (keys.size() != comms.size()) { @@ -456,9 +463,9 @@ DefaultNcclApi::CommSplit(absl::Span comms, for (size_t i = 0; i < comms.size(); ++i) { VLOG(1) << "Split NCCL communicator " << comms[i] << " with color " << color << " and key " << keys[i]; - XLA_NCCL_RETURN_IF_ERROR(ncclCommSplit(Cast(comms[i]), color, keys[i], - &split_comms_handles[i], - /*config=*/comm_config_ptr)); + XLA_NCCL_RETURN_IF_ERROR(ncclCommSplit( + Cast(comms[i]), color, keys[i].value(), &split_comms_handles[i], + /*config=*/comm_config_ptr)); } TF_RETURN_IF_ERROR(GroupEnd()); diff --git a/xla/service/gpu/runtime/nccl_api.h b/xla/service/gpu/runtime/nccl_api.h index 524e589959cb9..1452f3329ff44 100644 --- a/xla/service/gpu/runtime/nccl_api.h +++ b/xla/service/gpu/runtime/nccl_api.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/shape_util.h" @@ -118,11 +119,11 @@ class NcclApi { }; struct DeviceRank { - DeviceRank(se::StreamExecutor* device, int32_t rank) + DeviceRank(se::StreamExecutor* device, RankId rank) : device(device), rank(rank) {} se::StreamExecutor* device; - int32_t rank; + RankId rank; }; // Returns a slice of device memory `buff` containing `count` values of data @@ -157,7 +158,7 @@ class NcclApi { // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommsplit virtual absl::StatusOr>> CommSplit( absl::Span comms, int32_t color, - absl::Span keys, std::optional config) = 0; + absl::Span keys, std::optional config) = 0; // Abort any uncompleted operations and destroys the communicator. Frees // resources that are allocated to a communicator object comm. diff --git a/xla/service/gpu/runtime/nccl_api_stub.cc b/xla/service/gpu/runtime/nccl_api_stub.cc index b5157fe1ffcdc..ccd6f8df74ddf 100644 --- a/xla/service/gpu/runtime/nccl_api_stub.cc +++ b/xla/service/gpu/runtime/nccl_api_stub.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/runtime/nccl_api.h" #include "xla/service/gpu/runtime/nccl_clique_key.h" @@ -96,7 +97,7 @@ class NcclApiStub final : public NcclApi { } absl::StatusOr>> CommSplit( - absl::Span, int32_t, absl::Span, + absl::Span, int32_t, absl::Span, std::optional) final { return UnimplementedError(); } diff --git a/xla/service/gpu/runtime/nccl_clique.cc b/xla/service/gpu/runtime/nccl_clique.cc index 64681e4b4e7a3..61ff1bd195fac 100644 --- a/xla/service/gpu/runtime/nccl_clique.cc +++ b/xla/service/gpu/runtime/nccl_clique.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" #include "xla/service/global_device_id.h" @@ -109,12 +110,12 @@ static bool TerminateOnNcclError() { NcclCliqueCommunicators::NcclCliqueCommunicators( NcclCliqueKey clique_key, std::optional clique_id, - absl::btree_map> communicators) + absl::btree_map> communicators) : clique_key_(std::move(clique_key)), clique_id_(std::move(clique_id)), communicators_(std::move(communicators)) {} -std::optional NcclCliqueCommunicators::comm(int32_t rank) { +std::optional NcclCliqueCommunicators::comm(RankId rank) { if (auto it = communicators_.find(rank); it != communicators_.end()) { return it->second.get(); } @@ -126,7 +127,7 @@ bool NcclCliqueCommunicators::IsLocal() const { } void NcclCliqueCommunicators::ForEachComm( - absl::FunctionRef fn) { + absl::FunctionRef fn) { for (auto& [rank, comm] : communicators_) { fn(rank, comm.get()); } @@ -141,7 +142,7 @@ std::string NcclCliqueCommunicators::DebugString() const { int32_t cnt = 0; for (const auto& [rank, comm] : communicators_) { if (cnt++) absl::StrAppend(&out, ", "); - absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank, comm.get()); + absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank.value(), comm.get()); } return out; } @@ -195,7 +196,7 @@ static void CheckClique(const NcclCliqueKey& clique_key, VLOG(5) << "Checking NCCL clique " << clique_key.ToString() << " for async errors; num_communicators=" << clique->num_communicators(); - clique->ForEachComm([](int32_t rank, Communicator* comm) { + clique->ForEachComm([](RankId rank, Communicator* comm) { if (auto status = CheckComm(comm); !status.ok()) LOG(ERROR) << status; }); } else { @@ -241,7 +242,7 @@ static void StartNcclCliqueHeartBeatMonitor() { static auto DeviceRanksToString(absl::Span ranks) { return absl::StrJoin(ranks, ",", [](std::string* str, auto& rank) { - str->append(std::to_string(rank.rank)); + str->append(std::to_string(rank.rank.value())); }); } @@ -251,7 +252,7 @@ static auto DeviceRanksToString(absl::Span ranks) { static absl::StatusOr> InitializeNcclClique( se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, const NcclCliqueIdCallback& clique_id_callback, - int32_t num_local_participants, int32_t rank, NcclApi::Config& config) { + int32_t num_local_participants, RankId rank, NcclApi::Config& config) { int nranks = clique_key.devices().size(); VLOG(3) << "Initialize NCCL clique " << clique_key.ToString() << " rank #" << rank << "; num_local_participants=" << num_local_participants; @@ -274,7 +275,7 @@ static absl::StatusOr> InitializeNcclClique( return Internal( "Failed to synchronize device activity on rank %d. Do not attempt " "to initialize NCCL clique.", - device_rank.rank); + device_rank.rank.value()); } } @@ -296,7 +297,7 @@ static absl::StatusOr> InitializeNcclClique( std::vector> created_comms, NcclApi::Default()->CommInitRanks(nranks, clique_id, ranks, config)); - absl::btree_map> comms; + absl::btree_map> comms; for (size_t i = 0; i < ranks.size(); ++i) { comms[ranks[i].rank] = std::move(created_comms[i]); } @@ -332,7 +333,7 @@ static absl::StatusOr> InitializeNcclClique( auto rendezvous_key = std::make_tuple(run_id, clique_key); auto initialization_rendezvous_name = absl::StrFormat("initialize clique for rank %d; clique=%s; run_id=%d", - rank, clique_key.ToString(), run_id.ToInt()); + rank.value(), clique_key.ToString(), run_id.ToInt()); NcclApi::DeviceRank device_rank = {device, rank}; bool synchronized = device->SynchronizeAllActivity(); @@ -375,17 +376,18 @@ static int32_t GetCommSplitColor(const NcclCliqueKey& clique_key) { static absl::StatusOr> InitializeNcclClique( se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, std::shared_ptr parent_clique, - int32_t num_local_participants, int32_t rank, NcclApi::Config& config) { + int32_t num_local_participants, RankId rank, NcclApi::Config& config) { // Find our rank in the parent clique. const NcclCliqueKey& parent_clique_key = (*parent_clique)->clique_key(); - int32_t parent_rank = *parent_clique_key.rank(clique_key.devices()[rank]); + RankId parent_rank = + *parent_clique_key.rank(clique_key.devices()[rank.value()]); VLOG(3) << "Initialize NCCL clique " << clique_key.ToString() << " rank #" - << rank << " by splitting rank #" << parent_rank + << rank << " by splitting rank #" << parent_rank.value() << " in parent clique " << parent_clique_key.ToString() << "; num_local_participants=" << num_local_participants; - using RankPair = std::pair; + using RankPair = std::pair; RankPair rank_pair = {parent_rank, rank}; // Current approach for communicator splitting works because of XLAs SPMD @@ -402,26 +404,26 @@ static absl::StatusOr> InitializeNcclClique( auto split = [&](absl::Span rank_pairs) -> absl::StatusOr { // Collect mapping from ranks in parent clique to ranks in a new clique. - absl::btree_map rank_mapping; + absl::btree_map rank_mapping; for (auto* rank_pair : rank_pairs) { rank_mapping[rank_pair->first] = rank_pair->second; } auto rank_mapping_formatter = [](std::string* str, auto mapping) { - absl::StrAppend(str, mapping.first, "->", mapping.second); + absl::StrAppend(str, mapping.first.value(), "->", mapping.second.value()); }; // Collect parent communicators we'll be splitting from and keys for // creating new communicators. std::vector parent_comms; - std::vector keys; + std::vector keys; for (auto& [parent_rank, split_rank] : rank_mapping) { auto parent_comm = (*parent_clique)->comm(parent_rank); if (!parent_comm.has_value()) { return absl::InvalidArgumentError(absl::StrFormat( "Parent clique %s does not have a communicator for rank %d", - parent_clique_key.ToString(), parent_rank)); + parent_clique_key.ToString(), parent_rank.value())); } parent_comms.push_back(*parent_comm); @@ -441,7 +443,7 @@ static absl::StatusOr> InitializeNcclClique( auto splitted_comms, NcclApi::Default()->CommSplit(parent_comms, color, keys, config)); - absl::btree_map> comms; + absl::btree_map> comms; for (size_t i = 0; i < splitted_comms.size(); ++i) { comms[keys[i]] = std::move(splitted_comms[i]); } @@ -476,8 +478,9 @@ static absl::StatusOr> InitializeNcclClique( // will update cliques state, and others will destroy unused communicators. auto rendezvous_key = std::make_tuple(run_id, clique_key, parent_clique_key); auto initialization_rendezvous_name = absl::StrFormat( - "initialize clique for rank %d; clique=%s; run_id=%d; parent=%s", rank, - clique_key.ToString(), run_id.ToInt(), parent_clique_key.ToString()); + "initialize clique for rank %d; clique=%s; run_id=%d; parent=%s", + rank.value(), clique_key.ToString(), run_id.ToInt(), + parent_clique_key.ToString()); return RendezvousSingle>( initialization_rendezvous_name, rendezvous_key, rank_pair, @@ -490,7 +493,7 @@ using AcquiredCliquesMap = NcclClique::AcquiredCliquesMap; absl::StatusOr> AcquireNcclClique( se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, - const NcclCliqueIdCallback& clique_id_callback, int32_t rank, + const NcclCliqueIdCallback& clique_id_callback, RankId rank, size_t num_local_participants, const AcquiredCliquesMap& acquired_cliques, int64_t max_nchannels) { VLOG(2) << "Acquire NCCL clique " << clique_key.ToString() << "; run" @@ -502,8 +505,8 @@ absl::StatusOr> AcquireNcclClique( // members participate in XLA run. auto rendezvous_key = std::make_tuple(run_id, clique_key); auto rendezvous_name = - absl::StrFormat("acquire clique for rank %d; clique=%s; run_id=%d", rank, - clique_key.ToString(), run_id.ToInt()); + absl::StrFormat("acquire clique for rank %d; clique=%s; run_id=%d", + rank.value(), clique_key.ToString(), run_id.ToInt()); TF_ASSIGN_OR_RETURN( std::shared_ptr clique, @@ -552,7 +555,7 @@ absl::Status NcclClique::CheckAsyncErrors() { absl::Status NcclCliqueCommunicators::AsyncErrorChecker::Check() { absl::Status status = absl::OkStatus(); - communicators_.ForEachComm([&status](int32_t rank, Communicator* comm) { + communicators_.ForEachComm([&status](RankId rank, Communicator* comm) { // Do not overwrite previous errors. if (!status.ok()) return; status = NcclApi::Default()->CommGetAsyncError(comm); diff --git a/xla/service/gpu/runtime/nccl_clique.h b/xla/service/gpu/runtime/nccl_clique.h index 66fc5031c882e..32ed45fea33ca 100644 --- a/xla/service/gpu/runtime/nccl_clique.h +++ b/xla/service/gpu/runtime/nccl_clique.h @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/executable_run_options.h" #include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/lockable.h" @@ -94,17 +95,17 @@ class NcclCliqueCommunicators { NcclCliqueCommunicators( NcclCliqueKey clique_key, std::optional clique_id, - absl::btree_map> communicators); + absl::btree_map> communicators); // Returns a NCCL communicator for a given rank if it's in a clique. - std::optional comm(int32_t rank); + std::optional comm(RankId rank); // Return true if clique is local: all communicators belong to current // process. Non-local cliques spans multiple processes (typically hosts). bool IsLocal() const; // Calls `fn` for each communicator in the clique. - void ForEachComm(absl::FunctionRef fn); + void ForEachComm(absl::FunctionRef fn); const NcclCliqueKey& clique_key() const { return clique_key_; } const std::optional& clique_id() const { return clique_id_; } @@ -119,7 +120,7 @@ class NcclCliqueCommunicators { std::optional clique_id_; // TODO(ezhulenev): Switch this map to GlobalDeviceId key. - absl::btree_map> communicators_; + absl::btree_map> communicators_; }; struct NcclCliqueName { @@ -143,7 +144,7 @@ class NcclClique : public Lockable { // to the communicators from an acquired lock. NcclClique( NcclCliqueKey clique_key, std::optional clique_id, - absl::btree_map> communicators) + absl::btree_map> communicators) : Lockable(std::move(clique_key), clique_id, std::move(communicators)), async_error_checker_(Acquire()->GetChecker()) {} @@ -167,7 +168,7 @@ class NcclClique : public Lockable { // cliques. absl::StatusOr> AcquireNcclClique( se::StreamExecutor* device, RunId run_id, NcclCliqueKey clique_key, - const NcclCliqueIdCallback& clique_id_callback, int32_t rank, + const NcclCliqueIdCallback& clique_id_callback, RankId rank, size_t num_local_participants, const NcclClique::AcquiredCliquesMap& acquired_cliques, int64_t max_nchannels = 0); diff --git a/xla/service/gpu/runtime/nccl_clique_key.cc b/xla/service/gpu/runtime/nccl_clique_key.cc index 8aae26356ef26..f0ce2861a6d5c 100644 --- a/xla/service/gpu/runtime/nccl_clique_key.cc +++ b/xla/service/gpu/runtime/nccl_clique_key.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/global_device_id.h" #include "tsl/platform/logging.h" @@ -65,9 +66,9 @@ absl::Span NcclCliqueKey::devices() const { NcclStreamId NcclCliqueKey::stream_id() const { return stream_id_; } -std::optional NcclCliqueKey::rank(GlobalDeviceId id) const { +std::optional NcclCliqueKey::rank(GlobalDeviceId id) const { if (auto it = absl::c_find(devices_, id); it != devices_.end()) { - return it - devices_.begin(); + return RankId(it - devices_.begin()); } return std::nullopt; } diff --git a/xla/service/gpu/runtime/nccl_clique_key.h b/xla/service/gpu/runtime/nccl_clique_key.h index d14b3a903a071..84246b7743f3e 100644 --- a/xla/service/gpu/runtime/nccl_clique_key.h +++ b/xla/service/gpu/runtime/nccl_clique_key.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/global_device_id.h" #include "xla/tsl/lib/gtl/int_type.h" @@ -88,7 +89,7 @@ class NcclCliqueKey { NcclStreamId stream_id() const; // Returns the rank of the global device in the clique. - std::optional rank(GlobalDeviceId id) const; + std::optional rank(GlobalDeviceId id) const; // Returns true if this clique is a subset of `other`: both cliques have the // same `stream_id` and all clique devices are part of `other` clique. diff --git a/xla/service/gpu/runtime/nccl_collective_thunk.cc b/xla/service/gpu/runtime/nccl_collective_thunk.cc index ffb7c7009bf5f..ce2df0c3c7d58 100644 --- a/xla/service/gpu/runtime/nccl_collective_thunk.cc +++ b/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -34,8 +34,8 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" -#include "mlir/IR/Value.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/layout_util.h" @@ -44,7 +44,6 @@ limitations under the License. #include "xla/service/computation_placer.h" #include "xla/service/global_device_id.h" #include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/runtime/nccl_api.h" #include "xla/service/gpu/runtime/nccl_clique.h" #include "xla/service/gpu/runtime/nccl_clique_key.h" @@ -227,8 +226,8 @@ absl::StatusOr GetNcclCliqueKey( GetParticipatingDevices(global_device_id, *params.device_assn, replica_groups, group_mode)); - // If splitting is enabled, particpating groups must match in order for a - // clique to be reused from the cache. We can ignore the particpating groups + // If splitting is enabled, participating groups must match in order for a + // clique to be reused from the cache. We can ignore the participating groups // otherwise. static const int64_t enable_nccl_comm_splitting = xla::GetDebugOptionsFromFlags().xla_gpu_enable_nccl_comm_splitting(); @@ -263,7 +262,7 @@ absl::StatusOr GetNcclComm( GetNcclCliqueKey(params, replica_groups, group_mode, stream_id, stream_kind)); - std::optional rank = clique_key.rank(params.global_device_id); + std::optional rank = clique_key.rank(params.global_device_id); TF_ASSIGN_OR_RETURN(bool is_local, collective_cliques.is_local_clique(clique_key)); TF_ASSIGN_OR_RETURN(Communicator * comm, @@ -462,12 +461,13 @@ absl::Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { params.collective_cliques->num_communicators(clique_key)); auto global_device_id = params.collective_params->global_device_id; + RankId rank = clique_key.rank(global_device_id).value_or(RankId(-1)); VLOG(1) << "Do a rendezvous after a first call to " << Thunk::KindToString(kind()) << "; run_id=" << params.collective_params->run_id.ToInt() << "; op_id=" << config().op_id << "; num_local_participants=" << num_local_participants - << "; rank=" << clique_key.rank(global_device_id).value_or(-1) + << "; rank=" << rank.value() << "; clique_key=" << clique_key.ToString(); auto rendezvous_key = FirstCallRendezvousKey{std::move(clique_key)}; diff --git a/xla/service/gpu/runtime/thunk.cc b/xla/service/gpu/runtime/thunk.cc index bf28cdc01baae..3eee55a5031d4 100644 --- a/xla/service/gpu/runtime/thunk.cc +++ b/xla/service/gpu/runtime/thunk.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/executable_run_options.h" #include "xla/ffi/execution_context.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -55,7 +56,7 @@ Thunk::CollectiveCliques::CollectiveCliques( : cliques_map_(std::move(cliques_map)) {} absl::StatusOr Thunk::CollectiveCliques::GetComm( - const NcclCliqueKey& clique_key, int32_t rank) const { + const NcclCliqueKey& clique_key, RankId rank) const { // Check that we locked access to a clique for `clique_key`. auto clique = cliques_map_.find(clique_key); if (clique == cliques_map_.end()) { @@ -66,9 +67,9 @@ absl::StatusOr Thunk::CollectiveCliques::GetComm( // Check that clique has a communicator for our rank. auto communicator = (*clique->second)->comm(rank); if (!communicator.has_value()) { - return absl::InternalError(absl::StrCat("Communicator for rank ", rank, - " not found in a NCCL clique ", - clique_key.ToString())); + return absl::InternalError( + absl::StrCat("Communicator for rank ", rank.value(), + " not found in a NCCL clique ", clique_key.ToString())); } return *communicator; diff --git a/xla/service/gpu/runtime/thunk.h b/xla/service/gpu/runtime/thunk.h index a445392193a14..2634b119942e0 100644 --- a/xla/service/gpu/runtime/thunk.h +++ b/xla/service/gpu/runtime/thunk.h @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/executable_run_options.h" #include "xla/ffi/execution_context.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -214,7 +215,7 @@ class Thunk { explicit CollectiveCliques(NcclClique::AcquiredCliquesMap cliques_map); absl::StatusOr GetComm(const NcclCliqueKey& clique_key, - int32_t rank) const; + RankId rank) const; // Returns the number of communicators in a collective clique. Returns error // if we do not have an acquired clique for a given key.