From 9ee3c1ab807ef47cdae03b9bb6f4ebc2f6e5d9a5 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 2 Dec 2024 16:04:17 -0800 Subject: [PATCH] [xla:collectives] NFC: Move NcclApi::CommCount to Communicator API PiperOrigin-RevId: 702113027 --- xla/backends/gpu/collectives/BUILD | 1 + .../gpu/collectives/nccl_communicator.cc | 9 ++++ .../gpu/collectives/nccl_communicator.h | 3 ++ xla/core/collectives/BUILD | 1 + xla/core/collectives/communicator.h | 5 +++ .../gpu/runtime/nccl_all_reduce_thunk.cc | 7 ++- .../gpu/runtime/nccl_all_to_all_thunk.cc | 43 +++++++++---------- xla/service/gpu/runtime/nccl_api.cc | 9 ---- xla/service/gpu/runtime/nccl_api.h | 5 --- xla/service/gpu/runtime/nccl_api_stub.cc | 4 -- 10 files changed, 42 insertions(+), 45 deletions(-) diff --git a/xla/backends/gpu/collectives/BUILD b/xla/backends/gpu/collectives/BUILD index 0e07a204eda80a..ba44aeb5730468 100644 --- a/xla/backends/gpu/collectives/BUILD +++ b/xla/backends/gpu/collectives/BUILD @@ -125,6 +125,7 @@ cc_library( "//xla:util", "//xla/core/collectives:communicator", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:logging", ] + if_cuda_is_configured([ diff --git a/xla/backends/gpu/collectives/nccl_communicator.cc b/xla/backends/gpu/collectives/nccl_communicator.cc index ccfd0cce0a4ddb..bc5ca202a7b1da 100644 --- a/xla/backends/gpu/collectives/nccl_communicator.cc +++ b/xla/backends/gpu/collectives/nccl_communicator.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/backends/gpu/collectives/nccl_communicator.h" +#include +#include #include #include "absl/status/status.h" @@ -61,6 +63,13 @@ absl::Status NcclCommunicator::HealthCheck() const { ncclGetLastError(comm_), ncclGetErrorString(async_err)); } +absl::StatusOr NcclCommunicator::NumRanks() const { + VLOG(5) << "Get the number of ranks in NCCL communicator: " << ToString(); + int32_t count; + XLA_NCCL_RETURN_IF_ERROR(ncclCommCount(comm_, &count)); + return count; +} + std::string NcclCommunicator::ToString() const { return absl::StrFormat("NccCommunicator(ncclComm_t=%p)", comm_); } diff --git a/xla/backends/gpu/collectives/nccl_communicator.h b/xla/backends/gpu/collectives/nccl_communicator.h index ebae31d68d4362..031412f356658a 100644 --- a/xla/backends/gpu/collectives/nccl_communicator.h +++ b/xla/backends/gpu/collectives/nccl_communicator.h @@ -16,9 +16,11 @@ limitations under the License. #ifndef XLA_BACKENDS_GPU_COLLECTIVES_NCCL_COMMUNICATOR_H_ #define XLA_BACKENDS_GPU_COLLECTIVES_NCCL_COMMUNICATOR_H_ +#include #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/core/collectives/communicator.h" #if TENSORFLOW_USE_ROCM @@ -42,6 +44,7 @@ class NcclCommunicator : public Communicator { absl::Status Abort() final; absl::Status HealthCheck() const final; + absl::StatusOr NumRanks() const final; std::string ToString() const final; diff --git a/xla/core/collectives/BUILD b/xla/core/collectives/BUILD index 25a0e39e26cba6..7a9f2f9edfb2e3 100644 --- a/xla/core/collectives/BUILD +++ b/xla/core/collectives/BUILD @@ -42,6 +42,7 @@ cc_library( hdrs = ["communicator.h"], deps = [ "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", ], ) diff --git a/xla/core/collectives/communicator.h b/xla/core/collectives/communicator.h index 942a6b0efca237..17a93609ee6afe 100644 --- a/xla/core/collectives/communicator.h +++ b/xla/core/collectives/communicator.h @@ -16,10 +16,12 @@ limitations under the License. #ifndef XLA_CORE_COLLECTIVES_COMMUNICATOR_H_ #define XLA_CORE_COLLECTIVES_COMMUNICATOR_H_ +#include #include #include #include "absl/status/status.h" +#include "absl/status/statusor.h" namespace xla { @@ -38,6 +40,9 @@ class Communicator { // have to wait for the completion of scheduled operations. virtual absl::Status HealthCheck() const = 0; + // Returns the number of ranks in the communicator. + virtual absl::StatusOr NumRanks() const = 0; + virtual std::string ToString() const = 0; }; diff --git a/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc b/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc index 54a7431ab9970d..d14b1a1f9c8ea0 100644 --- a/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc +++ b/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc @@ -232,21 +232,20 @@ absl::Status RunReduceScatter(NcclApi* nccl_api, ReductionKind reduction_kind, TF_RETURN_IF_ERROR( MaybeRegisterBuffers(nccl_api, stream.parent(), buffers, comm)); - TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm)); + TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks()); TF_RETURN_IF_ERROR(nccl_api->GroupStart()); for (DeviceBufferPair& buffer : buffers) { // buffer.element_count is the source buffers element count. For // ncclReduceScatter, we need the destination buffers element count. - TF_RET_CHECK(buffer.element_count % num_participants == 0) + TF_RET_CHECK(buffer.element_count % num_ranks == 0) << "Source buffer was not an exact multiple of the number of " "participants."; TF_RETURN_IF_ERROR(nccl_api->ReduceScatter( buffer.source_buffer, buffer.destination_buffer, buffer.element_type, - buffer.element_count / num_participants, reduction_kind, comm, - &stream)); + buffer.element_count / num_ranks, reduction_kind, comm, &stream)); } return nccl_api->GroupEnd(); diff --git a/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc b/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc index cc42d145a67ace..b8adbfa9f5d5da 100644 --- a/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc +++ b/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc @@ -113,13 +113,12 @@ absl::Status NcclAllToAllStartThunk::Initialize( GetNcclComm(*params.collective_params, *params.collective_cliques, config().replica_groups, config().group_mode, stream_id, stream_kind)); - TF_ASSIGN_OR_RETURN(int32_t num_participants, - nccl_api()->CommCount(comm_handle.comm)); - int local_id = params.stream->parent()->device_ordinal() % num_participants; + TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm_handle.comm->NumRanks()); + int local_id = params.stream->parent()->device_ordinal() % num_ranks; { absl::MutexLock lock(&pointer_maps_mutex_); if (!send_pointer_maps_.count(local_id)) { - for (int i = 0; i < num_participants; ++i) { + for (int i = 0; i < num_ranks; ++i) { if (!params.stream->parent()->HostMemoryRegister( &send_pointer_maps_[local_id][i], sizeof(void*))) { VLOG(5) << "Registering host send pointer for memcpy failed."; @@ -144,10 +143,9 @@ absl::Status NcclAllToAllStartThunk::Cleanup(const CleanupParams& params) { GetNcclComm(*params.collective_params, *params.collective_cliques, config().replica_groups, config().group_mode, stream_id, stream_kind)); - TF_ASSIGN_OR_RETURN(int32_t num_participants, - nccl_api()->CommCount(comm_handle.comm)); + TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm_handle.comm->NumRanks()); - int local_id = params.executor->device_ordinal() % num_participants; + int local_id = params.executor->device_ordinal() % num_ranks; { absl::MutexLock lock(&pointer_maps_mutex_); if (send_pointer_maps_.count(local_id)) { @@ -176,11 +174,10 @@ absl::Status NcclAllToAllStartThunk::RunNcclCollective( std::vector device_buffers, ConvertToDeviceBuffers(params, buffers_, config_.config.operand_element_type)); - TF_ASSIGN_OR_RETURN(int32_t num_participants, - nccl_api()->CommCount(comm_handle.comm)); + TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm_handle.comm->NumRanks()); if (is_local() && p2p_memcpy_enabled_) { - int local_id = stream.parent()->device_ordinal() % num_participants; + int local_id = stream.parent()->device_ordinal() % num_ranks; absl::flat_hash_map* send_pointer_map = nullptr; absl::flat_hash_map* receive_pointer_map = nullptr; { @@ -222,7 +219,7 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, TF_RETURN_IF_ERROR( MaybeRegisterBuffers(nccl_api, stream.parent(), buffers, comm)); - TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm)); + TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks()); TF_RETURN_IF_ERROR(nccl_api->GroupStart()); @@ -232,12 +229,12 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, // and produces a tuple of outputs. if (has_split_dimension) { for (DeviceBufferPair& buffer : buffers) { - TF_RET_CHECK(buffer.element_count % num_participants == 0) + TF_RET_CHECK(buffer.element_count % num_ranks == 0) << "Buffer was not an exact multiple of the number of participants."; - size_t chunk_elements = buffer.element_count / num_participants; + size_t chunk_elements = buffer.element_count / num_ranks; - for (int peer = 0; peer < num_participants; ++peer) { + for (int peer = 0; peer < num_ranks; ++peer) { se::DeviceMemoryBase send_slice = NcclApi::Slice(buffer.source_buffer, buffer.element_type, peer * chunk_elements, chunk_elements); @@ -254,7 +251,7 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, } } } else { - TF_RET_CHECK(buffers.size() == num_participants) + TF_RET_CHECK(buffers.size() == num_ranks) << "Number of inputs didn't match the number of participants."; for (size_t i = 0; i < buffers.size(); ++i) { @@ -285,7 +282,7 @@ absl::Status RunMemCpyAllToAll( TF_RETURN_IF_ERROR( MaybeRegisterBuffers(nccl_api, stream.parent(), buffers, comm)); - TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm)); + TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks()); // AllToAll can operate in two modes. Either it specifies a split dimension, // in which case inputs are split and outputs concatenated in that dimension @@ -293,13 +290,13 @@ absl::Status RunMemCpyAllToAll( // and produces a tuple of outputs. if (has_split_dimension) { for (DeviceBufferPair& buffer : buffers) { - TF_RET_CHECK(buffer.element_count % num_participants == 0) + TF_RET_CHECK(buffer.element_count % num_ranks == 0) << "Buffer was not an exact multiple of the number of participants."; - size_t chunk_elements = buffer.element_count / num_participants; + size_t chunk_elements = buffer.element_count / num_ranks; TF_RETURN_IF_ERROR(nccl_api->GroupStart()); - for (int peer = 0; peer < num_participants; ++peer) { + for (int peer = 0; peer < num_ranks; ++peer) { se::DeviceMemoryBase recv_slice = NcclApi::Slice(buffer.destination_buffer, buffer.element_type, peer * chunk_elements, chunk_elements); @@ -313,7 +310,7 @@ absl::Status RunMemCpyAllToAll( TF_RETURN_IF_ERROR(nccl_api->GroupEnd()); TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); - for (int peer = 0; peer < num_participants; ++peer) { + for (int peer = 0; peer < num_ranks; ++peer) { se::DeviceMemoryBase send_slice = NcclApi::Slice(buffer.source_buffer, buffer.element_type, peer * chunk_elements, chunk_elements); @@ -324,11 +321,11 @@ absl::Status RunMemCpyAllToAll( } } } else { - TF_RET_CHECK(buffers.size() == num_participants) + TF_RET_CHECK(buffers.size() == num_ranks) << "Number of inputs didn't match the number of participants."; TF_RETURN_IF_ERROR(nccl_api->GroupStart()); - for (int peer = 0; peer < num_participants; ++peer) { + for (int peer = 0; peer < num_ranks; ++peer) { send_pointer_map[peer] = (uint64_t)buffers[peer].destination_buffer.opaque(); @@ -340,7 +337,7 @@ absl::Status RunMemCpyAllToAll( TF_RETURN_IF_ERROR(nccl_api->GroupEnd()); TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); - for (int peer = 0; peer < num_participants; ++peer) { + for (int peer = 0; peer < num_ranks; ++peer) { // double buffer, exchange data with peer se::DeviceMemoryBase dst_addr = se::DeviceMemoryBase((void*)receive_pointer_map[peer]); diff --git a/xla/service/gpu/runtime/nccl_api.cc b/xla/service/gpu/runtime/nccl_api.cc index 68a28a8418dd07..5044c1ff8ffcc1 100644 --- a/xla/service/gpu/runtime/nccl_api.cc +++ b/xla/service/gpu/runtime/nccl_api.cc @@ -302,8 +302,6 @@ class DefaultNcclApi final : public NcclApi { absl::Span comms, int32_t color, absl::Span keys, std::optional config) final; - absl::StatusOr CommCount(Communicator* comm) final; - absl::Status GroupStart() final; absl::Status GroupEnd() final; @@ -478,13 +476,6 @@ DefaultNcclApi::CommSplit(absl::Span comms, #endif // !defined(TENSORFLOW_USE_ROCM) || TF_ROCM_VERSION >= 60000 } -absl::StatusOr DefaultNcclApi::CommCount(Communicator* comm) { - VLOG(5) << "Get the number of ranks in NCCL communicator: " << comm; - int32_t count; - XLA_NCCL_RETURN_IF_ERROR(ncclCommCount(Cast(comm), &count)); - return count; -} - absl::Status DefaultNcclApi::GroupStart() { VLOG(5) << "Start NCCL group"; return XLA_NCCL_STATUS(ncclGroupStart()); diff --git a/xla/service/gpu/runtime/nccl_api.h b/xla/service/gpu/runtime/nccl_api.h index 1293ffab01e2ea..f187d816026c13 100644 --- a/xla/service/gpu/runtime/nccl_api.h +++ b/xla/service/gpu/runtime/nccl_api.h @@ -161,11 +161,6 @@ class NcclApi : public GpuCollectives { absl::Span comms, int32_t color, absl::Span keys, std::optional config) = 0; - // Returns the number of ranks in the NCCL communicator comm. - // - // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommcount - virtual absl::StatusOr CommCount(Communicator* comm) = 0; - // Starts a group call. // // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/group.html#ncclgroupstart diff --git a/xla/service/gpu/runtime/nccl_api_stub.cc b/xla/service/gpu/runtime/nccl_api_stub.cc index 78501bd1d72f55..07375e91fab05a 100644 --- a/xla/service/gpu/runtime/nccl_api_stub.cc +++ b/xla/service/gpu/runtime/nccl_api_stub.cc @@ -100,10 +100,6 @@ class NcclApiStub final : public NcclApi { return UnimplementedError(); } - absl::StatusOr CommCount(Communicator*) final { - return UnimplementedError(); - } - absl::Status GroupStart() final { return UnimplementedError(); } absl::Status GroupEnd() final { return UnimplementedError(); }