diff --git a/xla/backends/gpu/collectives/nccl_communicator.cc b/xla/backends/gpu/collectives/nccl_communicator.cc index 2a1f238f4099cf..ccfd0cce0a4ddb 100644 --- a/xla/backends/gpu/collectives/nccl_communicator.cc +++ b/xla/backends/gpu/collectives/nccl_communicator.cc @@ -45,6 +45,11 @@ NcclCommunicator::~NcclCommunicator() { XLA_NCCL_LOG_IF_ERROR(ncclCommDestroy(comm_)); } +absl::Status NcclCommunicator::Abort() { + VLOG(1) << "Abort NCCL communicator: " << ToString(); + return XLA_NCCL_STATUS(ncclCommAbort(comm_)); +} + absl::Status NcclCommunicator::HealthCheck() const { VLOG(5) << "Get last async error for NCCL communicator: " << ToString(); diff --git a/xla/backends/gpu/collectives/nccl_communicator.h b/xla/backends/gpu/collectives/nccl_communicator.h index 0472fb0a154413..ebae31d68d4362 100644 --- a/xla/backends/gpu/collectives/nccl_communicator.h +++ b/xla/backends/gpu/collectives/nccl_communicator.h @@ -40,6 +40,7 @@ class NcclCommunicator : public Communicator { explicit NcclCommunicator(ncclComm_t comm); ~NcclCommunicator() override; + absl::Status Abort() final; absl::Status HealthCheck() const final; std::string ToString() const final; diff --git a/xla/core/collectives/communicator.h b/xla/core/collectives/communicator.h index 284b8ce68d9447..942a6b0efca237 100644 --- a/xla/core/collectives/communicator.h +++ b/xla/core/collectives/communicator.h @@ -28,6 +28,11 @@ class Communicator { public: virtual ~Communicator() = default; + // Abort any uncompleted operations and destroys the underlying communicator + // object. It is undefined behavior to use the communicator after calling + // this method. + virtual absl::Status Abort() = 0; + // Checks the health of the communicator. It might return an error from the // previously launched asynchronous collective operations, and it does not // have to wait for the completion of scheduled operations. diff --git a/xla/service/gpu/runtime/nccl_api.cc b/xla/service/gpu/runtime/nccl_api.cc index 4ac6b6a43b4757..68a28a8418dd07 100644 --- a/xla/service/gpu/runtime/nccl_api.cc +++ b/xla/service/gpu/runtime/nccl_api.cc @@ -302,9 +302,6 @@ class DefaultNcclApi final : public NcclApi { absl::Span comms, int32_t color, absl::Span keys, std::optional config) final; - absl::Status CommAbort(Communicator* comm) final; - absl::Status CommFinalize(Communicator* comm) final; - absl::StatusOr CommCount(Communicator* comm) final; absl::Status GroupStart() final; @@ -481,16 +478,6 @@ DefaultNcclApi::CommSplit(absl::Span comms, #endif // !defined(TENSORFLOW_USE_ROCM) || TF_ROCM_VERSION >= 60000 } -absl::Status DefaultNcclApi::CommAbort(Communicator* comm) { - VLOG(1) << "Abort NCCL communicator: " << comm; - return XLA_NCCL_STATUS(ncclCommAbort(Cast(comm))); -} - -absl::Status DefaultNcclApi::CommFinalize(Communicator* comm) { - VLOG(1) << "Finalize NCCL communicator: " << comm; - return XLA_NCCL_STATUS(ncclCommFinalize(Cast(comm))); -} - absl::StatusOr DefaultNcclApi::CommCount(Communicator* comm) { VLOG(5) << "Get the number of ranks in NCCL communicator: " << comm; int32_t count; diff --git a/xla/service/gpu/runtime/nccl_api.h b/xla/service/gpu/runtime/nccl_api.h index ab5a6aeeeb3e70..1293ffab01e2ea 100644 --- a/xla/service/gpu/runtime/nccl_api.h +++ b/xla/service/gpu/runtime/nccl_api.h @@ -161,17 +161,6 @@ class NcclApi : public GpuCollectives { absl::Span comms, int32_t color, absl::Span keys, std::optional config) = 0; - // Abort any uncompleted operations and destroys the communicator. Frees - // resources that are allocated to a communicator object comm. - // - // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommabort - virtual absl::Status CommAbort(Communicator* comm) = 0; - - // Finalize a communicator object comm. - // - // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommdestroy - virtual absl::Status CommFinalize(Communicator* comm) = 0; - // Returns the number of ranks in the NCCL communicator comm. // // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommcount diff --git a/xla/service/gpu/runtime/nccl_api_stub.cc b/xla/service/gpu/runtime/nccl_api_stub.cc index f24525df725ed4..78501bd1d72f55 100644 --- a/xla/service/gpu/runtime/nccl_api_stub.cc +++ b/xla/service/gpu/runtime/nccl_api_stub.cc @@ -100,12 +100,6 @@ class NcclApiStub final : public NcclApi { return UnimplementedError(); } - absl::Status CommAbort(Communicator*) final { return UnimplementedError(); } - - absl::Status CommFinalize(Communicator*) final { - return UnimplementedError(); - } - absl::StatusOr CommCount(Communicator*) final { return UnimplementedError(); } diff --git a/xla/service/gpu/runtime/nccl_clique.cc b/xla/service/gpu/runtime/nccl_clique.cc index cf5f0f8c440a66..206b54683c81e0 100644 --- a/xla/service/gpu/runtime/nccl_clique.cc +++ b/xla/service/gpu/runtime/nccl_clique.cc @@ -138,13 +138,13 @@ static NcclCliques& GetNcclCliques() { // error state. It will free resources that are allocated to a communicator // and abort any uncompleted operations before destroying the communicator. static absl::Status CheckComm(Communicator* comm) { - absl::Status async_err = comm->HealthCheck(); - if (!async_err.ok()) { + absl::Status health = comm->HealthCheck(); + if (!health.ok()) { LOG(ERROR) << "Aborting communicator: " << comm - << " due to async NCCL error: " << async_err; - TF_RETURN_IF_ERROR(NcclApi::Default()->CommAbort(comm)); + << " due to error: " << health; + TF_RETURN_IF_ERROR(comm->Abort()); } - return async_err; + return health; } // Runs async check on all communicators in a clique.