From 3c0edb6b4a1d05c7cad1a6134863a0335aa632e8 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Sun, 1 Dec 2024 10:35:10 -0800 Subject: [PATCH] [xla:collectives] NFC: Move NcclCliqueKey to GpuCliqueKey PiperOrigin-RevId: 701724894 --- xla/backends/gpu/collectives/BUILD | 38 +++++- .../gpu/collectives/gpu_clique_key.cc} | 33 +++--- xla/backends/gpu/collectives/gpu_clique_key.h | 109 ++++++++++++++++++ .../gpu/collectives/gpu_clique_key_test.cc} | 75 ++++++------ .../gpu/collectives/gpu_collectives.cc | 31 ----- .../gpu/collectives/gpu_collectives.h | 27 ----- xla/service/gpu/runtime/BUILD | 18 +-- xla/service/gpu/runtime/nccl_clique_key.h | 65 +---------- 8 files changed, 205 insertions(+), 191 deletions(-) rename xla/{service/gpu/runtime/nccl_clique_key.cc => backends/gpu/collectives/gpu_clique_key.cc} (77%) create mode 100644 xla/backends/gpu/collectives/gpu_clique_key.h rename xla/{service/gpu/runtime/nccl_clique_key_test.cc => backends/gpu/collectives/gpu_clique_key_test.cc} (66%) delete mode 100644 xla/backends/gpu/collectives/gpu_collectives.cc diff --git a/xla/backends/gpu/collectives/BUILD b/xla/backends/gpu/collectives/BUILD index 2c34bc47f4053e..98a32611484373 100644 --- a/xla/backends/gpu/collectives/BUILD +++ b/xla/backends/gpu/collectives/BUILD @@ -1,4 +1,5 @@ load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load("//xla:xla.bzl", "xla_cc_test") load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") load("//xla/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") @@ -16,13 +17,46 @@ package_group( ], ) +cc_library( + name = "gpu_clique_key", + srcs = ["gpu_clique_key.cc"], + hdrs = ["gpu_clique_key.h"], + deps = [ + "//xla/core/collectives", + "//xla/core/collectives:clique_id", + "//xla/core/collectives:clique_key", + "//xla/service:global_device_id", + "//xla/tsl/lib/gtl:int_type", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:casts", + "@tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "gpu_clique_key_test", + srcs = ["gpu_clique_key_test.cc"], + deps = [ + ":gpu_clique_key", + "//xla/core/collectives:clique_id", + "//xla/service:global_device_id", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "gpu_collectives", - srcs = ["gpu_collectives.cc"], hdrs = ["gpu_collectives.h"], deps = [ "//xla/core/collectives", - "//xla/tsl/lib/gtl:int_type", ], ) diff --git a/xla/service/gpu/runtime/nccl_clique_key.cc b/xla/backends/gpu/collectives/gpu_clique_key.cc similarity index 77% rename from xla/service/gpu/runtime/nccl_clique_key.cc rename to xla/backends/gpu/collectives/gpu_clique_key.cc index 5836cf0e1c51d9..d949fb52da85a1 100644 --- a/xla/service/gpu/runtime/nccl_clique_key.cc +++ b/xla/backends/gpu/collectives/gpu_clique_key.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime/nccl_clique_key.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" +#include #include #include #include @@ -24,7 +25,6 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/clique_key.h" #include "xla/service/global_device_id.h" #include "tsl/platform/casts.h" @@ -32,11 +32,16 @@ limitations under the License. namespace xla::gpu { -//===----------------------------------------------------------------------===// -// NcclCliqueKey -//===----------------------------------------------------------------------===// +CollectiveStreamId GetCollectiveStreamId(bool is_async, + AsyncStreamKind stream_kind) { + // TODO(ezhulenev): This implementation does not look correct as stream IDs + // are not really unique. Figure out if it's the case and fix either the code + // or the documentation. + int64_t stream_id = static_cast(stream_kind); + return CollectiveStreamId(is_async ? stream_id + 1 : 0); +} -NcclCliqueKey::NcclCliqueKey( +GpuCliqueKey::GpuCliqueKey( std::vector devices, CollectiveStreamId stream_id, AsyncStreamKind stream_kind, std::vector> participant_groups) @@ -57,10 +62,10 @@ NcclCliqueKey::NcclCliqueKey( absl::c_sort(participant_groups_, compare_groups); } -CollectiveStreamId NcclCliqueKey::stream_id() const { return stream_id_; } +CollectiveStreamId GpuCliqueKey::stream_id() const { return stream_id_; } -bool NcclCliqueKey::IsSubsetOf(const CliqueKey& other) const { - auto* other_nccl = tsl::down_cast(&other); +bool GpuCliqueKey::IsSubsetOf(const CliqueKey& other) const { + auto* other_nccl = tsl::down_cast(&other); if (other_nccl == nullptr) return false; return stream_id_ == other_nccl->stream_id_ && @@ -69,7 +74,7 @@ bool NcclCliqueKey::IsSubsetOf(const CliqueKey& other) const { }); } -std::string NcclCliqueKey::ToString() const { +std::string GpuCliqueKey::ToString() const { std::string group_string = ""; if (!participant_groups_.empty()) { std::vector values; @@ -84,17 +89,17 @@ std::string NcclCliqueKey::ToString() const { group_string); } -void NcclCliqueKey::HashValue(absl::HashState state) const { +void GpuCliqueKey::HashValue(absl::HashState state) const { absl::HashState::combine(std::move(state), devices(), stream_id_, participant_groups_); } -bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) { +bool operator==(const GpuCliqueKey& a, const GpuCliqueKey& b) { return a.devices() == b.devices() && a.stream_id_ == b.stream_id_ && a.participant_groups_ == b.participant_groups_; } -bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b) { +bool operator<(const GpuCliqueKey& a, const GpuCliqueKey& b) { if (a.devices().size() < b.devices().size()) return true; if (b.devices().size() < a.devices().size()) return false; @@ -104,7 +109,7 @@ bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b) { return a.stream_id_.value() < b.stream_id_.value(); } -bool operator>(const NcclCliqueKey& a, const NcclCliqueKey& b) { +bool operator>(const GpuCliqueKey& a, const GpuCliqueKey& b) { if (a.devices().size() > b.devices().size()) return true; if (b.devices().size() > a.devices().size()) return false; diff --git a/xla/backends/gpu/collectives/gpu_clique_key.h b/xla/backends/gpu/collectives/gpu_clique_key.h new file mode 100644 index 00000000000000..7d6a1b79f433c9 --- /dev/null +++ b/xla/backends/gpu/collectives/gpu_clique_key.h @@ -0,0 +1,109 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_KEY_H_ +#define XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_KEY_H_ + +#include +#include +#include + +#include "absl/hash/hash.h" +#include "xla/core/collectives/clique_key.h" +#include "xla/service/global_device_id.h" +#include "xla/tsl/lib/gtl/int_type.h" + +namespace xla::gpu { + +// In XLA:GPU we use different streams for different kinds of collective +// operations, and include the async stream kind into the GPU clique key. +// +// We carefully isolate different kinds of collectives using separate +// communicators and guarantee that all collective operations have a total order +// that will not create a deadlock. +enum class AsyncStreamKind : int64_t { + kCollective = 0, // Stream for asynchronous collective ops. + kP2P0 = 1, // One Stream for P2P Send and Recv ops. + kP2P1 = 2, // Another Stream for P2P Send and Recv ops. + kMemCpyP2P = 3, // Stream for MemCpyP2P +}; + +inline constexpr int64_t kAsyncStreamTotal = + static_cast(AsyncStreamKind::kMemCpyP2P) + 1; + +// Strongly-typed wrapper to represent collective stream ID. +TSL_LIB_GTL_DEFINE_INT_TYPE(CollectiveStreamId, uint64_t); + +// Assigns a unique ID to a stream for asynchronous or synchronous execution. +// These IDs can be used, for example, to look up the NCCL communicator. +CollectiveStreamId GetCollectiveStreamId( + bool is_async, AsyncStreamKind stream_kind = AsyncStreamKind::kCollective); + +// Clique key for identifying a particular collectives clique on a GPU backend. +class GpuCliqueKey : public CliqueKey { + public: + explicit GpuCliqueKey( + std::vector devices, + CollectiveStreamId stream_id = CollectiveStreamId(0), + AsyncStreamKind stream_kind = AsyncStreamKind::kCollective, + std::vector> participant_groups = {}); + + CollectiveStreamId stream_id() const; + + // Returns true if this clique is a subset of `other`: both cliques have the + // same `stream_id` and all clique devices are part of `other` clique. + bool IsSubsetOf(const CliqueKey& other) const final; + + // Returns the stream kind for this clique key, stream kind will be used to + // specify what configuration to pass for each type of operation. + AsyncStreamKind stream_kind() const { return stream_kind_; } + + std::string ToString() const final; + + // GPU clique keys have a total order on which we rely on for acquiring + // cliques in the same order across all participating devices. + friend bool operator==(const GpuCliqueKey& a, const GpuCliqueKey& b); + friend bool operator<(const GpuCliqueKey& a, const GpuCliqueKey& b); + friend bool operator>(const GpuCliqueKey& a, const GpuCliqueKey& b); + + private: + void HashValue(absl::HashState state) const final; + + CollectiveStreamId stream_id_; + AsyncStreamKind stream_kind_; + + // The full list of groups across all devices which this clique is a part of. + // + // When GPU communicator splitting is enabled, this is used to distinguish + // which cliques can be reused from the cache or must be split in order to + // prevent a deadlock situation. + // + // For example, imagine we have a communicator with devices = [0,1] and + // groups = [0, 1] Later on, we may want to create communicators [0, 1] and + // [2, 3] by splitting [0, 1, 2, 3] If ranks 0 and 1 reuse the existing + // [0, 1] clique but ranks 2 and 3 initiate a split, there will be a deadlock + // since ranks 2, 3 and will be waiting forever for 0, 1 to join the split. + // + // Having the participating groups as part of the cache key will prevent such + // situations + std::vector> participant_groups_; +}; + +bool operator==(const GpuCliqueKey& a, const GpuCliqueKey& b); +bool operator<(const GpuCliqueKey& a, const GpuCliqueKey& b); + +} // namespace xla::gpu + +#endif // XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_KEY_H_ diff --git a/xla/service/gpu/runtime/nccl_clique_key_test.cc b/xla/backends/gpu/collectives/gpu_clique_key_test.cc similarity index 66% rename from xla/service/gpu/runtime/nccl_clique_key_test.cc rename to xla/backends/gpu/collectives/gpu_clique_key_test.cc index f6e59fc4c5e7f2..f55b72bdc18c42 100644 --- a/xla/service/gpu/runtime/nccl_clique_key_test.cc +++ b/xla/backends/gpu/collectives/gpu_clique_key_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/runtime/nccl_clique_key.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include #include @@ -24,116 +24,115 @@ limitations under the License. #include #include "absl/container/btree_map.h" -#include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/clique_id.h" #include "xla/service/global_device_id.h" #include "tsl/platform/test.h" namespace xla::gpu { -static NcclCliqueKey GetBaseCliqueKey() { - return NcclCliqueKey({GlobalDeviceId(0), GlobalDeviceId(1)}, - CollectiveStreamId(0), AsyncStreamKind::kCollective, - std::vector>{ - {GlobalDeviceId(0), GlobalDeviceId(1)}, - {GlobalDeviceId(2), GlobalDeviceId(3)}}); +static GpuCliqueKey GetBaseCliqueKey() { + return GpuCliqueKey({GlobalDeviceId(0), GlobalDeviceId(1)}, + CollectiveStreamId(0), AsyncStreamKind::kCollective, + std::vector>{ + {GlobalDeviceId(0), GlobalDeviceId(1)}, + {GlobalDeviceId(2), GlobalDeviceId(3)}}); } -TEST(NcclCliqueKeyTest, IsSubsetOf) { +TEST(GpuCliqueKeyTest, IsSubsetOf) { GlobalDeviceId id0 = GlobalDeviceId(0); GlobalDeviceId id1 = GlobalDeviceId(1); GlobalDeviceId id2 = GlobalDeviceId(2); GlobalDeviceId id3 = GlobalDeviceId(3); - NcclCliqueKey key0({id0, id1}, CollectiveStreamId(0)); - NcclCliqueKey key1({id0, id1, id2, id3}, CollectiveStreamId(0)); - NcclCliqueKey key2({id0, id1, id2, id3}, CollectiveStreamId(1)); - NcclCliqueKey key3({id1, id2, id3}, CollectiveStreamId(0)); + GpuCliqueKey key0({id0, id1}, CollectiveStreamId(0)); + GpuCliqueKey key1({id0, id1, id2, id3}, CollectiveStreamId(0)); + GpuCliqueKey key2({id0, id1, id2, id3}, CollectiveStreamId(1)); + GpuCliqueKey key3({id1, id2, id3}, CollectiveStreamId(0)); EXPECT_TRUE(key0.IsSubsetOf(key1)); EXPECT_FALSE(key0.IsSubsetOf(key2)); EXPECT_FALSE(key0.IsSubsetOf(key3)); } -TEST(NcclCliqueKeyTest, Compare) { +TEST(GpuCliqueKeyTest, Compare) { GlobalDeviceId id0 = GlobalDeviceId(0); GlobalDeviceId id1 = GlobalDeviceId(1); GlobalDeviceId id2 = GlobalDeviceId(2); GlobalDeviceId id3 = GlobalDeviceId(3); - NcclCliqueKey key0({id0, id1}, CollectiveStreamId(0)); - NcclCliqueKey key1({id1, id2, id3}, CollectiveStreamId(0)); - NcclCliqueKey key2({id1, id2, id3}, CollectiveStreamId(1)); + GpuCliqueKey key0({id0, id1}, CollectiveStreamId(0)); + GpuCliqueKey key1({id1, id2, id3}, CollectiveStreamId(0)); + GpuCliqueKey key2({id1, id2, id3}, CollectiveStreamId(1)); EXPECT_LT(key0, key1); EXPECT_GT(key1, key0); EXPECT_LT(key1, key2); } -TEST(NcclCliqueKeyTest, CompareWithParticipantGroups) { +TEST(GpuCliqueKeyTest, CompareWithParticipantGroups) { GlobalDeviceId id0 = GlobalDeviceId(0); GlobalDeviceId id1 = GlobalDeviceId(1); GlobalDeviceId id2 = GlobalDeviceId(2); GlobalDeviceId id3 = GlobalDeviceId(3); // The keys are not equal because the replica groups are different. - NcclCliqueKey key0({id0, id1}, CollectiveStreamId(0), - AsyncStreamKind::kCollective, - std::vector>{{id0, id1}}); - NcclCliqueKey key1( + GpuCliqueKey key0({id0, id1}, CollectiveStreamId(0), + AsyncStreamKind::kCollective, + std::vector>{{id0, id1}}); + GpuCliqueKey key1( {id0, id1}, CollectiveStreamId(0), AsyncStreamKind::kCollective, std::vector>{{id0, id1}, {id2, id3}}); EXPECT_FALSE(key0 == key1); // With no replica groups, the keys are equal - NcclCliqueKey key0_nogroups({id0, id1}, CollectiveStreamId(0)); - NcclCliqueKey key1_nogroups({id0, id1}, CollectiveStreamId(0)); + GpuCliqueKey key0_nogroups({id0, id1}, CollectiveStreamId(0)); + GpuCliqueKey key1_nogroups({id0, id1}, CollectiveStreamId(0)); EXPECT_EQ(key0_nogroups, key1_nogroups); } -TEST(NcclCliqueKeyTest, CompareWithPermutedParticipantGroups) { +TEST(GpuCliqueKeyTest, CompareWithPermutedParticipantGroups) { GlobalDeviceId id0 = GlobalDeviceId(0); GlobalDeviceId id1 = GlobalDeviceId(1); GlobalDeviceId id2 = GlobalDeviceId(2); GlobalDeviceId id3 = GlobalDeviceId(3); // The keys are equal because the replica groups are same up to permutation. - NcclCliqueKey key0( + GpuCliqueKey key0( {id0, id1}, CollectiveStreamId(0), AsyncStreamKind::kCollective, std::vector>{{id3, id2}, {id0, id1}}); - NcclCliqueKey key1( + GpuCliqueKey key1( {id0, id1}, CollectiveStreamId(0), AsyncStreamKind::kCollective, std::vector>{{id0, id1}, {id2, id3}}); EXPECT_EQ(key0, key1); - NcclCliqueKey key_other( + GpuCliqueKey key_other( {id0, id1}, CollectiveStreamId(0), AsyncStreamKind::kCollective, std::vector>{{id0, id2}, {id1, id3}}); EXPECT_FALSE(key0 == key_other); } -TEST(NcclCliqueKeyTest, BtreeIterationOrder) { +TEST(GpuCliqueKeyTest, BtreeIterationOrder) { GlobalDeviceId id0 = GlobalDeviceId(0); GlobalDeviceId id1 = GlobalDeviceId(1); GlobalDeviceId id2 = GlobalDeviceId(2); GlobalDeviceId id3 = GlobalDeviceId(3); - NcclCliqueKey key0({id0, id2}, CollectiveStreamId(0)); - NcclCliqueKey key1({id0, id1, id2, id3}, CollectiveStreamId(0)); + GpuCliqueKey key0({id0, id2}, CollectiveStreamId(0)); + GpuCliqueKey key1({id0, id1, id2, id3}, CollectiveStreamId(0)); - absl::btree_map> map; + absl::btree_map> map; map[key0] = 0; map[key1] = 1; EXPECT_EQ(map.begin()->first, key1); } -TEST(NcclCliqueKeyGettersTest, Devices) { +TEST(GpuCliqueKeyGettersTest, Devices) { EXPECT_THAT( GetBaseCliqueKey().devices(), ::testing::UnorderedElementsAre(GlobalDeviceId(0), GlobalDeviceId(1))); } -TEST(NcclCliqueKeyGettersTest, Rank) { +TEST(GpuCliqueKeyGettersTest, Rank) { auto key = GetBaseCliqueKey(); EXPECT_EQ(key.rank(GlobalDeviceId(0)), 0); EXPECT_EQ(key.rank(GlobalDeviceId(1)), 1); @@ -141,23 +140,23 @@ TEST(NcclCliqueKeyGettersTest, Rank) { EXPECT_EQ(key.rank(GlobalDeviceId(3)), std::nullopt); } -TEST(NcclCliqueKeyGettersTest, StreamId) { +TEST(GpuCliqueKeyGettersTest, StreamId) { EXPECT_EQ(GetBaseCliqueKey().stream_id(), CollectiveStreamId(0)); } -TEST(NcclCliqueKeyGetterTest, ToString) { +TEST(GpuCliqueKeyGetterTest, ToString) { EXPECT_EQ(GetBaseCliqueKey().ToString(), "devices=[0,1]; stream=0; groups=[[0,1],[2,3]]"); } -TEST(NcclCliqueIdGettersTest, Data) { +TEST(GpuCliqueIdGettersTest, Data) { std::array id; std::fill(id.begin(), id.end(), 0x01); CliqueId clique_id(id.data()); EXPECT_EQ(std::memcmp(clique_id.data().data(), id.data(), 128), 0); } -TEST(NcclCliqueIdStringTest, ToString) { +TEST(GpuCliqueIdStringTest, ToString) { std::array id; std::fill(id.begin(), id.end(), 0x01); CliqueId clique_id(id.data()); diff --git a/xla/backends/gpu/collectives/gpu_collectives.cc b/xla/backends/gpu/collectives/gpu_collectives.cc deleted file mode 100644 index 456d56b99d1233..00000000000000 --- a/xla/backends/gpu/collectives/gpu_collectives.cc +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/backends/gpu/collectives/gpu_collectives.h" - -#include - -namespace xla::gpu { - -CollectiveStreamId GetCollectiveStreamId(bool is_async, - AsyncStreamKind stream_kind) { - // TODO(ezhulenev): This implementation does not look correct as stream IDs - // are not really unique. Figure out if it's the case and fix either the code - // or the documentation. - int64_t stream_id = static_cast(stream_kind); - return CollectiveStreamId(is_async ? stream_id + 1 : 0); -} - -} // namespace xla::gpu diff --git a/xla/backends/gpu/collectives/gpu_collectives.h b/xla/backends/gpu/collectives/gpu_collectives.h index 4dcbc29873cc4a..13e8c1cd1aea3f 100644 --- a/xla/backends/gpu/collectives/gpu_collectives.h +++ b/xla/backends/gpu/collectives/gpu_collectives.h @@ -16,37 +16,10 @@ limitations under the License. #ifndef XLA_BACKENDS_GPU_COLLECTIVES_GPU_COLLECTIVES_H_ #define XLA_BACKENDS_GPU_COLLECTIVES_GPU_COLLECTIVES_H_ -#include - #include "xla/core/collectives/collectives.h" -#include "xla/tsl/lib/gtl/int_type.h" namespace xla::gpu { -// In XLA:GPU we use different streams for different kinds of collective -// operations, and include the async stream kind into the GPU clique key. -// -// We carefully isolate different kinds of collectives using separate -// communicators and guarantee that all collective operations have a total order -// that will not create a deadlock. -enum class AsyncStreamKind : int64_t { - kCollective = 0, // Stream for asynchronous collective ops. - kP2P0 = 1, // One Stream for P2P Send and Recv ops. - kP2P1 = 2, // Another Stream for P2P Send and Recv ops. - kMemCpyP2P = 3, // Stream for MemCpyP2P -}; - -inline constexpr int64_t kAsyncStreamTotal = - static_cast(AsyncStreamKind::kMemCpyP2P) + 1; - -// Strongly-typed wrapper to represent collective stream ID. -TSL_LIB_GTL_DEFINE_INT_TYPE(CollectiveStreamId, uint64_t); - -// Assigns a unique ID to a stream for asynchronous or synchronous execution. -// These IDs can be used, for example, to look up the NCCL communicator. -CollectiveStreamId GetCollectiveStreamId( - bool is_async, AsyncStreamKind stream_kind = AsyncStreamKind::kCollective); - // XLA:GPU extension of the Collectives interface with GPU-specific APIs. class GpuCollectives : public Collectives { public: diff --git a/xla/service/gpu/runtime/BUILD b/xla/service/gpu/runtime/BUILD index e102433f09c11a..62fa7f6a373d2e 100644 --- a/xla/service/gpu/runtime/BUILD +++ b/xla/service/gpu/runtime/BUILD @@ -319,10 +319,10 @@ cc_library( cc_library( name = "nccl_clique_key", - srcs = ["nccl_clique_key.cc"], hdrs = ["nccl_clique_key.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", @@ -343,22 +343,6 @@ cc_library( ], ) -xla_cc_test( - name = "nccl_clique_key_test", - srcs = ["nccl_clique_key_test.cc"], - deps = [ - ":nccl_clique_key", - "//xla/backends/gpu/collectives:gpu_collectives", - "//xla/core/collectives:clique_id", - "//xla/service:global_device_id", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/status", - "@tsl//tsl/platform:status_matchers", - "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_main", - ], -) - #===-------------------------------------------------------------------------------------------===// # XLA Thunks Runtime #===-------------------------------------------------------------------------------------------===// diff --git a/xla/service/gpu/runtime/nccl_clique_key.h b/xla/service/gpu/runtime/nccl_clique_key.h index 4367ef45fa9d9e..c46395647548ee 100644 --- a/xla/service/gpu/runtime/nccl_clique_key.h +++ b/xla/service/gpu/runtime/nccl_clique_key.h @@ -16,17 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_CLIQUE_KEY_H_ #define XLA_SERVICE_GPU_RUNTIME_NCCL_CLIQUE_KEY_H_ -#include -#include -#include -#include - -#include "absl/hash/hash.h" -#include "absl/status/statusor.h" -#include "xla/backends/gpu/collectives/gpu_collectives.h" -#include "xla/core/collectives/clique_id.h" -#include "xla/core/collectives/clique_key.h" -#include "xla/service/global_device_id.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" namespace xla::gpu { @@ -34,57 +24,8 @@ namespace xla::gpu { // NcclCliqueKey //===----------------------------------------------------------------------===// -// Key for naming up a particular NCCL clique. This is just a set of unique -// device IDs (i.e. GPU IDs) and a stream_id. The device IDs must be global -// within a cluster. The stream_id is used to create different NCCL clique and -// communicators for collectives executed on different streams within an -// executable. -class NcclCliqueKey : public CliqueKey { - public: - explicit NcclCliqueKey( - std::vector devices, - CollectiveStreamId stream_id = CollectiveStreamId(0), - AsyncStreamKind stream_kind = AsyncStreamKind::kCollective, - std::vector> participant_groups = {}); - - CollectiveStreamId stream_id() const; - - // Returns true if this clique is a subset of `other`: both cliques have the - // same `stream_id` and all clique devices are part of `other` clique. - bool IsSubsetOf(const CliqueKey& other) const final; - - // Returns the stream kind for this clique key, - // stream kind will be used to specify what configuration - // to pass for each type of operation. - AsyncStreamKind stream_kind() const { return stream_kind_; } - - std::string ToString() const final; - - friend bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b); - friend bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b); - friend bool operator>(const NcclCliqueKey& a, const NcclCliqueKey& b); - - private: - void HashValue(absl::HashState state) const final; - - CollectiveStreamId stream_id_; - AsyncStreamKind stream_kind_; - - // The full list of groups across all devices which this clique is a part of. - // When enable_nccl_comm_splitting is enabled, this is used to distinguish - // which cliques can be reused from the cache or must be split in order to - // prevent a deadlock situation. - // For example, imagine we have a communicator with devices = [0,1] and groups - // = [0, 1] Later on, we may want to create communicators [0, 1] and [2, 3] by - // splitting [0, 1, 2, 3] If ranks 0 and 1 reuse the exisiting [0, 1] clique - // but ranks 2 and 3 initiate a split, there will be a deadlock since ranks 2, - // 3 and will be waiting forever for 0, 1 to join the split. Having the - // particating groups as part of the cache key will prevent such situations - std::vector> participant_groups_; -}; - -bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b); -bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b); +// TODO(b/380457503): Delete this alias. +using NcclCliqueKey = GpuCliqueKey; } // namespace xla::gpu