diff --git a/xla/backends/gpu/collectives/BUILD b/xla/backends/gpu/collectives/BUILD index 868920737faef..477f8a6681e46 100644 --- a/xla/backends/gpu/collectives/BUILD +++ b/xla/backends/gpu/collectives/BUILD @@ -18,6 +18,20 @@ package_group( ], ) +cc_library( + name = "gpu_clique", + srcs = ["gpu_clique.cc"], + hdrs = ["gpu_clique.h"], + deps = [ + ":gpu_clique_key", + "//xla/core/collectives:clique", + "//xla/core/collectives:clique_id", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "@com_google_absl//absl/container:btree", + ], +) + cc_library( name = "gpu_clique_key", srcs = ["gpu_clique_key.cc"], @@ -72,6 +86,7 @@ cc_library( ], ) +# TODO(b/380457503): Update visibility to "//visibility:private". cc_library( name = "nccl_collectives", hdrs = if_gpu_is_configured(["nccl_collectives.h"]), @@ -91,6 +106,26 @@ cc_library( ]), ) +# TODO(b/380457503): Update visibility to "//visibility:private". +cc_library( + name = "nccl_clique", + srcs = ["nccl_clique.cc"], + hdrs = ["nccl_clique.h"], + deps = [ + ":gpu_clique", + "//xla/core/collectives:clique", + "//xla/core/collectives:clique_id", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/service/gpu/runtime:nccl_api", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:logging", + ], +) + # TODO(b/380457503): Update visibility to "//visibility:private". cc_library( name = "nccl_communicator", diff --git a/xla/backends/gpu/collectives/gpu_clique.cc b/xla/backends/gpu/collectives/gpu_clique.cc new file mode 100644 index 0000000000000..6a7a6c495b6eb --- /dev/null +++ b/xla/backends/gpu/collectives/gpu_clique.cc @@ -0,0 +1,38 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/collectives/gpu_clique.h" + +#include +#include +#include + +#include "absl/container/btree_map.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" +#include "xla/core/collectives/clique.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" + +namespace xla::gpu { + +GpuClique::GpuClique( + GpuCliqueKey clique_key, std::optional clique_id, + absl::btree_map> communicators) + : Clique(std::move(communicators)), + clique_key_(clique_key), + clique_id_(clique_id) {} + +} // namespace xla::gpu diff --git a/xla/backends/gpu/collectives/gpu_clique.h b/xla/backends/gpu/collectives/gpu_clique.h new file mode 100644 index 0000000000000..5571f25f2a495 --- /dev/null +++ b/xla/backends/gpu/collectives/gpu_clique.h @@ -0,0 +1,54 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_H_ +#define XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_H_ + +#include +#include + +#include "absl/container/btree_map.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" +#include "xla/core/collectives/clique.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" + +namespace xla::gpu { + +// A group of GPU communicators making up a clique for a given clique key. +class GpuClique : public Clique { + public: + GpuClique( + GpuCliqueKey clique_key, std::optional clique_id, + absl::btree_map> communicators); + + // Returns true if clique is local: all communicators belong to current + // process. Non-local cliques spans multiple processes (typically hosts). + bool IsLocal() const { + return num_communicators() == clique_key_.devices().size(); + } + + const GpuCliqueKey& clique_key() const { return clique_key_; } + const std::optional& clique_id() const { return clique_id_; } + + private: + GpuCliqueKey clique_key_; + std::optional clique_id_; +}; + +} // namespace xla::gpu + +#endif // XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_H_ diff --git a/xla/backends/gpu/collectives/gpu_clique_key.h b/xla/backends/gpu/collectives/gpu_clique_key.h index 7d6a1b79f433c..d563db28b0b00 100644 --- a/xla/backends/gpu/collectives/gpu_clique_key.h +++ b/xla/backends/gpu/collectives/gpu_clique_key.h @@ -60,6 +60,12 @@ class GpuCliqueKey : public CliqueKey { AsyncStreamKind stream_kind = AsyncStreamKind::kCollective, std::vector> participant_groups = {}); + GpuCliqueKey(const GpuCliqueKey&) = default; + GpuCliqueKey& operator=(const GpuCliqueKey&) = default; + + GpuCliqueKey(GpuCliqueKey&&) = default; + GpuCliqueKey& operator=(GpuCliqueKey&&) = default; + CollectiveStreamId stream_id() const; // Returns true if this clique is a subset of `other`: both cliques have the diff --git a/xla/backends/gpu/collectives/nccl_clique.cc b/xla/backends/gpu/collectives/nccl_clique.cc new file mode 100644 index 0000000000000..51570d9e41a77 --- /dev/null +++ b/xla/backends/gpu/collectives/nccl_clique.cc @@ -0,0 +1,58 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/collectives/nccl_clique.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "tsl/platform/logging.h" + +namespace xla::gpu { + +std::string NcclCliqueImpl::DebugString() const { + std::string out = absl::StrFormat( + "clique_key: %s; fingerprint(id): %d; size: %d; communicators: ", + clique_key().ToString(), + clique_id().has_value() ? clique_id()->fingerprint() : 0, + num_communicators()); + int32_t cnt = 0; + ForEachComm([&](RankId rank, Communicator* comm) { + if (cnt++) absl::StrAppend(&out, ", "); + absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank.value(), comm); + }); + return out; +} + +absl::Status NcclCliqueImpl::HealthCheck() const { + absl::Status health_check = absl::OkStatus(); + ForEachComm([&health_check](RankId rank, Communicator* comm) { + // TODO(b/380457503): Move error checking API to communicator base class. + if (auto s = NcclApi::Default()->CommGetAsyncError(comm); !s.ok()) { + LOG(ERROR) << "NCCL async error (rank " << rank << "): " << s; + if (health_check.ok()) health_check = std::move(s); // return first error + } + }); + return health_check; +} + +} // namespace xla::gpu diff --git a/xla/backends/gpu/collectives/nccl_clique.h b/xla/backends/gpu/collectives/nccl_clique.h new file mode 100644 index 0000000000000..e3bdb68a0fb66 --- /dev/null +++ b/xla/backends/gpu/collectives/nccl_clique.h @@ -0,0 +1,40 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_GPU_COLLECTIVES_NCCL_CLIQUE_H_ +#define XLA_BACKENDS_GPU_COLLECTIVES_NCCL_CLIQUE_H_ + +#include + +#include "absl/status/status.h" +#include "xla/backends/gpu/collectives/gpu_clique.h" + +namespace xla::gpu { + +// A GPU clique that is implemented on top of NCCL communicators. +// +// TODO(b/380457503): Remove `Impl` suffix once we migrate all users to +// LockableGpuClique. +class NcclCliqueImpl : public GpuClique { + public: + using GpuClique::GpuClique; + + std::string DebugString() const final; + absl::Status HealthCheck() const final; +}; + +} // namespace xla::gpu + +#endif // XLA_BACKENDS_GPU_COLLECTIVES_NCCL_CLIQUE_H_ diff --git a/xla/core/collectives/clique_key.h b/xla/core/collectives/clique_key.h index de00008059383..0541177343150 100644 --- a/xla/core/collectives/clique_key.h +++ b/xla/core/collectives/clique_key.h @@ -42,6 +42,12 @@ class CliqueKey { explicit CliqueKey(std::vector devices); virtual ~CliqueKey() = default; + CliqueKey(const CliqueKey& other) = default; + CliqueKey& operator=(const CliqueKey& other) = default; + + CliqueKey(CliqueKey&& other) = default; + CliqueKey& operator=(CliqueKey&& other) = default; + // Returns the rank of the global device in the clique. std::optional rank(GlobalDeviceId id) const; diff --git a/xla/service/gpu/runtime/BUILD b/xla/service/gpu/runtime/BUILD index 33cab54c5819a..66a98966edcf2 100644 --- a/xla/service/gpu/runtime/BUILD +++ b/xla/service/gpu/runtime/BUILD @@ -288,6 +288,7 @@ cc_library( "//xla:types", "//xla:util", "//xla/backends/gpu/collectives:gpu_clique_key", + "//xla/backends/gpu/collectives:nccl_clique", "//xla/core/collectives:clique", "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", diff --git a/xla/service/gpu/runtime/nccl_clique.cc b/xla/service/gpu/runtime/nccl_clique.cc index bafcdb576f4a7..0c4e67e9f79b6 100644 --- a/xla/service/gpu/runtime/nccl_clique.cc +++ b/xla/service/gpu/runtime/nccl_clique.cc @@ -112,42 +112,6 @@ static bool TerminateOnNcclError() { // NcclClique //===----------------------------------------------------------------------===// -NcclCliqueCommunicators::NcclCliqueCommunicators( - GpuCliqueKey clique_key, std::optional clique_id, - absl::btree_map> communicators) - : Clique(std::move(communicators)), - clique_key_(std::move(clique_key)), - clique_id_(std::move(clique_id)) {} - -bool NcclCliqueCommunicators::IsLocal() const { - return num_communicators() == clique_key_.devices().size(); -} - -std::string NcclCliqueCommunicators::DebugString() const { - std::string out = absl::StrFormat( - "clique_key: %s; fingerprint(id): %d; size: %d; communicators: ", - clique_key_.ToString(), - clique_id_.has_value() ? clique_id_->fingerprint() : 0, - num_communicators()); - int32_t cnt = 0; - ForEachComm([&](RankId rank, Communicator* comm) { - if (cnt++) absl::StrAppend(&out, ", "); - absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank.value(), comm); - }); - return out; -} - -absl::Status NcclCliqueCommunicators::HealthCheck() const { - absl::Status health_check = absl::OkStatus(); - ForEachComm([&health_check](RankId rank, Communicator* comm) { - if (auto s = NcclApi::Default()->CommGetAsyncError(comm); !s.ok()) { - LOG(ERROR) << "NCCL async error (rank " << rank << "): " << s; - if (health_check.ok()) health_check = std::move(s); // return first error - } - }); - return health_check; -} - std::string NcclClique::DebugString() const { return absl::StrFormat("NcclClique: %s", value().DebugString()); } diff --git a/xla/service/gpu/runtime/nccl_clique.h b/xla/service/gpu/runtime/nccl_clique.h index 6bea290f6daba..ad048570a852b 100644 --- a/xla/service/gpu/runtime/nccl_clique.h +++ b/xla/service/gpu/runtime/nccl_clique.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" +#include "xla/backends/gpu/collectives/nccl_clique.h" #include "xla/core/collectives/clique.h" #include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/communicator.h" @@ -76,39 +77,13 @@ absl::StatusOr GetCliqueIdCallback( // NcclClique //===----------------------------------------------------------------------===// -// A group of NCCL communicators making up a clique. With NCCL it's notoriously -// easy to get a deadlock, so we take extra care by grouping communicators into -// cliques and making sure that we have a well defined order of all collective -// operations that does not lead to deadlocks. -class NcclCliqueCommunicators : public Clique { - public: - NcclCliqueCommunicators( - GpuCliqueKey clique_key, std::optional clique_id, - absl::btree_map> communicators); - - // Return true if clique is local: all communicators belong to current - // process. Non-local cliques spans multiple processes (typically hosts). - bool IsLocal() const; - - const GpuCliqueKey& clique_key() const { return clique_key_; } - const std::optional& clique_id() const { return clique_id_; } - - std::string DebugString() const final; - - absl::Status HealthCheck() const final; - - private: - GpuCliqueKey clique_key_; - std::optional clique_id_; -}; - struct NcclCliqueName { - static std::string ToString(const NcclCliqueCommunicators& comms) { + static std::string ToString(const NcclCliqueImpl& comms) { return absl::StrFormat("lockable clique %s", comms.clique_key().ToString()); } }; -class NcclClique : public Lockable { +class NcclClique : public Lockable { public: // We keep acquired cliques in a sorted container to guarantee that all // participants iterate over cliques in the same order.