Skip to content

Commit

Permalink
[xla:collectives] NFC: Remove unused NcclApi CommFinalize function
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702108779
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Dec 3, 2024
1 parent e9947dd commit 503738a
Show file tree
Hide file tree
Showing 38 changed files with 424 additions and 318 deletions.
21 changes: 21 additions & 0 deletions xla/backends/gpu/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ package_group(
],
)

cc_library(
name = "gpu_clique",
srcs = ["gpu_clique.cc"],
hdrs = ["gpu_clique.h"],
deps = [
":gpu_clique_key",
"//xla/core/collectives:clique",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@tsl//tsl/platform:logging",
],
)

cc_library(
name = "gpu_clique_key",
srcs = ["gpu_clique_key.cc"],
Expand Down Expand Up @@ -71,6 +89,7 @@ cc_library(
],
)

# TODO(b/380457503): Update visibility to "//visibility:private".
cc_library(
name = "nccl_collectives",
hdrs = if_gpu_is_configured(["nccl_collectives.h"]),
Expand Down Expand Up @@ -102,7 +121,9 @@ cc_library(
]),
deps = [
":nccl_errors",
"//xla:util",
"//xla/core/collectives:communicator",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
"@tsl//tsl/platform:logging",
] + if_cuda_is_configured([
Expand Down
66 changes: 66 additions & 0 deletions xla/backends/gpu/collectives/gpu_clique.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/backends/gpu/collectives/gpu_clique.h"

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>

#include "absl/container/btree_map.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/core/collectives/clique.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "tsl/platform/logging.h"

namespace xla::gpu {

GpuClique::GpuClique(
GpuCliqueKey key, std::optional<CliqueId> id,
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators)
: Clique(std::move(communicators)), key_(key), id_(id) {}

std::string GpuClique::DebugString() const {
std::string out =
absl::StrFormat("key: %s; fingerprint(id): %d; size: %d; communicators: ",
key_.ToString(), id_.has_value() ? id_->fingerprint() : 0,
num_communicators());
int32_t cnt = 0;
ForEachComm([&](RankId rank, Communicator* comm) {
if (cnt++) absl::StrAppend(&out, ", ");
absl::StrAppendFormat(&out, "[rank=%d, comm=%p]", rank.value(), comm);
});
return out;
}

absl::Status GpuClique::HealthCheck() const {
absl::Status health_check = absl::OkStatus();
ForEachComm([&health_check](RankId rank, Communicator* comm) {
if (auto s = comm->HealthCheck(); !s.ok()) {
LOG(ERROR) << "GPU communicator error (rank " << rank << "): " << s;
if (health_check.ok()) health_check = std::move(s); // return first error
}
});
return health_check;
}

} // namespace xla::gpu
57 changes: 57 additions & 0 deletions xla/backends/gpu/collectives/gpu_clique.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_H_
#define XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_H_

#include <memory>
#include <optional>
#include <string>

#include "absl/container/btree_map.h"
#include "absl/status/status.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/core/collectives/clique.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"

namespace xla::gpu {

// A group of GPU communicators making up a clique for a given clique key.
class GpuClique : public Clique {
public:
GpuClique(
GpuCliqueKey key, std::optional<CliqueId> id,
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators);

// Returns true if clique is local: all communicators belong to current
// process. Non-local cliques spans multiple processes (typically hosts).
bool IsLocal() const { return num_communicators() == key_.devices().size(); }

const GpuCliqueKey& key() const { return key_; }
const std::optional<CliqueId>& id() const { return id_; }

std::string DebugString() const final;
absl::Status HealthCheck() const final;

private:
GpuCliqueKey key_;
std::optional<CliqueId> id_;
};

} // namespace xla::gpu

#endif // XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_H_
6 changes: 6 additions & 0 deletions xla/backends/gpu/collectives/gpu_clique_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class GpuCliqueKey : public CliqueKey {
AsyncStreamKind stream_kind = AsyncStreamKind::kCollective,
std::vector<std::vector<GlobalDeviceId>> participant_groups = {});

GpuCliqueKey(const GpuCliqueKey&) = default;
GpuCliqueKey& operator=(const GpuCliqueKey&) = default;

GpuCliqueKey(GpuCliqueKey&&) = default;
GpuCliqueKey& operator=(GpuCliqueKey&&) = default;

CollectiveStreamId stream_id() const;

// Returns true if this clique is a subset of `other`: both cliques have the
Expand Down
18 changes: 18 additions & 0 deletions xla/backends/gpu/collectives/nccl_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ limitations under the License.

#include <string>

#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "xla/backends/gpu/collectives/nccl_errors.h"
#include "xla/util.h"
#include "tsl/platform/logging.h"

#if TENSORFLOW_USE_ROCM
Expand All @@ -43,6 +45,22 @@ NcclCommunicator::~NcclCommunicator() {
XLA_NCCL_LOG_IF_ERROR(ncclCommDestroy(comm_));
}

absl::Status NcclCommunicator::Abort() {
VLOG(1) << "Abort NCCL communicator: " << ToString();
return XLA_NCCL_STATUS(ncclCommAbort(comm_));
}

absl::Status NcclCommunicator::HealthCheck() const {
VLOG(5) << "Get last async error for NCCL communicator: " << ToString();

ncclResult_t async_err;
XLA_NCCL_RETURN_IF_ERROR(ncclCommGetAsyncError(comm_, &async_err));
if (async_err == ncclSuccess) return absl::OkStatus();

return Internal("%s. Last NCCL error (maybe unrelated): %s",
ncclGetLastError(comm_), ncclGetErrorString(async_err));
}

std::string NcclCommunicator::ToString() const {
return absl::StrFormat("NccCommunicator(ncclComm_t=%p)", comm_);
}
Expand Down
4 changes: 4 additions & 0 deletions xla/backends/gpu/collectives/nccl_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include <string>

#include "absl/status/status.h"
#include "xla/core/collectives/communicator.h"

#if TENSORFLOW_USE_ROCM
Expand All @@ -39,6 +40,9 @@ class NcclCommunicator : public Communicator {
explicit NcclCommunicator(ncclComm_t comm);
~NcclCommunicator() override;

absl::Status Abort() final;
absl::Status HealthCheck() const final;

std::string ToString() const final;

ncclComm_t comm() const { return comm_; }
Expand Down
18 changes: 18 additions & 0 deletions xla/core/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@ package_group(
],
)

cc_library(
name = "clique",
srcs = ["clique.cc"],
hdrs = ["clique.h"],
deps = [
":clique_id",
":communicator",
":rank_id",
"//xla:util",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/status",
],
)

cc_library(
name = "collectives",
hdrs = ["collectives.h"],
Expand All @@ -25,6 +40,9 @@ cc_library(
cc_library(
name = "communicator",
hdrs = ["communicator.h"],
deps = [
"@com_google_absl//absl/status",
],
)

cc_library(
Expand Down
47 changes: 47 additions & 0 deletions xla/core/collectives/clique.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/core/collectives/clique.h"

#include <memory>
#include <optional>
#include <utility>

#include "absl/container/btree_map.h"
#include "absl/functional/function_ref.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"

namespace xla {

Clique::Clique(
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators)
: communicators_(std::move(communicators)) {}

std::optional<Communicator*> Clique::comm(RankId rank) const {
if (auto it = communicators_.find(rank); it != communicators_.end()) {
return it->second.get();
}
return std::nullopt;
}

void Clique::ForEachComm(
absl::FunctionRef<void(RankId, Communicator*)> fn) const {
for (auto& [rank, comm] : communicators_) {
fn(rank, comm.get());
}
}

} // namespace xla
71 changes: 71 additions & 0 deletions xla/core/collectives/clique.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_CORE_COLLECTIVES_CLIQUE_H_
#define XLA_CORE_COLLECTIVES_CLIQUE_H_

// A group of NCCL communicators making up a clique. With NCCL it's notoriously
// easy to get a deadlock, so we take extra care by grouping communicators into
// cliques and making sure that we have a well defined order of all collective
// operations that does not lead to deadlocks.

#include <cstddef>
#include <memory>
#include <optional>
#include <string>

#include "absl/container/btree_map.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"

namespace xla {

// A group of collective communicators for make up a clique.
//
// We use clique mechanism to group communicators to be able to efficiently
// get exclusive access to all communicators in a clique, as we typically have
// to guarantee that collective operations on all ranks are executed in the
// same order across all devices.
class Clique {
public:
explicit Clique(
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators);
virtual ~Clique() = default;

// Returns a communicator for a given rank if it's in a clique.
std::optional<Communicator*> comm(RankId rank) const;

// Calls `fn` for each communicator in the clique.
void ForEachComm(absl::FunctionRef<void(RankId, Communicator*)> fn) const;

// Checks that all communicators in the clique are in a healthy state.
virtual absl::Status HealthCheck() const = 0;

// Returns a human-readable string representation of the clique.
virtual std::string DebugString() const = 0;

size_t num_communicators() const { return communicators_.size(); }

private:
// We keep communicators in a sorted order by rank to guarantee deterministic
// traversal order in `ForEachComm`.
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators_;
};

} // namespace xla

#endif // XLA_CORE_COLLECTIVES_CLIQUE_H_
6 changes: 6 additions & 0 deletions xla/core/collectives/clique_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class CliqueKey {
explicit CliqueKey(std::vector<GlobalDeviceId> devices);
virtual ~CliqueKey() = default;

CliqueKey(const CliqueKey& other) = default;
CliqueKey& operator=(const CliqueKey& other) = default;

CliqueKey(CliqueKey&& other) = default;
CliqueKey& operator=(CliqueKey&& other) = default;

// Returns the rank of the global device in the clique.
std::optional<RankId> rank(GlobalDeviceId id) const;

Expand Down
Loading

0 comments on commit 503738a

Please sign in to comment.