Skip to content

Commit

Permalink
[xla:collectives] NFC: Move NcclApi::CommCount to Communicator API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702113027
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Dec 4, 2024
1 parent 41910db commit 9ee3c1a
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 45 deletions.
1 change: 1 addition & 0 deletions xla/backends/gpu/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ cc_library(
"//xla:util",
"//xla/core/collectives:communicator",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@tsl//tsl/platform:logging",
] + if_cuda_is_configured([
Expand Down
9 changes: 9 additions & 0 deletions xla/backends/gpu/collectives/nccl_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "xla/backends/gpu/collectives/nccl_communicator.h"

#include <cstddef>
#include <cstdint>
#include <string>

#include "absl/status/status.h"
Expand Down Expand Up @@ -61,6 +63,13 @@ absl::Status NcclCommunicator::HealthCheck() const {
ncclGetLastError(comm_), ncclGetErrorString(async_err));
}

absl::StatusOr<size_t> NcclCommunicator::NumRanks() const {
VLOG(5) << "Get the number of ranks in NCCL communicator: " << ToString();
int32_t count;
XLA_NCCL_RETURN_IF_ERROR(ncclCommCount(comm_, &count));
return count;
}

std::string NcclCommunicator::ToString() const {
return absl::StrFormat("NccCommunicator(ncclComm_t=%p)", comm_);
}
Expand Down
3 changes: 3 additions & 0 deletions xla/backends/gpu/collectives/nccl_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ limitations under the License.
#ifndef XLA_BACKENDS_GPU_COLLECTIVES_NCCL_COMMUNICATOR_H_
#define XLA_BACKENDS_GPU_COLLECTIVES_NCCL_COMMUNICATOR_H_

#include <cstddef>
#include <string>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/core/collectives/communicator.h"

#if TENSORFLOW_USE_ROCM
Expand All @@ -42,6 +44,7 @@ class NcclCommunicator : public Communicator {

absl::Status Abort() final;
absl::Status HealthCheck() const final;
absl::StatusOr<size_t> NumRanks() const final;

std::string ToString() const final;

Expand Down
1 change: 1 addition & 0 deletions xla/core/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ cc_library(
hdrs = ["communicator.h"],
deps = [
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

Expand Down
5 changes: 5 additions & 0 deletions xla/core/collectives/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ limitations under the License.
#ifndef XLA_CORE_COLLECTIVES_COMMUNICATOR_H_
#define XLA_CORE_COLLECTIVES_COMMUNICATOR_H_

#include <cstddef>
#include <ostream>
#include <string>

#include "absl/status/status.h"
#include "absl/status/statusor.h"

namespace xla {

Expand All @@ -38,6 +40,9 @@ class Communicator {
// have to wait for the completion of scheduled operations.
virtual absl::Status HealthCheck() const = 0;

// Returns the number of ranks in the communicator.
virtual absl::StatusOr<size_t> NumRanks() const = 0;

virtual std::string ToString() const = 0;
};

Expand Down
7 changes: 3 additions & 4 deletions xla/service/gpu/runtime/nccl_all_reduce_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,21 +232,20 @@ absl::Status RunReduceScatter(NcclApi* nccl_api, ReductionKind reduction_kind,
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api, stream.parent(), buffers, comm));

TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm));
TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks());

TF_RETURN_IF_ERROR(nccl_api->GroupStart());

for (DeviceBufferPair& buffer : buffers) {
// buffer.element_count is the source buffers element count. For
// ncclReduceScatter, we need the destination buffers element count.
TF_RET_CHECK(buffer.element_count % num_participants == 0)
TF_RET_CHECK(buffer.element_count % num_ranks == 0)
<< "Source buffer was not an exact multiple of the number of "
"participants.";

TF_RETURN_IF_ERROR(nccl_api->ReduceScatter(
buffer.source_buffer, buffer.destination_buffer, buffer.element_type,
buffer.element_count / num_participants, reduction_kind, comm,
&stream));
buffer.element_count / num_ranks, reduction_kind, comm, &stream));
}

return nccl_api->GroupEnd();
Expand Down
43 changes: 20 additions & 23 deletions xla/service/gpu/runtime/nccl_all_to_all_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,12 @@ absl::Status NcclAllToAllStartThunk::Initialize(
GetNcclComm(*params.collective_params, *params.collective_cliques,
config().replica_groups, config().group_mode, stream_id,
stream_kind));
TF_ASSIGN_OR_RETURN(int32_t num_participants,
nccl_api()->CommCount(comm_handle.comm));
int local_id = params.stream->parent()->device_ordinal() % num_participants;
TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm_handle.comm->NumRanks());
int local_id = params.stream->parent()->device_ordinal() % num_ranks;
{
absl::MutexLock lock(&pointer_maps_mutex_);
if (!send_pointer_maps_.count(local_id)) {
for (int i = 0; i < num_participants; ++i) {
for (int i = 0; i < num_ranks; ++i) {
if (!params.stream->parent()->HostMemoryRegister(
&send_pointer_maps_[local_id][i], sizeof(void*))) {
VLOG(5) << "Registering host send pointer for memcpy failed.";
Expand All @@ -144,10 +143,9 @@ absl::Status NcclAllToAllStartThunk::Cleanup(const CleanupParams& params) {
GetNcclComm(*params.collective_params, *params.collective_cliques,
config().replica_groups, config().group_mode, stream_id,
stream_kind));
TF_ASSIGN_OR_RETURN(int32_t num_participants,
nccl_api()->CommCount(comm_handle.comm));
TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm_handle.comm->NumRanks());

int local_id = params.executor->device_ordinal() % num_participants;
int local_id = params.executor->device_ordinal() % num_ranks;
{
absl::MutexLock lock(&pointer_maps_mutex_);
if (send_pointer_maps_.count(local_id)) {
Expand Down Expand Up @@ -176,11 +174,10 @@ absl::Status NcclAllToAllStartThunk::RunNcclCollective(
std::vector<DeviceBufferPair> device_buffers,
ConvertToDeviceBuffers(params, buffers_,
config_.config.operand_element_type));
TF_ASSIGN_OR_RETURN(int32_t num_participants,
nccl_api()->CommCount(comm_handle.comm));
TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm_handle.comm->NumRanks());

if (is_local() && p2p_memcpy_enabled_) {
int local_id = stream.parent()->device_ordinal() % num_participants;
int local_id = stream.parent()->device_ordinal() % num_ranks;
absl::flat_hash_map<int64_t, uint64_t>* send_pointer_map = nullptr;
absl::flat_hash_map<int64_t, uint64_t>* receive_pointer_map = nullptr;
{
Expand Down Expand Up @@ -222,7 +219,7 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension,
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api, stream.parent(), buffers, comm));

TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm));
TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks());

TF_RETURN_IF_ERROR(nccl_api->GroupStart());

Expand All @@ -232,12 +229,12 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension,
// and produces a tuple of outputs.
if (has_split_dimension) {
for (DeviceBufferPair& buffer : buffers) {
TF_RET_CHECK(buffer.element_count % num_participants == 0)
TF_RET_CHECK(buffer.element_count % num_ranks == 0)
<< "Buffer was not an exact multiple of the number of participants.";

size_t chunk_elements = buffer.element_count / num_participants;
size_t chunk_elements = buffer.element_count / num_ranks;

for (int peer = 0; peer < num_participants; ++peer) {
for (int peer = 0; peer < num_ranks; ++peer) {
se::DeviceMemoryBase send_slice =
NcclApi::Slice(buffer.source_buffer, buffer.element_type,
peer * chunk_elements, chunk_elements);
Expand All @@ -254,7 +251,7 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension,
}
}
} else {
TF_RET_CHECK(buffers.size() == num_participants)
TF_RET_CHECK(buffers.size() == num_ranks)
<< "Number of inputs didn't match the number of participants.";

for (size_t i = 0; i < buffers.size(); ++i) {
Expand Down Expand Up @@ -285,21 +282,21 @@ absl::Status RunMemCpyAllToAll(
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(nccl_api, stream.parent(), buffers, comm));

TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm));
TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks());

// AllToAll can operate in two modes. Either it specifies a split dimension,
// in which case inputs are split and outputs concatenated in that dimension
// (here, we only support dimension 0), or it takes a list of inputs
// and produces a tuple of outputs.
if (has_split_dimension) {
for (DeviceBufferPair& buffer : buffers) {
TF_RET_CHECK(buffer.element_count % num_participants == 0)
TF_RET_CHECK(buffer.element_count % num_ranks == 0)
<< "Buffer was not an exact multiple of the number of participants.";

size_t chunk_elements = buffer.element_count / num_participants;
size_t chunk_elements = buffer.element_count / num_ranks;

TF_RETURN_IF_ERROR(nccl_api->GroupStart());
for (int peer = 0; peer < num_participants; ++peer) {
for (int peer = 0; peer < num_ranks; ++peer) {
se::DeviceMemoryBase recv_slice =
NcclApi::Slice(buffer.destination_buffer, buffer.element_type,
peer * chunk_elements, chunk_elements);
Expand All @@ -313,7 +310,7 @@ absl::Status RunMemCpyAllToAll(
TF_RETURN_IF_ERROR(nccl_api->GroupEnd());
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());

for (int peer = 0; peer < num_participants; ++peer) {
for (int peer = 0; peer < num_ranks; ++peer) {
se::DeviceMemoryBase send_slice =
NcclApi::Slice(buffer.source_buffer, buffer.element_type,
peer * chunk_elements, chunk_elements);
Expand All @@ -324,11 +321,11 @@ absl::Status RunMemCpyAllToAll(
}
}
} else {
TF_RET_CHECK(buffers.size() == num_participants)
TF_RET_CHECK(buffers.size() == num_ranks)
<< "Number of inputs didn't match the number of participants.";

TF_RETURN_IF_ERROR(nccl_api->GroupStart());
for (int peer = 0; peer < num_participants; ++peer) {
for (int peer = 0; peer < num_ranks; ++peer) {
send_pointer_map[peer] =
(uint64_t)buffers[peer].destination_buffer.opaque();

Expand All @@ -340,7 +337,7 @@ absl::Status RunMemCpyAllToAll(
TF_RETURN_IF_ERROR(nccl_api->GroupEnd());
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());

for (int peer = 0; peer < num_participants; ++peer) {
for (int peer = 0; peer < num_ranks; ++peer) {
// double buffer, exchange data with peer
se::DeviceMemoryBase dst_addr =
se::DeviceMemoryBase((void*)receive_pointer_map[peer]);
Expand Down
9 changes: 0 additions & 9 deletions xla/service/gpu/runtime/nccl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,6 @@ class DefaultNcclApi final : public NcclApi {
absl::Span<const Communicator* const> comms, int32_t color,
absl::Span<const RankId> keys, std::optional<Config> config) final;

absl::StatusOr<int32_t> CommCount(Communicator* comm) final;

absl::Status GroupStart() final;
absl::Status GroupEnd() final;

Expand Down Expand Up @@ -478,13 +476,6 @@ DefaultNcclApi::CommSplit(absl::Span<const Communicator* const> comms,
#endif // !defined(TENSORFLOW_USE_ROCM) || TF_ROCM_VERSION >= 60000
}

absl::StatusOr<int32_t> DefaultNcclApi::CommCount(Communicator* comm) {
VLOG(5) << "Get the number of ranks in NCCL communicator: " << comm;
int32_t count;
XLA_NCCL_RETURN_IF_ERROR(ncclCommCount(Cast(comm), &count));
return count;
}

absl::Status DefaultNcclApi::GroupStart() {
VLOG(5) << "Start NCCL group";
return XLA_NCCL_STATUS(ncclGroupStart());
Expand Down
5 changes: 0 additions & 5 deletions xla/service/gpu/runtime/nccl_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,6 @@ class NcclApi : public GpuCollectives {
absl::Span<const Communicator* const> comms, int32_t color,
absl::Span<const RankId> keys, std::optional<Config> config) = 0;

// Returns the number of ranks in the NCCL communicator comm.
//
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommcount
virtual absl::StatusOr<int32_t> CommCount(Communicator* comm) = 0;

// Starts a group call.
//
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/group.html#ncclgroupstart
Expand Down
4 changes: 0 additions & 4 deletions xla/service/gpu/runtime/nccl_api_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ class NcclApiStub final : public NcclApi {
return UnimplementedError();
}

absl::StatusOr<int32_t> CommCount(Communicator*) final {
return UnimplementedError();
}

absl::Status GroupStart() final { return UnimplementedError(); }
absl::Status GroupEnd() final { return UnimplementedError(); }

Expand Down

0 comments on commit 9ee3c1a

Please sign in to comment.