Skip to content

Commit

Permalink
[xla:collectives] NFC: Move NcclCliqueKey to GpuCliqueKey
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701724894
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Dec 3, 2024
1 parent 50098e0 commit 3c0edb6
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 191 deletions.
38 changes: 36 additions & 2 deletions xla/backends/gpu/collectives/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
load("//xla:xla.bzl", "xla_cc_test")
load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
load("//xla/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")
Expand All @@ -16,13 +17,46 @@ package_group(
],
)

cc_library(
name = "gpu_clique_key",
srcs = ["gpu_clique_key.cc"],
hdrs = ["gpu_clique_key.h"],
deps = [
"//xla/core/collectives",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/service:global_device_id",
"//xla/tsl/lib/gtl:int_type",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:logging",
],
)

xla_cc_test(
name = "gpu_clique_key_test",
srcs = ["gpu_clique_key_test.cc"],
deps = [
":gpu_clique_key",
"//xla/core/collectives:clique_id",
"//xla/service:global_device_id",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/status",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)

cc_library(
name = "gpu_collectives",
srcs = ["gpu_collectives.cc"],
hdrs = ["gpu_collectives.h"],
deps = [
"//xla/core/collectives",
"//xla/tsl/lib/gtl:int_type",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/runtime/nccl_clique_key.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"

#include <cstdint>
#include <string>
#include <utility>
#include <vector>
Expand All @@ -24,19 +25,23 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xla/backends/gpu/collectives/gpu_collectives.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/service/global_device_id.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/logging.h"

namespace xla::gpu {

//===----------------------------------------------------------------------===//
// NcclCliqueKey
//===----------------------------------------------------------------------===//
CollectiveStreamId GetCollectiveStreamId(bool is_async,
AsyncStreamKind stream_kind) {
// TODO(ezhulenev): This implementation does not look correct as stream IDs
// are not really unique. Figure out if it's the case and fix either the code
// or the documentation.
int64_t stream_id = static_cast<int64_t>(stream_kind);
return CollectiveStreamId(is_async ? stream_id + 1 : 0);
}

NcclCliqueKey::NcclCliqueKey(
GpuCliqueKey::GpuCliqueKey(
std::vector<GlobalDeviceId> devices, CollectiveStreamId stream_id,
AsyncStreamKind stream_kind,
std::vector<std::vector<GlobalDeviceId>> participant_groups)
Expand All @@ -57,10 +62,10 @@ NcclCliqueKey::NcclCliqueKey(
absl::c_sort(participant_groups_, compare_groups);
}

CollectiveStreamId NcclCliqueKey::stream_id() const { return stream_id_; }
CollectiveStreamId GpuCliqueKey::stream_id() const { return stream_id_; }

bool NcclCliqueKey::IsSubsetOf(const CliqueKey& other) const {
auto* other_nccl = tsl::down_cast<const NcclCliqueKey*>(&other);
bool GpuCliqueKey::IsSubsetOf(const CliqueKey& other) const {
auto* other_nccl = tsl::down_cast<const GpuCliqueKey*>(&other);
if (other_nccl == nullptr) return false;

return stream_id_ == other_nccl->stream_id_ &&
Expand All @@ -69,7 +74,7 @@ bool NcclCliqueKey::IsSubsetOf(const CliqueKey& other) const {
});
}

std::string NcclCliqueKey::ToString() const {
std::string GpuCliqueKey::ToString() const {
std::string group_string = "";
if (!participant_groups_.empty()) {
std::vector<std::string> values;
Expand All @@ -84,17 +89,17 @@ std::string NcclCliqueKey::ToString() const {
group_string);
}

void NcclCliqueKey::HashValue(absl::HashState state) const {
void GpuCliqueKey::HashValue(absl::HashState state) const {
absl::HashState::combine(std::move(state), devices(), stream_id_,
participant_groups_);
}

bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) {
bool operator==(const GpuCliqueKey& a, const GpuCliqueKey& b) {
return a.devices() == b.devices() && a.stream_id_ == b.stream_id_ &&
a.participant_groups_ == b.participant_groups_;
}

bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b) {
bool operator<(const GpuCliqueKey& a, const GpuCliqueKey& b) {
if (a.devices().size() < b.devices().size()) return true;
if (b.devices().size() < a.devices().size()) return false;

Expand All @@ -104,7 +109,7 @@ bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b) {
return a.stream_id_.value() < b.stream_id_.value();
}

bool operator>(const NcclCliqueKey& a, const NcclCliqueKey& b) {
bool operator>(const GpuCliqueKey& a, const GpuCliqueKey& b) {
if (a.devices().size() > b.devices().size()) return true;
if (b.devices().size() > a.devices().size()) return false;

Expand Down
109 changes: 109 additions & 0 deletions xla/backends/gpu/collectives/gpu_clique_key.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_KEY_H_
#define XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_KEY_H_

#include <cstdint>
#include <string>
#include <vector>

#include "absl/hash/hash.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/service/global_device_id.h"
#include "xla/tsl/lib/gtl/int_type.h"

namespace xla::gpu {

// In XLA:GPU we use different streams for different kinds of collective
// operations, and include the async stream kind into the GPU clique key.
//
// We carefully isolate different kinds of collectives using separate
// communicators and guarantee that all collective operations have a total order
// that will not create a deadlock.
enum class AsyncStreamKind : int64_t {
kCollective = 0, // Stream for asynchronous collective ops.
kP2P0 = 1, // One Stream for P2P Send and Recv ops.
kP2P1 = 2, // Another Stream for P2P Send and Recv ops.
kMemCpyP2P = 3, // Stream for MemCpyP2P
};

inline constexpr int64_t kAsyncStreamTotal =
static_cast<int64_t>(AsyncStreamKind::kMemCpyP2P) + 1;

// Strongly-typed wrapper to represent collective stream ID.
TSL_LIB_GTL_DEFINE_INT_TYPE(CollectiveStreamId, uint64_t);

// Assigns a unique ID to a stream for asynchronous or synchronous execution.
// These IDs can be used, for example, to look up the NCCL communicator.
CollectiveStreamId GetCollectiveStreamId(
bool is_async, AsyncStreamKind stream_kind = AsyncStreamKind::kCollective);

// Clique key for identifying a particular collectives clique on a GPU backend.
class GpuCliqueKey : public CliqueKey {
public:
explicit GpuCliqueKey(
std::vector<GlobalDeviceId> devices,
CollectiveStreamId stream_id = CollectiveStreamId(0),
AsyncStreamKind stream_kind = AsyncStreamKind::kCollective,
std::vector<std::vector<GlobalDeviceId>> participant_groups = {});

CollectiveStreamId stream_id() const;

// Returns true if this clique is a subset of `other`: both cliques have the
// same `stream_id` and all clique devices are part of `other` clique.
bool IsSubsetOf(const CliqueKey& other) const final;

// Returns the stream kind for this clique key, stream kind will be used to
// specify what configuration to pass for each type of operation.
AsyncStreamKind stream_kind() const { return stream_kind_; }

std::string ToString() const final;

// GPU clique keys have a total order on which we rely on for acquiring
// cliques in the same order across all participating devices.
friend bool operator==(const GpuCliqueKey& a, const GpuCliqueKey& b);
friend bool operator<(const GpuCliqueKey& a, const GpuCliqueKey& b);
friend bool operator>(const GpuCliqueKey& a, const GpuCliqueKey& b);

private:
void HashValue(absl::HashState state) const final;

CollectiveStreamId stream_id_;
AsyncStreamKind stream_kind_;

// The full list of groups across all devices which this clique is a part of.
//
// When GPU communicator splitting is enabled, this is used to distinguish
// which cliques can be reused from the cache or must be split in order to
// prevent a deadlock situation.
//
// For example, imagine we have a communicator with devices = [0,1] and
// groups = [0, 1] Later on, we may want to create communicators [0, 1] and
// [2, 3] by splitting [0, 1, 2, 3] If ranks 0 and 1 reuse the existing
// [0, 1] clique but ranks 2 and 3 initiate a split, there will be a deadlock
// since ranks 2, 3 and will be waiting forever for 0, 1 to join the split.
//
// Having the participating groups as part of the cache key will prevent such
// situations
std::vector<std::vector<GlobalDeviceId>> participant_groups_;
};

bool operator==(const GpuCliqueKey& a, const GpuCliqueKey& b);
bool operator<(const GpuCliqueKey& a, const GpuCliqueKey& b);

} // namespace xla::gpu

#endif // XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_KEY_H_
Loading

0 comments on commit 3c0edb6

Please sign in to comment.