Skip to content

Commit

Permalink
[xla:collectives] NFC: Extract NcclCommunicators into GpuClique and N…
Browse files Browse the repository at this point in the history
…cclCliqueImpl

PiperOrigin-RevId: 702499628
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Dec 4, 2024
1 parent cb7ae7c commit d44da53
Show file tree
Hide file tree
Showing 10 changed files with 241 additions and 64 deletions.
35 changes: 35 additions & 0 deletions xla/backends/gpu/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@ package_group(
],
)

cc_library(
name = "gpu_clique",
srcs = ["gpu_clique.cc"],
hdrs = ["gpu_clique.h"],
deps = [
":gpu_clique_key",
"//xla/core/collectives:clique",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"@com_google_absl//absl/container:btree",
],
)

cc_library(
name = "gpu_clique_key",
srcs = ["gpu_clique_key.cc"],
Expand Down Expand Up @@ -72,6 +86,7 @@ cc_library(
],
)

# TODO(b/380457503): Update visibility to "//visibility:private".
cc_library(
name = "nccl_collectives",
hdrs = if_gpu_is_configured(["nccl_collectives.h"]),
Expand All @@ -91,6 +106,26 @@ cc_library(
]),
)

# TODO(b/380457503): Update visibility to "//visibility:private".
cc_library(
name = "nccl_clique",
srcs = ["nccl_clique.cc"],
hdrs = ["nccl_clique.h"],
deps = [
":gpu_clique",
"//xla/core/collectives:clique",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/service/gpu/runtime:nccl_api",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@tsl//tsl/platform:logging",
],
)

# TODO(b/380457503): Update visibility to "//visibility:private".
cc_library(
name = "nccl_communicator",
Expand Down
38 changes: 38 additions & 0 deletions xla/backends/gpu/collectives/gpu_clique.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/backends/gpu/collectives/gpu_clique.h"

#include <memory>
#include <optional>
#include <utility>

#include "absl/container/btree_map.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/core/collectives/clique.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"

namespace xla::gpu {

GpuClique::GpuClique(
GpuCliqueKey clique_key, std::optional<CliqueId> clique_id,
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators)
: Clique(std::move(communicators)),
clique_key_(clique_key),
clique_id_(clique_id) {}

} // namespace xla::gpu
54 changes: 54 additions & 0 deletions xla/backends/gpu/collectives/gpu_clique.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_H_
#define XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_H_

#include <memory>
#include <optional>

#include "absl/container/btree_map.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/core/collectives/clique.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"

namespace xla::gpu {

// A group of GPU communicators making up a clique for a given clique key.
class GpuClique : public Clique {
public:
GpuClique(
GpuCliqueKey clique_key, std::optional<CliqueId> clique_id,
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators);

// Returns true if clique is local: all communicators belong to current
// process. Non-local cliques spans multiple processes (typically hosts).
bool IsLocal() const {
return num_communicators() == clique_key_.devices().size();
}

const GpuCliqueKey& clique_key() const { return clique_key_; }
const std::optional<CliqueId>& clique_id() const { return clique_id_; }

private:
GpuCliqueKey clique_key_;
std::optional<CliqueId> clique_id_;
};

} // namespace xla::gpu

#endif // XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_H_
6 changes: 6 additions & 0 deletions xla/backends/gpu/collectives/gpu_clique_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class GpuCliqueKey : public CliqueKey {
AsyncStreamKind stream_kind = AsyncStreamKind::kCollective,
std::vector<std::vector<GlobalDeviceId>> participant_groups = {});

GpuCliqueKey(const GpuCliqueKey&) = default;
GpuCliqueKey& operator=(const GpuCliqueKey&) = default;

GpuCliqueKey(GpuCliqueKey&&) = default;
GpuCliqueKey& operator=(GpuCliqueKey&&) = default;

CollectiveStreamId stream_id() const;

// Returns true if this clique is a subset of `other`: both cliques have the
Expand Down
58 changes: 58 additions & 0 deletions xla/backends/gpu/collectives/nccl_clique.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/backends/gpu/collectives/nccl_clique.h"

#include <cstdint>
#include <string>
#include <utility>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/service/gpu/runtime/nccl_api.h"
#include "tsl/platform/logging.h"

namespace xla::gpu {

std::string NcclCliqueImpl::DebugString() const {
std::string out = absl::StrFormat(
"clique_key: %s; fingerprint(id): %d; size: %d; communicators: ",
clique_key().ToString(),
clique_id().has_value() ? clique_id()->fingerprint() : 0,
num_communicators());
int32_t cnt = 0;
ForEachComm([&](RankId rank, Communicator* comm) {
if (cnt++) absl::StrAppend(&out, ", ");
absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank.value(), comm);
});
return out;
}

absl::Status NcclCliqueImpl::HealthCheck() const {
absl::Status health_check = absl::OkStatus();
ForEachComm([&health_check](RankId rank, Communicator* comm) {
// TODO(b/380457503): Move error checking API to communicator base class.
if (auto s = NcclApi::Default()->CommGetAsyncError(comm); !s.ok()) {
LOG(ERROR) << "NCCL async error (rank " << rank << "): " << s;
if (health_check.ok()) health_check = std::move(s); // return first error
}
});
return health_check;
}

} // namespace xla::gpu
40 changes: 40 additions & 0 deletions xla/backends/gpu/collectives/nccl_clique.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_GPU_COLLECTIVES_NCCL_CLIQUE_H_
#define XLA_BACKENDS_GPU_COLLECTIVES_NCCL_CLIQUE_H_

#include <string>

#include "absl/status/status.h"
#include "xla/backends/gpu/collectives/gpu_clique.h"

namespace xla::gpu {

// A GPU clique that is implemented on top of NCCL communicators.
//
// TODO(b/380457503): Remove `Impl` suffix once we migrate all users to
// LockableGpuClique.
class NcclCliqueImpl : public GpuClique {
public:
using GpuClique::GpuClique;

std::string DebugString() const final;
absl::Status HealthCheck() const final;
};

} // namespace xla::gpu

#endif // XLA_BACKENDS_GPU_COLLECTIVES_NCCL_CLIQUE_H_
6 changes: 6 additions & 0 deletions xla/core/collectives/clique_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class CliqueKey {
explicit CliqueKey(std::vector<GlobalDeviceId> devices);
virtual ~CliqueKey() = default;

CliqueKey(const CliqueKey& other) = default;
CliqueKey& operator=(const CliqueKey& other) = default;

CliqueKey(CliqueKey&& other) = default;
CliqueKey& operator=(CliqueKey&& other) = default;

// Returns the rank of the global device in the clique.
std::optional<RankId> rank(GlobalDeviceId id) const;

Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ cc_library(
"//xla:types",
"//xla:util",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/backends/gpu/collectives:nccl_clique",
"//xla/core/collectives:clique",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
Expand Down
36 changes: 0 additions & 36 deletions xla/service/gpu/runtime/nccl_clique.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,42 +112,6 @@ static bool TerminateOnNcclError() {
// NcclClique
//===----------------------------------------------------------------------===//

NcclCliqueCommunicators::NcclCliqueCommunicators(
GpuCliqueKey clique_key, std::optional<CliqueId> clique_id,
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators)
: Clique(std::move(communicators)),
clique_key_(std::move(clique_key)),
clique_id_(std::move(clique_id)) {}

bool NcclCliqueCommunicators::IsLocal() const {
return num_communicators() == clique_key_.devices().size();
}

std::string NcclCliqueCommunicators::DebugString() const {
std::string out = absl::StrFormat(
"clique_key: %s; fingerprint(id): %d; size: %d; communicators: ",
clique_key_.ToString(),
clique_id_.has_value() ? clique_id_->fingerprint() : 0,
num_communicators());
int32_t cnt = 0;
ForEachComm([&](RankId rank, Communicator* comm) {
if (cnt++) absl::StrAppend(&out, ", ");
absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank.value(), comm);
});
return out;
}

absl::Status NcclCliqueCommunicators::HealthCheck() const {
absl::Status health_check = absl::OkStatus();
ForEachComm([&health_check](RankId rank, Communicator* comm) {
if (auto s = NcclApi::Default()->CommGetAsyncError(comm); !s.ok()) {
LOG(ERROR) << "NCCL async error (rank " << rank << "): " << s;
if (health_check.ok()) health_check = std::move(s); // return first error
}
});
return health_check;
}

std::string NcclClique::DebugString() const {
return absl::StrFormat("NcclClique: %s", value().DebugString());
}
Expand Down
31 changes: 3 additions & 28 deletions xla/service/gpu/runtime/nccl_clique.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/backends/gpu/collectives/nccl_clique.h"
#include "xla/core/collectives/clique.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/communicator.h"
Expand Down Expand Up @@ -76,39 +77,13 @@ absl::StatusOr<const CliqueIdCallback*> GetCliqueIdCallback(
// NcclClique
//===----------------------------------------------------------------------===//

// A group of NCCL communicators making up a clique. With NCCL it's notoriously
// easy to get a deadlock, so we take extra care by grouping communicators into
// cliques and making sure that we have a well defined order of all collective
// operations that does not lead to deadlocks.
class NcclCliqueCommunicators : public Clique {
public:
NcclCliqueCommunicators(
GpuCliqueKey clique_key, std::optional<CliqueId> clique_id,
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators);

// Return true if clique is local: all communicators belong to current
// process. Non-local cliques spans multiple processes (typically hosts).
bool IsLocal() const;

const GpuCliqueKey& clique_key() const { return clique_key_; }
const std::optional<CliqueId>& clique_id() const { return clique_id_; }

std::string DebugString() const final;

absl::Status HealthCheck() const final;

private:
GpuCliqueKey clique_key_;
std::optional<CliqueId> clique_id_;
};

struct NcclCliqueName {
static std::string ToString(const NcclCliqueCommunicators& comms) {
static std::string ToString(const NcclCliqueImpl& comms) {
return absl::StrFormat("lockable clique %s", comms.clique_key().ToString());
}
};

class NcclClique : public Lockable<NcclCliqueCommunicators, NcclCliqueName> {
class NcclClique : public Lockable<NcclCliqueImpl, NcclCliqueName> {
public:
// We keep acquired cliques in a sorted container to guarantee that all
// participants iterate over cliques in the same order.
Expand Down

0 comments on commit d44da53

Please sign in to comment.