From 90a52d376cb20eed62e5b0c847f9ccf2012097d9 Mon Sep 17 00:00:00 2001 From: hanchao Date: Wed, 20 Nov 2024 02:17:03 +0000 Subject: [PATCH 01/19] Happy Init --- CMakeLists.txt | 18 + cmake/Modules/FindXCCL.cmake | 65 ++ cmake/XCCL.cmake | 20 + src/BuildOnLinux.cmake | 27 +- src/CMakeLists.txt | 1 + src/xccl/CMakeLists.txt | 7 + src/xccl/ProcessGroupXCCL.cpp | 1795 +++++++++++++++++++++++++++++++++ src/xccl/ProcessGroupXCCL.hpp | 372 +++++++ 8 files changed, 2304 insertions(+), 1 deletion(-) create mode 100644 cmake/Modules/FindXCCL.cmake create mode 100644 cmake/XCCL.cmake create mode 100644 src/xccl/CMakeLists.txt create mode 100644 src/xccl/ProcessGroupXCCL.cpp create mode 100644 src/xccl/ProcessGroupXCCL.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index a0ba1fd99..33d2d7667 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,24 @@ list(APPEND CMAKE_MODULE_PATH ${TORCH_XPU_OPS_ROOT}/cmake/Modules) include(${TORCH_XPU_OPS_ROOT}/cmake/SYCL.cmake) include(${TORCH_XPU_OPS_ROOT}/cmake/BuildFlags.cmake) +option(BUILD_WITH_XCCL "Build with XCCL support" ON) +if (DEFINED ENV{BUILD_WITH_XCCL}) + string(TOLOWER "$ENV{BUILD_WITH_XCCL}" BUILD_WITH_XCCL_LOWER) + + if (NOT (BUILD_WITH_XCCL_LOWER STREQUAL "1" OR + BUILD_WITH_XCCL_LOWER STREQUAL "on" OR + BUILD_WITH_XCCL_LOWER STREQUAL "yes")) + set(BUILD_WITH_XCCL OFF CACHE BOOL "Build with XCCL support" FORCE) + else() + set(BUILD_WITH_XCCL ON CACHE BOOL "Build with XCCL support" FORCE) + endif() +endif() + +if(NOT WIN32 AND BUILD_WITH_XCCL) + include(${TORCH_XPU_OPS_ROOT}/cmake/XCCL.cmake) + set(USE_C10D_XCCL) +endif() + if(BUILD_TEST) add_subdirectory(${TORCH_XPU_OPS_ROOT}/test/sycl ${CMAKE_BINARY_DIR}/test_sycl) endif() diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake new file mode 100644 index 000000000..29571065c --- /dev/null +++ b/cmake/Modules/FindXCCL.cmake @@ -0,0 +1,65 @@ +# This will define the following variables: +# XCCL_FOUND : True if the system has the XCCL library. +# XCCL_INCLUDE_DIR : Include directories needed to use XCCL. +# XCCL_LIBRARY_DIR :The path to the XCCL library. +# XCCL_LIBRARY : XCCL library fullname. + +include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) + +set(XCCL_ROOT $ENV{CCL_ROOT}) +string(COMPARE EQUAL "${XCCL_ROOT}" "" nocclfound) +if(nocclfound) + set(XCCL_FOUND False) + set(XCCL_REASON_FAILURE "OneCCL library not found!!") + set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") + return() +endif() + +# Find include path from binary. +find_file( + XCCL_INCLUDE_DIR + NAMES include + HINTS ${XCCL_ROOT} + NO_DEFAULT_PATH +) + +# Find include/oneapi path from include path. +find_file( + XCCL_INCLUDE_ONEAPI_DIR + NAMES oneapi + HINTS ${XCCL_ROOT}/include/ + NO_DEFAULT_PATH +) + +list(APPEND XCCL_INCLUDE_DIR ${XCCL_INCLUDE_ONEAPI_DIR}) + +# Find library directory from binary. +find_file( + XCCL_LIBRARY_DIR + NAMES lib + HINTS ${XCCL_ROOT} + NO_DEFAULT_PATH +) + +# Find XCCL library fullname. +find_library( + XCCL_LIBRARY + NAMES ccl + HINTS ${XCCL_LIBRARY_DIR} + NO_DEFAULT_PATH +) + +if((NOT XCCL_INCLUDE_DIR) OR (NOT XCCL_LIBRARY_DIR) OR (NOT XCCL_LIBRARY)) + set(XCCL_FOUND False) + set(XCCL_REASON_FAILURE "OneCCL library not found!!") + set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") + return() +endif() + +find_package_handle_standard_args( + XCCL + FOUND_VAR XCCL_FOUND + REQUIRED_VARS XCCL_INCLUDE_DIR XCCL_LIBRARY_DIR XCCL_LIBRARY + REASON_FAILURE_MESSAGE "${XCCL_REASON_FAILURE}" +) + diff --git a/cmake/XCCL.cmake b/cmake/XCCL.cmake new file mode 100644 index 000000000..50e1bdcf5 --- /dev/null +++ b/cmake/XCCL.cmake @@ -0,0 +1,20 @@ +if(NOT __XCCL_INCLUDED) + set(__XCCL_INCLUDED TRUE) + + # XCCL_ROOT, XCCL_LIBRARY_DIR, XCCL_INCLUDE_DIR are handled by FindXCCL.cmake. + find_package(XCCL REQUIRED) + if(NOT XCCL_FOUND) + message("${XCCL_NOT_FOUND_MESSAGE") + return() + endif() + if(XCCL_FOUND) + add_library(torch::xccl INTERFACE IMPORTED) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${XCCL_INCLUDE_DIR}) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_LINK_LIBRARIES + ${XCCL_LIBRARY}) + endif() +endif() + diff --git a/src/BuildOnLinux.cmake b/src/BuildOnLinux.cmake index 1590919c0..f32f840cd 100644 --- a/src/BuildOnLinux.cmake +++ b/src/BuildOnLinux.cmake @@ -8,7 +8,8 @@ add_library( STATIC ${ATen_XPU_CPP_SRCS} ${ATen_XPU_NATIVE_CPP_SRCS} - ${ATen_XPU_GEN_SRCS}) + ${ATen_XPU_GEN_SRCS} + ${ATen_XPU_XCCL_SRCS}) if(BUILD_SEPARATE_OPS) foreach(sycl_src ${ATen_XPU_SYCL_SRCS}) @@ -24,6 +25,20 @@ if(BUILD_SEPARATE_OPS) # Decouple with PyTorch cmake definition. install(TARGETS ${sycl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") endforeach() + if(USE_C10D_XCCL) + foreach(xccl_src ${ATen_XPU_XCCL_SRCS}) + get_filename_component(name ${xccl_src} NAME_WLE REALPATH) + set(xccl_lib torch-xpu-ops-xccl-${name}) + target_link_libraries(xccl_lib PRIVATE torch::xccl) + sycl_add_library( + ${xccl_lib} + SHARED + CXX_SOURCES ${xccl_src}) + target_link_libraries(torch_xpu_ops PUBLIC ${xccl_lib}) + list(APPEND TORCH_XPU_OPS_LIBRARIES ${xccl_lib}) + install(TARGETS ${xccl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + endforeach() + endif() else() # Split SYCL kernels into 4 libraries as categories 1) Unary+Binary 2) Reduce 3) Foreach 4) Others. set(ATen_XPU_SYCL_UNARY_BINARY_SRCS) @@ -102,6 +117,16 @@ else() # Decouple with PyTorch cmake definition. install(TARGETS ${sycl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + if(USE_C10D_XCCL) + set(xccl_lib torch_xpu_ops_xccl) + target_link_libraries(xccl_lib PRIVATE torch::xccl) + sycl_add_library( + ${xccl_lib} + SHARED + CXX_SOURCES ${ATen_XPU_XCCL_SRCS}) + target_link_libraries(torch_xpu_ops PUBLIC ${xccl_lib}) + install(TARGETS ${xccl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + endif() endif() set(SYCL_LINK_LIBRARIES_KEYWORD) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0716ca5af..a1d7f49be 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -4,6 +4,7 @@ include(${TORCH_XPU_OPS_ROOT}/cmake/Codegen.cmake) set(ATen_XPU_CPP_SRCS) set(ATen_XPU_NATIVE_CPP_SRCS) set(ATen_XPU_SYCL_SRCS) +set(ATen_XPU_XCCL_SRCS) set(ATen_XPU_INCLUDE_DIRS ${TORCH_XPU_OPS_ROOT}/src CACHE STRING "ATen XPU Include directory") diff --git a/src/xccl/CMakeLists.txt b/src/xccl/CMakeLists.txt new file mode 100644 index 000000000..809181d55 --- /dev/null +++ b/src/xccl/CMakeLists.txt @@ -0,0 +1,7 @@ +# XCCL sources + +file(GLOB xccl_cpp "*.cpp") + +list(APPEND ATen_XPU_XCCL_SRCS ${xccl_cpp}) + +set(ATen_XPU_XCCL_SRCS ${ATen_XPU_XCCL_SRCS} PARENT_SCOPE) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp new file mode 100644 index 000000000..4f407d133 --- /dev/null +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -0,0 +1,1795 @@ +#ifdef USE_C10D_XCCL + +#include +#include +#include + +namespace c10d { + +namespace { +const std::map xcclOps = { + {ReduceOp::MIN, ccl::reduction::min}, + {ReduceOp::MAX, ccl::reduction::max}, + {ReduceOp::SUM, ccl::reduction::sum}, + {ReduceOp::PRODUCT, ccl::reduction::prod}, +}; + +const std::map xcclDatatypes = { + {at::kByte, ccl::datatype::uint8}, + {at::kChar, ccl::datatype::int8}, + {at::kInt, ccl::datatype::int32}, + {at::kLong, ccl::datatype::int64}, + {at::kHalf, ccl::datatype::float16}, + {at::kFloat, ccl::datatype::float32}, + {at::kDouble, ccl::datatype::float64}, + {at::kBFloat16, ccl::datatype::bfloat16}, + {at::kBool, ccl::datatype::uint8}, + // use for allgather + {at::kFloat8_e5m2, ccl::datatype::uint8}, + {at::kFloat8_e4m3fn, ccl::datatype::uint8}, + {at::kFloat8_e4m3fnuz, ccl::datatype::uint8}, + {at::kFloat8_e5m2fnuz, ccl::datatype::uint8}, +}; + +bool computeLengthsAndCheckAndGetFlat( + const std::vector& tensors, + std::vector& lengths, + at::Tensor& flatTensor, + int64_t& flatLength) { + int64_t groupSize = tensors.size(); + auto firstTensor = tensors[0]; + int64_t totalSize = 0; + bool isFlat = true; + + auto storage = firstTensor.storage(); + int64_t firstStorageOffset = firstTensor.storage_offset(); + + for (int i = 0; i < groupSize; i++) { + auto& curTensor = tensors[i]; + int64_t length = curTensor.numel(); + lengths[i] = length; + totalSize += length; + + if (isFlat && + (!storage.is_alias_of(curTensor.storage()) || + curTensor.storage_offset() != + firstStorageOffset + totalSize - length)) { + isFlat = false; + } + } + + flatLength = totalSize; + + if (isFlat) { + flatTensor = firstTensor; + } else { + flatTensor = at::empty({totalSize}, firstTensor.options()); + } + + return isFlat; +} + +bool check_same_size(const std::vector& input_tensors) { + for (const auto& input_tensor : input_tensors) { + if (!input_tensors[0].is_same_size(input_tensor)) { + return false; + } + } + return true; +} + +void check_xpu_single_tensor( + const at::Tensor& tensor, + const bool p2p = false // whether operation is a P2P operation +) { + if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { + C10_THROW_ERROR( + ValueError, "Tensors must be XPU and dense and non-complex"); + + // Skip the following requirements for P2P operations + if (!tensor.is_contiguous(tensor.suggest_memory_format())) { + if (p2p) { + TORCH_WARN_ONCE( + "Detected non-contiguous tensor in P2P operations. It is user " + "responsibility to guarantee that source and destination tensors have " + "the same contiguity format."); + } else { + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + } + } + } +} + +int64_t check_xpu_tensors_same_device(const std::vector& tensors) { + TORCH_CHECK_WITH( + ValueError, tensors.size() != 0, "Tensor list must be nonempty"); + + const auto& first = tensors.front(); + + int64_t total_numel = 0; + for (const auto& t : tensors) { + if (!t.is_xpu() || t.is_sparse() || t.is_complex()) { + C10_THROW_ERROR( + ValueError, "Tensors must be XPU and dense and non-complex"); + } + if (t.scalar_type() != first.scalar_type()) { + C10_THROW_ERROR(TypeError, "Tensors must have identical type"); + } + TORCH_CHECK_WITH( + ValueError, + t.get_device() == tensors[0].get_device(), + "Expected list of tensors on the same device"); + total_numel += t.numel(); + } + + return total_numel; +} + +ccl::datatype getXcclDataType( + at::ScalarType type, + bool is_reduction_op = false) { + if (is_reduction_op) + TORCH_CHECK( + !isFloat8Type(type), + "Float8 dtypes are not currenlty supported for XCCL reductions"); + auto it = xcclDatatypes.find(type); + TORCH_CHECK_WITH( + TypeError, + it != xcclDatatypes.end(), + "Input tensor data type is not supported for XCCL process group: ", + type); + return it->second; +} + +ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { + try { + if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) { + // Map sum to max for bool tensors to avoid overflow issues with sum. + return ccl::reduction::max; + } + // WA due to oneCCL not support AVG + if (reduceOp == ReduceOp::AVG) { + return ccl::reduction::sum; + } + return xcclOps.at(reduceOp); + } catch (const std::out_of_range&) { + C10_THROW_ERROR( + ValueError, + "Cannot use ReduceOp." + reduceOpToString(reduceOp) + " with XCCL"); + } +} + +void syncStream( + at::Device& device, + at::xpu::XPUEvent& xcclEvent, + at::xpu::XPUStream& xcclStream) { + xcclEvent.record(at::xpu::getCurrentXPUStream(device.index())); + xcclEvent.block(xcclStream); +} + +} // namespace + +constexpr int64_t kSynchronizeBusyWaitMillis = 10; +thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0; + +ProcessGroupXCCL::WorkXCCL::WorkXCCL( + at::Device& device, + int rank, + OpType opType, + uint64_t seq, + const char* profilingTitle, + const std::optional>& inputs) + : Work(rank, opType, profilingTitle, inputs), + device_(device), + workStartTime_(std::chrono::steady_clock::now()), + seq_(seq) { + xcclEndEvent_ = std::make_shared(); +} + +ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) + : Work(w.rank_, w.opType_), + device_(w.device_), + xcclEndEvent_(w.xcclEndEvent_), + blockingWait_(w.blockingWait_), + workStartTime_(w.workStartTime_), + seq_(w.seq_) {} + +ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; + +bool ProcessGroupXCCL::WorkXCCL::isCompleted() { + if (xcclEndEvent_ && xcclEndEvent_->query()) { + return true; + } + return false; +} + +void ProcessGroupXCCL::WorkXCCL::synchronize() { + synchronizeInternal(kNoTimeout); +} + +void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( + std::chrono::milliseconds timeout) { + auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); + xcclEndEvent_->block(currentStream); + if (blockingWait_) { + while (!isCompleted()) { + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + if (timeElapsed >= timeout) { + std::string exceptionMsg = c10::str( + "Work ran time out after ", timeElapsed.count(), " milliseconds."); + TORCH_CHECK(false, exceptionMsg) + } + std::this_thread::sleep_for( + std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); + } + } + if (barrierTensor_.defined()) { + auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); + currentStream.synchronize(); + } +} + +bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { + synchronizeInternal(timeout); + return true; +} + +constexpr const char* MULTI_DEVICE_ERROR_MSG = + "Expecting one tensor only but got multiple"; + +ProcessGroupXCCL::ProcessGroupXCCL( + const c10::intrusive_ptr& store, + int rank, + int size) + : Backend(rank, size), store_(store), xcclCommCounter_(0) { + blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); + init(); +} + +ProcessGroupXCCL::~ProcessGroupXCCL() = default; + +void ProcessGroupXCCL::setSequenceNumberForGroup() {} + +uint64_t ProcessGroupXCCL::getSequenceNumberForGroup() { + return seqCollective_; +} + +c10::intrusive_ptr ProcessGroupXCCL::initWork( + at::Device& device, + int rank, + OpType opType, + const char* profilingTitle, + const std::vector& inputs, + const std::vector& outputs) { + auto r = c10::make_intrusive( + device, + rank, + opType, + seqCollective_, + profilingTitle, + std::optional>(inputs)); + return r; +} + +std::shared_ptr ProcessGroupXCCL::getXCCLComm( + const std::string& deviceKey, + at::Device& device, + OpType opType, + int p2pRank, + bool isSendRecvSelf) { + if (deviceKey.empty()) { + C10_THROW_ERROR( + DistBackendError, + "Not able to create/get the XCCL Communicator since " + "the devices are empty "); + } + + usedDeviceIdxs_.insert(device.index()); + + { + std::lock_guard lock(mutex_); + if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { + return devXCCLCommMap_[deviceKey]; + } + } + + std::shared_ptr XCCLComm; + + bool batchP2P = xcclActiveGroupCounter_ > 0; + bool singleP2POp = isP2POp(opType, batchP2P); + + at::xpu::OptionalXPUGuard gpuGuard(device); + + int numRanks, rank; + if (!singleP2POp) { + numRanks = getSize(); + rank = getRank(); + } else if (isSendRecvSelf) { + numRanks = 1; + rank = 0; + } else { + numRanks = 2; + rank = p2pRank; + } + + c10::impl::VirtualGuardImpl impl(device.type()); + c10::Stream stream = + impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); + sycl::queue& q = c10::xpu::XPUStream(stream).queue(); + + auto ctx = ccl::create_context(q.get_context()); + ccl::vector_class> devs_rank; + devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); + + auto xccl_kvs = get_kvs(rank_, *store_, singleP2POp, deviceKey, p2pRank); + auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); + XCCLComm = std::make_shared(std::move(comms[0])); + + RECORD_PARAM_COMMS( + 0, // seq + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + rank, // rank + "init", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + size_); // worldSize + + std::lock_guard lock(mutex_); + devXCCLCommMap_.emplace(deviceKey, XCCLComm); + xcclStreamsMap_.emplace(deviceKey, std::move(stream)); + xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); + + return XCCLComm; +} + +void ProcessGroupXCCL::groupStart() { + ccl::group_start(); + ++xcclActiveGroupCounter_; +} + +void ProcessGroupXCCL::groupEnd() { + ccl::group_end(); + --xcclActiveGroupCounter_; +} + +// TODO: wait p2p enable +static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; +void ProcessGroupXCCL::startCoalescing() { + if (coalescing_state_ & CoalP2P) { + seqP2P_++; + } else { + seqCollective_++; + } + coalescedDevice_.set_index(-1); + coalescedComm_ = nullptr; + coalescing_state_ |= CoalActive; + groupStart(); +} + +c10::intrusive_ptr ProcessGroupXCCL::endCoalescing(OpType optype) { + if (coalescedComm_ == nullptr) { + // There is no actual work being coalesced, return here + groupEnd(); + coalescing_state_ = 0; + return nullptr; + } + TORCH_CHECK( + coalescedDevice_.index() >= 0, + "Somthing went wrong. Did you call end_coalescing before start_coalescing?"); + + auto comm = coalescedComm_; + auto device = coalescedDevice_; + + const auto key = std::to_string(device.index()); + auto stream = xcclStreamsMap_.at(key); + + auto work = initWork(device, rank_, optype); + work->blockingWait_ = blockingWait_; + + groupEnd(); + + work->xcclEndEvent_->record(stream); + + coalescing_state_ = 0; + coalescedComm_ = nullptr; + return work; +} + +c10::intrusive_ptr ProcessGroupXCCL::endCoalescing() { + // Default OpType to COALESCED if not specified + return endCoalescing(OpType::COALESCED); +} + +template +c10::intrusive_ptr ProcessGroupXCCL::collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle) { + seqCollective_++; + auto device = inputs[0].device(); + const auto key = std::to_string(device.index()); + auto comm = getXCCLComm(key, device, opType); + + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalColl; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = comm; + } else { + TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); + } + } + + auto stream = xcclStreamsMap_.at(key); + syncStream(device, xcclEventsMap_[key], stream); + + c10::intrusive_ptr work; + work = initWork(device, rank_, opType); + + work->outputs_ = std::make_shared>(outputs); + + at::xpu::OptionalXPUGuard gpuGuard(device); + + pre(stream, work); + + for (const auto i : c10::irange(inputs.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputs[i].storage().data_ptr(), stream); + fn(inputs[i], outputs[i], *comm, stream); + } + + post(stream, work); + + if (!coalescing_state_) { + work->xcclEndEvent_->record(stream); + } + + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); + work->blockingWait_ = blockingWait_; + + return work; +} + +template +c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + const char* profilingTitle) { + auto device = tensor.device(); + std::string key; + int p2pRank = 0, p2pTargetRank = 0; + bool isSendRecvSelf = false; + + bool batchP2P = xcclActiveGroupCounter_ > 0; + if (batchP2P) { + key = std::to_string(device.index()); + p2pRank = rank_; + p2pTargetRank = peer; + } else { + int lowRank = rank_ < peer ? rank_ : peer; + int highRank = rank_ < peer ? peer : rank_; + key = std::to_string(lowRank) + ":" + std::to_string(highRank); + p2pRank = rank_ <= peer ? 0 : 1; + isSendRecvSelf = rank_ == peer; + p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; + if (!coalescing_state_) { + seqP2P_++; + } + } + + auto comm = getXCCLComm(key, device, opType, p2pRank, isSendRecvSelf); + + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalP2P; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = comm; + } else { + TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); + } + } + + auto stream = xcclStreamsMap_.at(key); + syncStream(device, xcclEventsMap_[key], stream); + + if (!coalescing_state_) { + c10::intrusive_ptr work; + work = initWork(device, rank_, opType); + work->outputs_ = std::make_shared>(); + work->outputs_->push_back(tensor); + + at::xpu::OptionalXPUGuard gpuGuard(device); + + c10::xpu::XPUCachingAllocator::recordStream( + tensor.storage().data_ptr(), stream); + + fn(tensor, *comm, stream, p2pTargetRank); + + work->xcclEndEvent_->record(stream); + work->blockingWait_ = blockingWait_; + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); + return work; + } else { + at::xpu::OptionalXPUGuard gpuGuard(device); + + c10::xpu::XPUCachingAllocator::recordStream( + tensor.storage().data_ptr(), stream); + + fn(tensor, *comm, stream, p2pTargetRank); + + return nullptr; + } +} + +c10::intrusive_ptr ProcessGroupXCCL::send( + std::vector& tensors, + int dstRank, + int /* unused */) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor, true); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + dstRank, // dst rank + "send", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + auto ret = pointToPoint( + tensor, + [&](at::Tensor& input, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + int dst) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::send( + input.data_ptr(), + (size_t)input.numel(), + xcclDataType, + dst, + comm, + ccl::create_stream(stream.queue())); + return; + }, + dstRank, + OpType::SEND, + c10::str("xccl:send ", rank_, "->", dstRank).c_str()); + return ret; +} + +c10::intrusive_ptr ProcessGroupXCCL::recv( + std::vector& tensors, + int srcRank, + int /* unused */) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor, true); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + srcRank, // src rank + "recv", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + auto ret = pointToPoint( + tensor, + [&](at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + int src) { + auto xcclDataType = getXcclDataType(output.scalar_type()); + ccl::recv( + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + src, + comm, + ccl::create_stream(stream.queue())); + return; + }, + srcRank, + OpType::RECV, + c10::str("xccl:recv ", rank_, "<-", srcRank).c_str()); + return ret; +} + +c10::intrusive_ptr ProcessGroupXCCL::gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts) { + static auto invalidArgument = [](const std::string& msg) { + C10_THROW_ERROR(ValueError, "ProcessGroupXCCL::gather: " + msg); + }; + + assertRootRank(invalidArgument, opts.rootRank, size_); + + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto inputTensor = inputTensors.back(); + + std::vector outputs; + + if (getRank() == opts.rootRank) { + if (outputTensors.size() != 1) { + std::stringstream ss; + ss << "requires a single-element output list containing a list with " + << getSize() << " tensors."; + invalidArgument(ss.str()); + } else if (outputTensors[0].size() != static_cast(getSize())) { + std::stringstream ss; + ss << "Incorrect output list size " << outputTensors[0].size() + << ". Output list size should be " << getSize() + << ", same as size of the process group."; + invalidArgument(ss.str()); + } + + const auto& options = inputTensor.options(); + const auto& sizes = inputTensor.sizes(); + assertTypeAndSizesMatch(invalidArgument, outputTensors[0], options, sizes); + outputs = outputTensors[0]; + } else { + // if not in the root rank, initialize outputs as empty list + if (outputTensors.size() != 0) { + invalidArgument("requires empty output on non-root"); + } + outputs = {}; + // append a empty tensor to the list, we don't use it but the + // `collective` template function requires it to invoke its function + outputs.emplace_back(); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + opts.rootRank, // root rank + "gather", // collective name + inputTensor.numel(), // inNelems + inputTensor.numel() * this->getSize(), // outNelems + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + auto inputs = std::vector{inputTensor}; + return collective( + inputs, + outputs, // just to fit the collective interface + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + const auto root = opts.rootRank; + if (getRank() == root) { + for (auto output : outputs) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + } + { + auto xcclDataType = getXcclDataType(inputTensor.scalar_type()); + if (rank_ == root) { + for (const auto r : c10::irange(size_)) { + if (r != root) { + // do receive + ccl::recv( + outputs[r].data_ptr(), + (size_t)inputTensor.numel(), + xcclDataType, + r, + comm, + ccl::create_stream(stream.queue())); + } else { + // on its own rank, simply copy from the input + outputs[r].copy_(inputTensor); + } + } + } else { + // do send + ccl::send( + inputTensor.data_ptr(), + (size_t)inputTensor.numel(), + xcclDataType, + root, + comm, + ccl::create_stream(stream.queue())); + } + return; + } + }, + OpType::GATHER); +} + +c10::intrusive_ptr ProcessGroupXCCL::scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts) { + static auto invalidArgument = [](const std::string& msg) { + C10_THROW_ERROR(ValueError, "ProcessGroupXCCL::scatter: " + msg); + }; + + assertRootRank(invalidArgument, opts.rootRank, size_); + + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto outputTensor = outputTensors.back(); + + std::vector inputs; + + if (getRank() == opts.rootRank) { + if (inputTensors.size() != 1) { + std::stringstream ss; + ss << "requires a single-element input list containing a list with " + << getSize() << " tensors."; + invalidArgument(ss.str()); + } else if (inputTensors[0].size() != static_cast(getSize())) { + std::stringstream ss; + ss << "Incorrect input list size " << inputTensors[0].size() + << ". Input list size should be " << getSize() + << ", same as size of the process group."; + invalidArgument(ss.str()); + } + + const auto& options = outputTensor.options(); + const auto& sizes = outputTensor.sizes(); + assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes); + inputs = inputTensors[0]; + } else { + // if not in the root rank, initialize inputTensors as empty place holder + // with an empty list + if (inputTensors.size() != 0) { + invalidArgument("requires empty input on non-root"); + } + inputs = {}; + // append a empty tensor to the list, we don't use it but the + // `collective` template function requires it to invoke its function + inputs.emplace_back(); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + opts.rootRank, // root rank + "scatter", // collective name + outputTensor.numel() * this->getSize(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + const auto root = opts.rootRank; + + auto outputs = std::vector{outputTensor}; + return collective( + outputs, + inputs, // just to fit the collective interface + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + if (getRank() == root) { + for (auto input : inputs) { + c10::xpu::XPUCachingAllocator::recordStream( + input.storage().data_ptr(), stream); + } + } + { + if (rank_ == root) { + for (const auto r : c10::irange(size_)) { + if (r != root) { + // do send + size_t send_count = inputs[r].numel(); + auto send_type = getXcclDataType(inputs[r].scalar_type()); + ccl::send( + inputs[r].data_ptr(), + send_count, + send_type, + r, + comm, + ccl::create_stream(stream.queue())); + } else { + // on its own rank, simply copy from the input + outputTensor.copy_(inputs[r]); + } + } + } else { + // do receive + size_t recv_count = outputTensor.numel(); + auto recv_type = getXcclDataType(outputTensor.scalar_type()); + ccl::recv( + outputTensor.data_ptr(), + recv_count, + recv_type, + root, + comm, + ccl::create_stream(stream.queue())); + } + + return; + } + }, + OpType::SCATTER); +} + +c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( + at::Tensor& tensor, + const AllreduceOptions& opts) { + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + auto ccl_stream = ccl::create_stream(stream.queue()); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::ALLREDUCE, + "xccl:all_reduce"); +} + +c10::intrusive_ptr ProcessGroupXCCL::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor); + + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + size_); // worldSize + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::ALLREDUCE, + "xccl:all_reduce"); +} + +c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts) { + auto total_numel = check_xpu_tensors_same_device(tensors); + + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce_coalesced", // collective name + total_numel, // inNelems + total_numel, // outNelems + tensors[0].scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collectiveCoalesced( + tensors, + tensors, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::COALESCED, + "xccl:allreduce_coalesced"); +} + +c10::intrusive_ptr ProcessGroupXCCL::broadcast( + std::vector& tensors, + const BroadcastOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor); + + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + opts.rootRank, // root rank + "broadcast", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + const auto root = opts.rootRank + opts.rootTensor; + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::broadcast( + input.data_ptr(), + (size_t)input.numel(), + xcclDataType, + root, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::BROADCAST, + "nccl:broadcast"); +} + +c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const BroadcastOptions& opts) { + if (outputTensor.numel() != inputTensor.numel()) { + C10_THROW_ERROR( + ValueError, + "Tensor input and output of _broadcast_oop must have the same number of elements "); + } + const auto root = opts.rootRank + opts.rootTensor; + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::broadcast( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + root, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::BROADCAST, + "xccl:_broadcast_oop"); +} + +c10::intrusive_ptr ProcessGroupXCCL::reduce( + std::vector& tensors, + const ReduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + opts.rootRank, // root rank + "reduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + const int root = opts.rootRank + opts.rootTensor; + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); + const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + root, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::REDUCE, + "xccl:reduce"); +} + +c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceOptions& opts) { + TORCH_CHECK_WITH( + ValueError, + outputTensor.numel() == inputTensor.numel(), + "Tensor input and output of _reduce_oop must have the same number of elements"); + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + const int root = opts.rootRank + opts.rootTensor; + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); + const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + root, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::REDUCE, + "xccl:_reduce_oop"); +} + +c10::intrusive_ptr ProcessGroupXCCL::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts) { + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto inputTensor = inputTensors.back(); + check_xpu_single_tensor(inputTensor); + // @lint-ignore CLANGTIDY + std::vector& outputTensors_ = outputTensors.back(); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_gather", // collective name + inputTensor.numel(), // inNelems + inputTensor.numel() * // outNelems + this->getSize(), + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + bool same_size = check_same_size(outputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor outputFlattened = newLikeFlat(outputTensors_); + + return collective( + inputTensor, + outputFlattened, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::allgather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + return; + }, + [](at::xpu::XPUStream&, + c10::intrusive_ptr& work) {}, + [&](at::xpu::XPUStream& Stream, + c10::intrusive_ptr& work) { + // Copy the flattened output tensors to the outputs. + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(outputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + outputTensors_[j].storage().data_ptr(), Stream); + outputTensors_[j].copy_(outputFlattened[j], true); + } + }, + OpType::ALLGATHER, + "xccl:all_gather"); + } else { + const auto num_reduces = outputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& output = outputTensors_[i]; + auto& input = (i == rank_) ? inputTensor : output; + auto broadcastOpts = BroadcastOptions{ + static_cast(i), static_cast(0), opts.timeout}; + _broadcast_oop(output, input, broadcastOpts); + } + auto work = endCoalescing(OpType::ALLGATHER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const AllgatherOptions& opts) { + check_xpu_single_tensor(input_tensor); + check_xpu_single_tensor(output_tensor); + + TORCH_CHECK_WITH( + TypeError, + input_tensor.dtype() == output_tensor.dtype(), + "output tensor must have the same type as input tensor"); + TORCH_CHECK_WITH( + ValueError, + input_tensor.numel() * size_ == output_tensor.numel(), + "output tensor size must be equal to world_size times input tensor size"); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + input_tensor, // inputTensors + output_tensor, // outputTensors + rank_, // rank + "_allgather_base", // collective name + input_tensor.numel(), // inNelems + output_tensor.numel(), // outNelems + output_tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + input_tensor, + output_tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::allgather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::_ALLGATHER_BASE, + "xccl:_all_gather_base"); +} + +c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts) { + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::allgather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::COALESCED, + "xccl:all_gather_into_tensor_coalesced"); +} + +c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts) { + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto outputTensor = outputTensors.back(); + check_xpu_single_tensor(outputTensor); + // @lint-ignore CLANGTIDY + auto inputTensors_ = inputTensors.back(); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "reduce_scatter", // collective name + outputTensor.numel() * this->getSize(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + bool same_size = check_same_size(inputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor inputFlattened = newLikeFlat(inputTensors_); + return collective( + inputFlattened, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce_scatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + [&](at::xpu::XPUStream& Stream, + c10::intrusive_ptr& work) { + // Copy the input tensors to the flattened inputs. + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(inputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputTensors_[j].storage().data_ptr(), Stream); + inputFlattened[j].copy_(inputTensors_[j], true); + } + }, + [&](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + OpType::REDUCE_SCATTER, + "xccl:reduce_scatter"); + } else { + const auto num_reduces = inputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& input = inputTensors_[i]; + auto& output = (i == rank_) ? outputTensor : input; + auto reduceOpts = ReduceOptions{ + opts.reduceOp, + static_cast(i), + static_cast(0), + opts.timeout}; + _reduce_oop(output, input, reduceOpts); + } + auto work = endCoalescing(OpType::REDUCE_SCATTER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts) { + TORCH_CHECK_WITH( + TypeError, + inputTensor.dtype() == outputTensor.dtype(), + "input tensor must be the same type as the output tensor."); + TORCH_CHECK_WITH( + ValueError, + inputTensor.numel() == outputTensor.numel() * size_, + "input tensor must be the same size as output size times world size"); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "_reduce_scatter_base", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dtype + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce_scatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::_REDUCE_SCATTER_BASE, + "xccl:_reduce_scatter_base"); +} + +c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts) { + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce_scatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::COALESCED, + "xccl:reduce_scatter_tensor_coalesced"); +} + +c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { + RECORD_PARAM_COMMS( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + rank_, // rank + "barrier", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + // Device to use for barrier + int barDevIdx = -1; + + // See nccl barrier comments + if (!opts.device_ids.empty()) { + barDevIdx = opts.device_ids[0]; + } else if (getBoundDeviceId()) { + barDevIdx = (*getBoundDeviceId()).index(); + } else if (!usedDeviceIdxs_.empty()) { + barDevIdx = *usedDeviceIdxs_.begin(); + } else { + barDevIdx = + static_cast(rank_ % at::detail::getXPUHooks().getNumGPUs()); + } + + // todo: use barrier instead of allreduce + TORCH_CHECK_WITH( + ValueError, + barDevIdx >= 0, + "Failed to infer a GPU device id to perform barrier. "); + auto barDevice = at::Device(at::DeviceType::XPU, barDevIdx); + + at::Tensor barrierTensor = + at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); + + auto work = allreduce_impl(barrierTensor); + + auto xcclWork = dynamic_cast(work.get()); + TORCH_CHECK(xcclWork); + xcclWork->barrierTensor_ = std::move(barrierTensor); + return work; +} + +c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& /* unused */) { + check_xpu_single_tensor(outputTensor, true); + check_xpu_single_tensor(inputTensor, true); + if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_all", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + TORCH_CHECK( + outputTensor.numel() == inputTensor.numel() && + outputTensor.scalar_type() == inputTensor.scalar_type(), + "xpu_alltoall_base: tensors are not equal in size or data type"); + TORCH_CHECK( + outputTensor.size(0) % size_ == 0, + "xpu_alltoall_base: tensor's dim 0 does not divide equally across group size"); + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(output.scalar_type()); + ccl::alltoall( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel() / comm.size(), + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::ALLTOALL_BASE, + "xccl:all_to_all"); + } else { + c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); + c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_allv", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + inputSplitSizes, // inSplitSizes + outputSplitSizes, // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + std::vector sendCounts(size_); + std::vector recvCounts(size_); + bool inputSplitsEqual = inputSplitSizes.size() == 0; + bool outputSplitsEqual = outputSplitSizes.size() == 0; + + size_t inLen = input.numel(); + size_t outLen = output.numel(); + if (inLen) + inLen /= (inputSplitsEqual ? size_ : input.size(0)); + if (outLen) + outLen /= (outputSplitsEqual ? size_ : output.size(0)); + + for (int i = 0; i < size_; i++) { + sendCounts[i] = + (inputSplitsEqual ? inLen : inputSplitSizes[i] * inLen); + recvCounts[i] = + (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); + } + auto xcclDataType = getXcclDataType(output.scalar_type()); + ccl::alltoallv( + input.data_ptr(), + sendCounts, + output.data_ptr(), + recvCounts, + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::ALLTOALL_BASE, + "xccl:all_to_all"); + } +} + +c10::intrusive_ptr ProcessGroupXCCL::alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& /* unused */) { + auto device = outputTensors[0].device(); + int64_t total_numel = 0; + for (const auto r : c10::irange(outputTensors.size())) { + check_xpu_single_tensor(outputTensors[r], true); + check_xpu_single_tensor(inputTensors[r], true); + TORCH_CHECK( + device == outputTensors[r].device() && + device == inputTensors[r].device(), + "Tensors must be on the same device") + total_numel += inputTensors[r].numel(); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_to_all", // collective name + total_numel, // inNelems + total_numel, // outNelems + inputTensors.front().scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensors, + outputTensors, + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::OptionalStreamGuard stream_guard(stream.unwrap()); + at::Tensor flatInput; + at::Tensor flatOutput; + + std::vector sendCounts(size_); + std::vector recvCounts(size_); + + int64_t flatSendCount; + int64_t flatRecvCount; + + bool isInputFlat = computeLengthsAndCheckAndGetFlat( + inputTensors, sendCounts, flatInput, flatSendCount); + bool isOutputFlat = computeLengthsAndCheckAndGetFlat( + outputTensors, recvCounts, flatOutput, flatRecvCount); + if (!isInputFlat) { + auto flatInputSplits = flatInput.split_with_sizes( + c10::IntArrayRef((int64_t*)sendCounts.data(), sendCounts.size()), + 0); + + for (int i = 0; i < size_; i++) { + flatInputSplits[i].copy_(inputTensors[i].view({-1})); + } + } + + auto xcclDataType = getXcclDataType(flatOutput.scalar_type()); + ccl::event ret_evt; + ret_evt = ccl::alltoallv( + flatInput.data_ptr(), + sendCounts, + flatOutput.data_ptr(), + recvCounts, + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + + if (!isOutputFlat) { + ret_evt.wait(); + auto flatOutputSplits = flatOutput.split_with_sizes( + c10::IntArrayRef((int64_t*)recvCounts.data(), recvCounts.size()), + 0); + + for (int i = 0; i < size_; i++) { + outputTensors[i].view({-1}).copy_(flatOutputSplits[i]); + } + } + stream.synchronize(); + return; + }, + OpType::ALLTOALL, + "xccl:all_to_all"); +} + +} // namespace c10d + +#endif // USE_C10D_XCCL + diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp new file mode 100644 index 000000000..2b2837bfb --- /dev/null +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -0,0 +1,372 @@ +#pragma once + +#ifdef USE_C10D_XCCL +// We will define those flags in XCCL backend file instead of passing to gcc +// compiler. +#define CCL_ENABLE_ZE +#define CCL_ENABLE_SYCL + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +namespace c10d { + +static std::vector TORCH_XCCL_BLOCKING_WAIT = { + "TORCH_XCCL_BLOCKING_WAIT", + "XCCL_BLOCKING_WAIT"}; + +using xcclComm_t = ccl::communicator; +using XCCL_KVS = ccl::shared_ptr_class; +constexpr const char* XCCL_BACKEND_NAME = "xccl"; + +class TORCH_API ProcessGroupXCCL : public Backend { + public: + class WorkXCCL : public Work { + public: + WorkXCCL( + at::Device& device, + int rank, + OpType opType, + uint64_t seq, + const char* profilingTitle = nullptr, + const std::optional>& inputs = std::nullopt); + WorkXCCL(const WorkXCCL& w); + ~WorkXCCL() override; + + bool isCompleted() override; + + void abort() override { + TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::abort not implemented"); + } + + void synchronize() override; + + bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; + + c10::intrusive_ptr getFuture() override { + return future_; + } + + uint64_t getSequencenumber() const override { + return seq_; + } + + std::vector result() override { + return *outputs_; + } + + protected: + at::Device device_; + std::shared_ptr xcclEndEvent_; + at::Tensor barrierTensor_; + bool blockingWait_ = false; + std::chrono::time_point workStartTime_; + uint64_t seq_; + + private: + void synchronizeInternal(std::chrono::milliseconds timeout); + std::shared_ptr> outputs_; + c10::intrusive_ptr future_; + friend class ProcessGroupXCCL; + }; + + ProcessGroupXCCL(const c10::intrusive_ptr& store, int rank, int size); + + C10_DEPRECATED ProcessGroupXCCL( + const c10::intrusive_ptr& store, + int rank, + int size, + const std::string& groupName) + : ProcessGroupXCCL(store, rank, size) {} + + ~ProcessGroupXCCL() override; + + const std::string getBackendName() const override { + return std::string(XCCL_BACKEND_NAME); + } + + void startCoalescing() override; + + c10::intrusive_ptr endCoalescing() override; + + c10::intrusive_ptr endCoalescing(OpType optype); + + std::shared_ptr getXCCLComm( + const std::string& deviceKey, + at::Device& device, + OpType opType, + int p2pRank = 0, + bool isSendRecvSelf = false); + + virtual c10::intrusive_ptr initWork( + at::Device& device, + int rank, + OpType opType, + const char* profilingTitle = nullptr, + const std::vector& inputs = {}, + const std::vector& outputs = {}); + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + opType, + profilingTitle); + } + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective(inputs, outputs, fn, pre, post, opType, profilingTitle); + } + + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr) { + return collective( + inputs, + outputs, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + opType, + profilingTitle); + } + + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr); + + template + c10::intrusive_ptr collectiveCoalesced( + std::vector& input, + std::vector& output, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) { + ccl::group_start(); + }, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) { + ccl::group_end(); + }, + opType, + profilingTitle); + } + + template + c10::intrusive_ptr pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + const char* profilingTitle = nullptr); + + c10::intrusive_ptr allreduce_impl( + at::Tensor& tensor, + const AllreduceOptions& opts = AllreduceOptions()); + + c10::intrusive_ptr allreduce( + std::vector& tensors, + const AllreduceOptions& opts = AllreduceOptions()) override; + + c10::intrusive_ptr allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts = + AllreduceCoalescedOptions()) override; + + c10::intrusive_ptr reduce( + std::vector& tensors, + const ReduceOptions& opts = ReduceOptions()) override; + + c10::intrusive_ptr _reduce_oop( + at::Tensor& outputTensors, + at::Tensor& inputTensors, + const ReduceOptions& opts = ReduceOptions()); + + c10::intrusive_ptr broadcast( + std::vector& tensors, + const BroadcastOptions& opts = BroadcastOptions()) override; + + c10::intrusive_ptr _broadcast_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const BroadcastOptions& opts); + + c10::intrusive_ptr allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr _allgather_base( + at::Tensor& outputbuffer, + at::Tensor& inputbuffer, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr barrier( + const BarrierOptions& opts = BarrierOptions()) override; + + c10::intrusive_ptr alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& opts = AllToAllOptions()) override; + + c10::intrusive_ptr alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& opts = AllToAllOptions()) override; + + c10::intrusive_ptr send( + std::vector& tensors, + int dstRank, + int tag) override; + + c10::intrusive_ptr recv( + std::vector& tensors, + int srcRank, + int tag) override; + + void groupStart(); + + void groupEnd(); + + c10::intrusive_ptr gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts = GatherOptions()) override; + + c10::intrusive_ptr scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts = ScatterOptions()) override; + + void setSequenceNumberForGroup() override; + + uint64_t getSequenceNumberForGroup() override; + + protected: + std::unordered_map xcclStreamsMap_; + std::unordered_map xcclEventsMap_; + std::unordered_map> devXCCLCommMap_; + c10::intrusive_ptr store_; + uint64_t xcclCommCounter_{0}; + std::mutex mutex_; + std::set usedDeviceIdxs_; + int coalescing_state_ = 0; + at::Device coalescedDevice_ = at::Device("xpu"); + std::shared_ptr coalescedComm_ = nullptr; + bool blockingWait_ = false; + static thread_local uint64_t xcclActiveGroupCounter_; + uint64_t seqCollective_{0}; + uint64_t seqP2P_{0}; + + private: + std::mutex kvs_mutex; + + ccl::shared_ptr_class get_kvs( + int rank, + c10d::Store& store, + bool singleP2POp = false, + const std::string& p2pKey = "", + int p2pRank = 0) { + std::lock_guard lock(kvs_mutex); + ccl::shared_ptr_class kvs; + std::string storeKey; + if (!singleP2POp) { + storeKey = std::to_string(xcclCommCounter_++); + } else { + storeKey = p2pKey; + } + // Rank 0 broadcast the bootstrap network information to other ranks + if (rank == 0 || (singleP2POp && p2pRank == 0)) { + kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = kvs->get_address(); + auto ccl_kvs_addr = + std::vector(main_addr.begin(), main_addr.end()); + store.set(storeKey, ccl_kvs_addr); + } else { + auto ccl_kvs_addr = store.get(storeKey); + if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { + throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); + } + ccl::kvs::address_type main_addr; + std::copy_n( + ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); + kvs = ccl::create_kvs(main_addr); + } + return kvs; + } +}; +} // namespace c10d + +#endif // USE_C10D_XCCL + From f01b17364cbb3ac494bb181b102bf7e73da1e25b Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Wed, 20 Nov 2024 22:19:28 +0800 Subject: [PATCH 02/19] update cmake --- CMakeLists.txt | 2 +- src/BuildOnLinux.cmake | 6 +++--- src/CMakeLists.txt | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 33d2d7667..2c5a380e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,7 +53,7 @@ endif() if(NOT WIN32 AND BUILD_WITH_XCCL) include(${TORCH_XPU_OPS_ROOT}/cmake/XCCL.cmake) - set(USE_C10D_XCCL) + set(USE_C10D_XCCL ON) endif() if(BUILD_TEST) diff --git a/src/BuildOnLinux.cmake b/src/BuildOnLinux.cmake index f32f840cd..9d37ac303 100644 --- a/src/BuildOnLinux.cmake +++ b/src/BuildOnLinux.cmake @@ -29,14 +29,14 @@ if(BUILD_SEPARATE_OPS) foreach(xccl_src ${ATen_XPU_XCCL_SRCS}) get_filename_component(name ${xccl_src} NAME_WLE REALPATH) set(xccl_lib torch-xpu-ops-xccl-${name}) - target_link_libraries(xccl_lib PRIVATE torch::xccl) sycl_add_library( ${xccl_lib} SHARED CXX_SOURCES ${xccl_src}) target_link_libraries(torch_xpu_ops PUBLIC ${xccl_lib}) list(APPEND TORCH_XPU_OPS_LIBRARIES ${xccl_lib}) - install(TARGETS ${xccl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + install(TARGETS ${xccl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") + target_link_libraries(torch_xpu_ops PRIVATE torch::xccl) endforeach() endif() else() @@ -119,12 +119,12 @@ else() install(TARGETS ${sycl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") if(USE_C10D_XCCL) set(xccl_lib torch_xpu_ops_xccl) - target_link_libraries(xccl_lib PRIVATE torch::xccl) sycl_add_library( ${xccl_lib} SHARED CXX_SOURCES ${ATen_XPU_XCCL_SRCS}) target_link_libraries(torch_xpu_ops PUBLIC ${xccl_lib}) + target_link_libraries(torch_xpu_ops PRIVATE torch::xccl) install(TARGETS ${xccl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") endif() endif() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a1d7f49be..7a427e294 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -9,7 +9,9 @@ set(ATen_XPU_XCCL_SRCS) set(ATen_XPU_INCLUDE_DIRS ${TORCH_XPU_OPS_ROOT}/src CACHE STRING "ATen XPU Include directory") add_subdirectory(ATen) - +if(USE_C10D_XCCL) + add_subdirectory(xccl) +endif() # With the increasement of bin size, we have to split libtorch_xpu.so into # multiple libraries. Because of strict linkage requirements on Windows, # we add extra logics to resolve, 1) Cyclic dependence, 2) Make symbols visible. From 405013cd76be0fde0278e526896e4cb73afd05f7 Mon Sep 17 00:00:00 2001 From: hanchao Date: Wed, 20 Nov 2024 06:50:20 +0000 Subject: [PATCH 03/19] update --- src/BuildOnLinux.cmake | 2 ++ src/xccl/ProcessGroupXCCL.cpp | 3 +-- src/xccl/ProcessGroupXCCL.hpp | 27 ++++++++++++++++++++++++++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/BuildOnLinux.cmake b/src/BuildOnLinux.cmake index 9d37ac303..fd08ceaeb 100644 --- a/src/BuildOnLinux.cmake +++ b/src/BuildOnLinux.cmake @@ -33,6 +33,7 @@ if(BUILD_SEPARATE_OPS) ${xccl_lib} SHARED CXX_SOURCES ${xccl_src}) + target_compile_definitions(${xccl_lib} PRIVATE USE_C10D_XCCL) target_link_libraries(torch_xpu_ops PUBLIC ${xccl_lib}) list(APPEND TORCH_XPU_OPS_LIBRARIES ${xccl_lib}) install(TARGETS ${xccl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") @@ -123,6 +124,7 @@ else() ${xccl_lib} SHARED CXX_SOURCES ${ATen_XPU_XCCL_SRCS}) + target_compile_definitions(${xccl_lib} PRIVATE USE_C10D_XCCL) target_link_libraries(torch_xpu_ops PUBLIC ${xccl_lib}) target_link_libraries(torch_xpu_ops PRIVATE torch::xccl) install(TARGETS ${xccl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 4f407d133..71e26fbe0 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include namespace c10d { @@ -1792,4 +1792,3 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( } // namespace c10d #endif // USE_C10D_XCCL - diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 2b2837bfb..e2099e658 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -368,5 +368,30 @@ class TORCH_API ProcessGroupXCCL : public Backend { }; } // namespace c10d +namespace { +inline std::string reduceOpToString(c10d::ReduceOp op) { + switch (op) { + case c10d::ReduceOp::SUM: + return "SUM"; + case c10d::ReduceOp::PRODUCT: + return "PRODUCT"; + case c10d::ReduceOp::MIN: + return "MIN"; + case c10d::ReduceOp::MAX: + return "MAX"; + case c10d::ReduceOp::BAND: + return "BAND"; + case c10d::ReduceOp::BOR: + return "BOR"; + case c10d::ReduceOp::BXOR: + return "BXOR"; + case c10d::ReduceOp::AVG: + return "AVG"; + case c10d::ReduceOp::PREMUL_SUM: + return "PREMUL_SUM"; + default: + return "UNKNOWN"; + } +} +} // namespace #endif // USE_C10D_XCCL - From 0d8bb51c82337f54e5ecf76d84e8bcc389bcfab9 Mon Sep 17 00:00:00 2001 From: hanchao Date: Wed, 20 Nov 2024 07:10:08 +0000 Subject: [PATCH 04/19] oneccl private for xccl --- src/BuildOnLinux.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/BuildOnLinux.cmake b/src/BuildOnLinux.cmake index fd08ceaeb..918989fef 100644 --- a/src/BuildOnLinux.cmake +++ b/src/BuildOnLinux.cmake @@ -34,10 +34,10 @@ if(BUILD_SEPARATE_OPS) SHARED CXX_SOURCES ${xccl_src}) target_compile_definitions(${xccl_lib} PRIVATE USE_C10D_XCCL) + target_link_libraries(${xccl_lib} PRIVATE torch::xccl) target_link_libraries(torch_xpu_ops PUBLIC ${xccl_lib}) list(APPEND TORCH_XPU_OPS_LIBRARIES ${xccl_lib}) install(TARGETS ${xccl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") - target_link_libraries(torch_xpu_ops PRIVATE torch::xccl) endforeach() endif() else() @@ -125,8 +125,8 @@ else() SHARED CXX_SOURCES ${ATen_XPU_XCCL_SRCS}) target_compile_definitions(${xccl_lib} PRIVATE USE_C10D_XCCL) + target_link_libraries(${xccl_lib} PRIVATE torch::xccl) target_link_libraries(torch_xpu_ops PUBLIC ${xccl_lib}) - target_link_libraries(torch_xpu_ops PRIVATE torch::xccl) install(TARGETS ${xccl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") endif() endif() From 77148858d82dd1e726f71ef7c67fd6caffc7e17f Mon Sep 17 00:00:00 2001 From: hanchao Date: Tue, 26 Nov 2024 03:06:31 +0000 Subject: [PATCH 05/19] update cmake --- src/BuildOnLinux.cmake | 31 +++++-------------------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/src/BuildOnLinux.cmake b/src/BuildOnLinux.cmake index 918989fef..d0f28ad29 100644 --- a/src/BuildOnLinux.cmake +++ b/src/BuildOnLinux.cmake @@ -11,6 +11,11 @@ add_library( ${ATen_XPU_GEN_SRCS} ${ATen_XPU_XCCL_SRCS}) +if(USE_C10D_XCCL) + target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL) + target_link_libraries(torch_xpu_ops PUBLIC torch::xccl) +endif() + if(BUILD_SEPARATE_OPS) foreach(sycl_src ${ATen_XPU_SYCL_SRCS}) get_filename_component(name ${sycl_src} NAME_WLE REALPATH) @@ -25,21 +30,6 @@ if(BUILD_SEPARATE_OPS) # Decouple with PyTorch cmake definition. install(TARGETS ${sycl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") endforeach() - if(USE_C10D_XCCL) - foreach(xccl_src ${ATen_XPU_XCCL_SRCS}) - get_filename_component(name ${xccl_src} NAME_WLE REALPATH) - set(xccl_lib torch-xpu-ops-xccl-${name}) - sycl_add_library( - ${xccl_lib} - SHARED - CXX_SOURCES ${xccl_src}) - target_compile_definitions(${xccl_lib} PRIVATE USE_C10D_XCCL) - target_link_libraries(${xccl_lib} PRIVATE torch::xccl) - target_link_libraries(torch_xpu_ops PUBLIC ${xccl_lib}) - list(APPEND TORCH_XPU_OPS_LIBRARIES ${xccl_lib}) - install(TARGETS ${xccl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") - endforeach() - endif() else() # Split SYCL kernels into 4 libraries as categories 1) Unary+Binary 2) Reduce 3) Foreach 4) Others. set(ATen_XPU_SYCL_UNARY_BINARY_SRCS) @@ -118,17 +108,6 @@ else() # Decouple with PyTorch cmake definition. install(TARGETS ${sycl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") - if(USE_C10D_XCCL) - set(xccl_lib torch_xpu_ops_xccl) - sycl_add_library( - ${xccl_lib} - SHARED - CXX_SOURCES ${ATen_XPU_XCCL_SRCS}) - target_compile_definitions(${xccl_lib} PRIVATE USE_C10D_XCCL) - target_link_libraries(${xccl_lib} PRIVATE torch::xccl) - target_link_libraries(torch_xpu_ops PUBLIC ${xccl_lib}) - install(TARGETS ${xccl_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}") - endif() endif() set(SYCL_LINK_LIBRARIES_KEYWORD) From b7706408c6ba6d6ef2a4be45eceff6650bfa6f80 Mon Sep 17 00:00:00 2001 From: hanchao Date: Wed, 27 Nov 2024 01:01:33 +0000 Subject: [PATCH 06/19] update commit and add register --- CMakeLists.txt | 20 +-- src/xccl/CMakeLists.txt | 5 + src/xccl/Register.cpp | 313 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 328 insertions(+), 10 deletions(-) create mode 100644 src/xccl/Register.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 2c5a380e5..4ea88dd75 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,20 +38,20 @@ list(APPEND CMAKE_MODULE_PATH ${TORCH_XPU_OPS_ROOT}/cmake/Modules) include(${TORCH_XPU_OPS_ROOT}/cmake/SYCL.cmake) include(${TORCH_XPU_OPS_ROOT}/cmake/BuildFlags.cmake) -option(BUILD_WITH_XCCL "Build with XCCL support" ON) -if (DEFINED ENV{BUILD_WITH_XCCL}) - string(TOLOWER "$ENV{BUILD_WITH_XCCL}" BUILD_WITH_XCCL_LOWER) - - if (NOT (BUILD_WITH_XCCL_LOWER STREQUAL "1" OR - BUILD_WITH_XCCL_LOWER STREQUAL "on" OR - BUILD_WITH_XCCL_LOWER STREQUAL "yes")) - set(BUILD_WITH_XCCL OFF CACHE BOOL "Build with XCCL support" FORCE) +option(USE_XCCL "Build with XCCL support" ON) +if (DEFINED ENV{USE_XCCL}) + string(TOLOWER "$ENV{USE_XCCL}" USE_XCCL_LOWER) + + if (NOT (USE_XCCL_LOWER STREQUAL "1" OR + USE_XCCL_LOWER STREQUAL "on" OR + USE_XCCL_LOWER STREQUAL "yes")) + set(USE_XCCL OFF CACHE BOOL "Build with XCCL support" FORCE) else() - set(BUILD_WITH_XCCL ON CACHE BOOL "Build with XCCL support" FORCE) + set(USE_XCCL ON CACHE BOOL "Build with XCCL support" FORCE) endif() endif() -if(NOT WIN32 AND BUILD_WITH_XCCL) +if(NOT WIN32 AND USE_XCCL) include(${TORCH_XPU_OPS_ROOT}/cmake/XCCL.cmake) set(USE_C10D_XCCL ON) endif() diff --git a/src/xccl/CMakeLists.txt b/src/xccl/CMakeLists.txt index 809181d55..242cf7e26 100644 --- a/src/xccl/CMakeLists.txt +++ b/src/xccl/CMakeLists.txt @@ -1,7 +1,12 @@ # XCCL sources +file(GLOB xccl_h "*.hpp") file(GLOB xccl_cpp "*.cpp") list(APPEND ATen_XPU_XCCL_SRCS ${xccl_cpp}) set(ATen_XPU_XCCL_SRCS ${ATen_XPU_XCCL_SRCS} PARENT_SCOPE) + +foreach(HEADER ${xccl_h}) + file(COPY ${HEADER} DESTINATION "${CMAKE_BINARY_DIR}/torch/csrc/distributed/c10d") +endforeach() diff --git a/src/xccl/Register.cpp b/src/xccl/Register.cpp new file mode 100644 index 000000000..3716c7a90 --- /dev/null +++ b/src/xccl/Register.cpp @@ -0,0 +1,313 @@ +#include +#include +#include +#include +#include + +namespace c10d { +namespace ops { +namespace { +c10::intrusive_ptr send_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t dstRank, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->send(tensor_vec, static_cast(dstRank), static_cast(tag)); +} + +c10::intrusive_ptr recv_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t srcRank, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->recv(tensor_vec, static_cast(srcRank), static_cast(tag)); +} + +c10::intrusive_ptr recv_any_source_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->recvAnysource(tensor_vec, static_cast(tag)); +} + +c10::intrusive_ptr reduce_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t root_rank, + int64_t root_tensor, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->reduce( + tensor_vec, + ReduceOptions{ + *reduce_op.get(), + root_rank, + root_tensor, + std::chrono::milliseconds(timeout)}); +} + +std::tuple, c10::intrusive_ptr> broadcast_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t root_tensor, + bool asyncOp, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + auto work = process_group->getBackend(c10::DeviceType::XPU) + ->broadcast( + tensor_vec, + BroadcastOptions{ + root_rank, + root_tensor, + std::chrono::milliseconds(timeout), + asyncOp}); + return std::tuple, c10::intrusive_ptr>( + std::move(tensor_vec), work); +} + +std::tuple, c10::intrusive_ptr> allreduce_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + const std::optional& sparse_indices, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->allreduce( + tensor_vec, + AllreduceOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); + return std::tuple, c10::intrusive_ptr>( + std::move(tensor_vec), work); +} + +c10::intrusive_ptr allreduce_coalesced_XPU( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + return process_group->getBackend(c10::DeviceType::XPU) + ->allreduce_coalesced(tensor_vec, opts); +} + +std::tuple>, c10::intrusive_ptr> +allgather_XPU( + const std::vector>& output_tensors, + at::TensorList input_tensors, + const c10::intrusive_ptr& process_group, + int64_t timeout) { + auto input_tensors_vec = input_tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->allgather( + const_cast>&>(output_tensors), + input_tensors_vec, + AllgatherOptions{std::chrono::milliseconds(timeout)}); + return std:: + tuple>, c10::intrusive_ptr>( + output_tensors, work); +} + +std::tuple> _allgather_base_XPU( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group, + bool asyncOp, + int64_t timeout) { + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->_allgather_base( + output_tensor, + input_tensor, + AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp}); + return std::tuple>(output_tensor, work); +} + +c10::intrusive_ptr allgather_coalesced_XPU( + const std::vector>& output_lists, + const at::TensorList& input_list, + const c10::intrusive_ptr& process_group) { + auto input_list_vec = input_list.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->allgather_coalesced( + const_cast>&>(output_lists), + input_list_vec); +} + +c10::intrusive_ptr allgather_into_tensor_coalesced_XPU( + at::TensorList outputs, + at::TensorList inputs, + const c10::intrusive_ptr& process_group) { + auto output_vec = outputs.vec(); + auto input_vec = inputs.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->allgather_into_tensor_coalesced(output_vec, input_vec); +} + +std::tuple, c10::intrusive_ptr> reduce_scatter_XPU( + const at::TensorList& output_tensors, + const std::vector>& input_tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto output_tensors_vec = output_tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->reduce_scatter( + output_tensors_vec, + const_cast>&>(input_tensors), + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); + return std::tuple, c10::intrusive_ptr>( + output_tensors_vec, work); +} + +std::tuple> _reduce_scatter_base_XPU( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + bool asyncOp, + int64_t timeout) { + auto work = process_group->getBackend(c10::DeviceType::XPU) + ->_reduce_scatter_base( + output_tensor, + input_tensor, + ReduceScatterOptions{ + *reduce_op.get(), + std::chrono::milliseconds(timeout), + asyncOp}); + return std::tuple>(output_tensor, work); +} + +c10::intrusive_ptr reduce_scatter_tensor_coalesced_XPU( + at::TensorList outputs, + at::TensorList inputs, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto output_vec = outputs.vec(); + auto input_vec = inputs.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->reduce_scatter_tensor_coalesced( + output_vec, + input_vec, + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr gather_XPU( + const std::vector>& output_tensors, + const at::TensorList& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t timeout) { + auto input_tensors_vec = input_tensors.vec(); + return process_group->getBackend(c10::DeviceType::XPU) + ->gather( + const_cast>&>(output_tensors), + input_tensors_vec, + GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); +} + +std::tuple, c10::intrusive_ptr> scatter_XPU( + const at::TensorList& output_tensors, + const std::vector>& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + bool asyncOp, + int64_t timeout) { + auto output_tensors_vec = output_tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::XPU) + ->scatter( + output_tensors_vec, + const_cast>&>(input_tensors), + ScatterOptions{ + root_rank, std::chrono::milliseconds(timeout), asyncOp}); + return std::tuple, c10::intrusive_ptr>( + std::move(output_tensors_vec), work); +} + +std::tuple, c10::intrusive_ptr> alltoall_XPU( + const at::TensorList& output_tensors, + const at::TensorList& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t timeout) { + auto output_tensors_vec = output_tensors.vec(); + auto input_tensors_vec = input_tensors.vec(); + auto work = process_group->getBackend(c10::DeviceType::XPU) + ->alltoall( + output_tensors_vec, + input_tensors_vec, + AllToAllOptions{std::chrono::milliseconds(timeout)}); + return std::tuple, c10::intrusive_ptr>( + std::move(output_tensors_vec), work); +} + +c10::intrusive_ptr alltoall_base_XPU( + at::Tensor& output, + at::Tensor& input, + const c10::intrusive_ptr& process_group, + std::vector output_split_sizes, + std::vector input_split_sizes, + int64_t timeout) { + return process_group->getBackend(c10::DeviceType::XPU) + ->alltoall_base( + output, + input, + output_split_sizes, + input_split_sizes, + AllToAllOptions{std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr barrier_XPU( + at::Tensor /* unused */, + const c10::intrusive_ptr& process_group, + const std::vector& device_ids, + int64_t timeout) { + return process_group->getBackend(c10::DeviceType::XPU) + ->barrier(BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); +} + +TORCH_LIBRARY_IMPL(c10d, XPU, m) { + m.impl("send", send_XPU); + m.impl("recv_", recv_XPU); + m.impl("recv_any_source_", recv_any_source_XPU); + m.impl("reduce_", reduce_XPU); + m.impl("broadcast_", broadcast_XPU); + m.impl("allreduce_", allreduce_XPU); + m.impl("allreduce_coalesced_", allreduce_coalesced_XPU); + m.impl("allgather_", allgather_XPU); + m.impl("_allgather_base_", _allgather_base_XPU); + m.impl("allgather_coalesced_", allgather_coalesced_XPU); + m.impl( + "allgather_into_tensor_coalesced_", allgather_into_tensor_coalesced_XPU); + m.impl("reduce_scatter_", reduce_scatter_XPU); + m.impl("_reduce_scatter_base_", _reduce_scatter_base_XPU); + m.impl( + "reduce_scatter_tensor_coalesced_", reduce_scatter_tensor_coalesced_XPU); + m.impl("gather_", gather_XPU); + m.impl("scatter_", scatter_XPU); + m.impl("alltoall_", alltoall_XPU); + m.impl("alltoall_base_", alltoall_base_XPU); + m.impl("barrier", barrier_XPU); +} +} // namespace + +} // namespace ops +} // namespace c10d \ No newline at end of file From 30f6cd2db60101dc739720b033b148bd464bacfa Mon Sep 17 00:00:00 2001 From: hanchao Date: Wed, 27 Nov 2024 05:15:05 +0000 Subject: [PATCH 07/19] update --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ea88dd75..fe279c1ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,6 +54,7 @@ endif() if(NOT WIN32 AND USE_XCCL) include(${TORCH_XPU_OPS_ROOT}/cmake/XCCL.cmake) set(USE_C10D_XCCL ON) + set(USE_C10D_XCCL ${USE_C10D_XCCL} PARENT_SCOPE) endif() if(BUILD_TEST) From fb851b1139bdcf1ffffd727f2baaa5d20b497907 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Thu, 5 Dec 2024 18:58:04 +0800 Subject: [PATCH 08/19] imple allreduce and strcture --- cmake/Modules/FindXCCL.cmake | 7 +- src/xccl/CMakeLists.txt | 1 + src/xccl/ProcessGroupXCCL.cpp | 1559 +-------------------------------- src/xccl/ProcessGroupXCCL.hpp | 187 +--- 4 files changed, 38 insertions(+), 1716 deletions(-) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index 29571065c..1881cf3aa 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -6,7 +6,12 @@ include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) -set(XCCL_ROOT $ENV{CCL_ROOT}) +set(XCCL_ROOT "/opt/intel/oneapi/ccl/latest") +if (NOT EXISTS "${XCCL_ROOT}") + message(STATUS "Default OneCCL not found, using current environment OneCCL") + set(XCCL_ROOT $ENV{CCL_ROOT}) +endif() + string(COMPARE EQUAL "${XCCL_ROOT}" "" nocclfound) if(nocclfound) set(XCCL_FOUND False) diff --git a/src/xccl/CMakeLists.txt b/src/xccl/CMakeLists.txt index 242cf7e26..7e16ea8ff 100644 --- a/src/xccl/CMakeLists.txt +++ b/src/xccl/CMakeLists.txt @@ -7,6 +7,7 @@ list(APPEND ATen_XPU_XCCL_SRCS ${xccl_cpp}) set(ATen_XPU_XCCL_SRCS ${ATen_XPU_XCCL_SRCS} PARENT_SCOPE) +# Copy the header file to the build directory so that the PyTorch registration file can locate it. foreach(HEADER ${xccl_h}) file(COPY ${HEADER} DESTINATION "${CMAKE_BINARY_DIR}/torch/csrc/distributed/c10d") endforeach() diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 71e26fbe0..4a82e3cbc 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -24,114 +24,24 @@ const std::map xcclDatatypes = { {at::kDouble, ccl::datatype::float64}, {at::kBFloat16, ccl::datatype::bfloat16}, {at::kBool, ccl::datatype::uint8}, - // use for allgather - {at::kFloat8_e5m2, ccl::datatype::uint8}, - {at::kFloat8_e4m3fn, ccl::datatype::uint8}, - {at::kFloat8_e4m3fnuz, ccl::datatype::uint8}, - {at::kFloat8_e5m2fnuz, ccl::datatype::uint8}, }; -bool computeLengthsAndCheckAndGetFlat( - const std::vector& tensors, - std::vector& lengths, - at::Tensor& flatTensor, - int64_t& flatLength) { - int64_t groupSize = tensors.size(); - auto firstTensor = tensors[0]; - int64_t totalSize = 0; - bool isFlat = true; - - auto storage = firstTensor.storage(); - int64_t firstStorageOffset = firstTensor.storage_offset(); - - for (int i = 0; i < groupSize; i++) { - auto& curTensor = tensors[i]; - int64_t length = curTensor.numel(); - lengths[i] = length; - totalSize += length; - - if (isFlat && - (!storage.is_alias_of(curTensor.storage()) || - curTensor.storage_offset() != - firstStorageOffset + totalSize - length)) { - isFlat = false; - } - } - - flatLength = totalSize; - - if (isFlat) { - flatTensor = firstTensor; - } else { - flatTensor = at::empty({totalSize}, firstTensor.options()); - } - - return isFlat; -} - -bool check_same_size(const std::vector& input_tensors) { - for (const auto& input_tensor : input_tensors) { - if (!input_tensors[0].is_same_size(input_tensor)) { - return false; - } - } - return true; -} - -void check_xpu_single_tensor( - const at::Tensor& tensor, - const bool p2p = false // whether operation is a P2P operation -) { +void checkXPUTensor(at::Tensor& tensor) { if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { C10_THROW_ERROR( ValueError, "Tensors must be XPU and dense and non-complex"); - - // Skip the following requirements for P2P operations if (!tensor.is_contiguous(tensor.suggest_memory_format())) { - if (p2p) { - TORCH_WARN_ONCE( - "Detected non-contiguous tensor in P2P operations. It is user " - "responsibility to guarantee that source and destination tensors have " - "the same contiguity format."); - } else { - C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); - } + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); } } } -int64_t check_xpu_tensors_same_device(const std::vector& tensors) { - TORCH_CHECK_WITH( - ValueError, tensors.size() != 0, "Tensor list must be nonempty"); - - const auto& first = tensors.front(); - - int64_t total_numel = 0; - for (const auto& t : tensors) { - if (!t.is_xpu() || t.is_sparse() || t.is_complex()) { - C10_THROW_ERROR( - ValueError, "Tensors must be XPU and dense and non-complex"); - } - if (t.scalar_type() != first.scalar_type()) { - C10_THROW_ERROR(TypeError, "Tensors must have identical type"); - } - TORCH_CHECK_WITH( - ValueError, - t.get_device() == tensors[0].get_device(), - "Expected list of tensors on the same device"); - total_numel += t.numel(); - } - - return total_numel; -} - ccl::datatype getXcclDataType( at::ScalarType type, bool is_reduction_op = false) { - if (is_reduction_op) - TORCH_CHECK( - !isFloat8Type(type), - "Float8 dtypes are not currenlty supported for XCCL reductions"); + TORCH_CHECK( + !isFloat8Type(type) && is_reduction_op, + "Float8 dtypes are not currenlty supported for XCCL reductions"); auto it = xcclDatatypes.find(type); TORCH_CHECK_WITH( TypeError, @@ -147,10 +57,6 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { // Map sum to max for bool tensors to avoid overflow issues with sum. return ccl::reduction::max; } - // WA due to oneCCL not support AVG - if (reduceOp == ReduceOp::AVG) { - return ccl::reduction::sum; - } return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { C10_THROW_ERROR( @@ -166,11 +72,9 @@ void syncStream( xcclEvent.record(at::xpu::getCurrentXPUStream(device.index())); xcclEvent.block(xcclStream); } - } // namespace constexpr int64_t kSynchronizeBusyWaitMillis = 10; -thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0; ProcessGroupXCCL::WorkXCCL::WorkXCCL( at::Device& device, @@ -225,10 +129,6 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } } - if (barrierTensor_.defined()) { - auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); - currentStream.synchronize(); - } } bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { @@ -236,9 +136,6 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { return true; } -constexpr const char* MULTI_DEVICE_ERROR_MSG = - "Expecting one tensor only but got multiple"; - ProcessGroupXCCL::ProcessGroupXCCL( const c10::intrusive_ptr& store, int rank, @@ -250,12 +147,6 @@ ProcessGroupXCCL::ProcessGroupXCCL( ProcessGroupXCCL::~ProcessGroupXCCL() = default; -void ProcessGroupXCCL::setSequenceNumberForGroup() {} - -uint64_t ProcessGroupXCCL::getSequenceNumberForGroup() { - return seqCollective_; -} - c10::intrusive_ptr ProcessGroupXCCL::initWork( at::Device& device, int rank, @@ -275,19 +166,12 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( std::shared_ptr ProcessGroupXCCL::getXCCLComm( const std::string& deviceKey, - at::Device& device, - OpType opType, - int p2pRank, - bool isSendRecvSelf) { - if (deviceKey.empty()) { - C10_THROW_ERROR( - DistBackendError, - "Not able to create/get the XCCL Communicator since " - "the devices are empty "); - } - - usedDeviceIdxs_.insert(device.index()); - + at::Device& device) { + TORCH_CHECK_WITH( + DistBackendError, + !deviceKey.empty(), + "Not able to create/get " + "XCCL Communicator since the devices are empty "); { std::lock_guard lock(mutex_); if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { @@ -295,24 +179,9 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( } } - std::shared_ptr XCCLComm; - - bool batchP2P = xcclActiveGroupCounter_ > 0; - bool singleP2POp = isP2POp(opType, batchP2P); - - at::xpu::OptionalXPUGuard gpuGuard(device); - int numRanks, rank; - if (!singleP2POp) { - numRanks = getSize(); - rank = getRank(); - } else if (isSendRecvSelf) { - numRanks = 1; - rank = 0; - } else { - numRanks = 2; - rank = p2pRank; - } + numRanks = getSize(); + rank = getRank(); c10::impl::VirtualGuardImpl impl(device.type()); c10::Stream stream = @@ -323,23 +192,10 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( ccl::vector_class> devs_rank; devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); - auto xccl_kvs = get_kvs(rank_, *store_, singleP2POp, deviceKey, p2pRank); + auto xccl_kvs = get_kvs(rank_, *store_); auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); - XCCLComm = std::make_shared(std::move(comms[0])); - - RECORD_PARAM_COMMS( - 0, // seq - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - rank, // rank - "init", // collective name - 0, // inNelems - 0, // outNelems - at::kByte, // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - size_); // worldSize + std::shared_ptr XCCLComm = + std::make_shared(std::move(comms[0])); std::lock_guard lock(mutex_); devXCCLCommMap_.emplace(deviceKey, XCCLComm); @@ -349,64 +205,6 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( return XCCLComm; } -void ProcessGroupXCCL::groupStart() { - ccl::group_start(); - ++xcclActiveGroupCounter_; -} - -void ProcessGroupXCCL::groupEnd() { - ccl::group_end(); - --xcclActiveGroupCounter_; -} - -// TODO: wait p2p enable -static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; -void ProcessGroupXCCL::startCoalescing() { - if (coalescing_state_ & CoalP2P) { - seqP2P_++; - } else { - seqCollective_++; - } - coalescedDevice_.set_index(-1); - coalescedComm_ = nullptr; - coalescing_state_ |= CoalActive; - groupStart(); -} - -c10::intrusive_ptr ProcessGroupXCCL::endCoalescing(OpType optype) { - if (coalescedComm_ == nullptr) { - // There is no actual work being coalesced, return here - groupEnd(); - coalescing_state_ = 0; - return nullptr; - } - TORCH_CHECK( - coalescedDevice_.index() >= 0, - "Somthing went wrong. Did you call end_coalescing before start_coalescing?"); - - auto comm = coalescedComm_; - auto device = coalescedDevice_; - - const auto key = std::to_string(device.index()); - auto stream = xcclStreamsMap_.at(key); - - auto work = initWork(device, rank_, optype); - work->blockingWait_ = blockingWait_; - - groupEnd(); - - work->xcclEndEvent_->record(stream); - - coalescing_state_ = 0; - coalescedComm_ = nullptr; - return work; -} - -c10::intrusive_ptr ProcessGroupXCCL::endCoalescing() { - // Default OpType to COALESCED if not specified - return endCoalescing(OpType::COALESCED); -} - template c10::intrusive_ptr ProcessGroupXCCL::collective( std::vector& inputs, @@ -417,49 +215,28 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( OpType opType, const char* profilingTitle) { seqCollective_++; + auto device = inputs[0].device(); const auto key = std::to_string(device.index()); - auto comm = getXCCLComm(key, device, opType); - - if (coalescing_state_ & CoalActive) { - coalescing_state_ |= CoalColl; - if (coalescedDevice_.index() < 0) { - coalescedDevice_ = device; - } else { - TORCH_CHECK( - coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); - } - if (coalescedComm_ == nullptr) { - coalescedComm_ = comm; - } else { - TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); - } - } + auto comm = getXCCLComm(key, device); auto stream = xcclStreamsMap_.at(key); syncStream(device, xcclEventsMap_[key], stream); c10::intrusive_ptr work; - work = initWork(device, rank_, opType); - + work = initWork(device, rank_, opType, profilingTitle); work->outputs_ = std::make_shared>(outputs); at::xpu::OptionalXPUGuard gpuGuard(device); - pre(stream, work); - for (const auto i : c10::irange(inputs.size())) { c10::xpu::XPUCachingAllocator::recordStream( inputs[i].storage().data_ptr(), stream); fn(inputs[i], outputs[i], *comm, stream); } - post(stream, work); - if (!coalescing_state_) { - work->xcclEndEvent_->record(stream); - } - + work->xcclEndEvent_->record(stream); std::vector streams = {stream.unwrap()}; c10::MultiStreamGuard streamGuard(streams); std::vector devices{device}; @@ -471,468 +248,13 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( return work; } -template -c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( - at::Tensor& tensor, - Fn fn, - int peer, - OpType opType, - const char* profilingTitle) { - auto device = tensor.device(); - std::string key; - int p2pRank = 0, p2pTargetRank = 0; - bool isSendRecvSelf = false; - - bool batchP2P = xcclActiveGroupCounter_ > 0; - if (batchP2P) { - key = std::to_string(device.index()); - p2pRank = rank_; - p2pTargetRank = peer; - } else { - int lowRank = rank_ < peer ? rank_ : peer; - int highRank = rank_ < peer ? peer : rank_; - key = std::to_string(lowRank) + ":" + std::to_string(highRank); - p2pRank = rank_ <= peer ? 0 : 1; - isSendRecvSelf = rank_ == peer; - p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; - if (!coalescing_state_) { - seqP2P_++; - } - } - - auto comm = getXCCLComm(key, device, opType, p2pRank, isSendRecvSelf); - - if (coalescing_state_ & CoalActive) { - coalescing_state_ |= CoalP2P; - if (coalescedDevice_.index() < 0) { - coalescedDevice_ = device; - } else { - TORCH_CHECK( - coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); - } - if (coalescedComm_ == nullptr) { - coalescedComm_ = comm; - } else { - TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); - } - } - - auto stream = xcclStreamsMap_.at(key); - syncStream(device, xcclEventsMap_[key], stream); - - if (!coalescing_state_) { - c10::intrusive_ptr work; - work = initWork(device, rank_, opType); - work->outputs_ = std::make_shared>(); - work->outputs_->push_back(tensor); - - at::xpu::OptionalXPUGuard gpuGuard(device); - - c10::xpu::XPUCachingAllocator::recordStream( - tensor.storage().data_ptr(), stream); - - fn(tensor, *comm, stream, p2pTargetRank); - - work->xcclEndEvent_->record(stream); - work->blockingWait_ = blockingWait_; - std::vector streams = {stream.unwrap()}; - c10::MultiStreamGuard streamGuard(streams); - std::vector devices{device}; - work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), devices); - work->future_->markCompleted(at::IValue(*work->outputs_)); - return work; - } else { - at::xpu::OptionalXPUGuard gpuGuard(device); - - c10::xpu::XPUCachingAllocator::recordStream( - tensor.storage().data_ptr(), stream); - - fn(tensor, *comm, stream, p2pTargetRank); - - return nullptr; - } -} - -c10::intrusive_ptr ProcessGroupXCCL::send( - std::vector& tensors, - int dstRank, - int /* unused */) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto tensor = tensors.back(); - check_xpu_single_tensor(tensor, true); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - dstRank, // dst rank - "send", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - auto ret = pointToPoint( - tensor, - [&](at::Tensor& input, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - int dst) { - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::send( - input.data_ptr(), - (size_t)input.numel(), - xcclDataType, - dst, - comm, - ccl::create_stream(stream.queue())); - return; - }, - dstRank, - OpType::SEND, - c10::str("xccl:send ", rank_, "->", dstRank).c_str()); - return ret; -} - -c10::intrusive_ptr ProcessGroupXCCL::recv( - std::vector& tensors, - int srcRank, - int /* unused */) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto tensor = tensors.back(); - check_xpu_single_tensor(tensor, true); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - srcRank, // src rank - "recv", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - auto ret = pointToPoint( - tensor, - [&](at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - int src) { - auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::recv( - output.data_ptr(), - (size_t)output.numel(), - xcclDataType, - src, - comm, - ccl::create_stream(stream.queue())); - return; - }, - srcRank, - OpType::RECV, - c10::str("xccl:recv ", rank_, "<-", srcRank).c_str()); - return ret; -} - -c10::intrusive_ptr ProcessGroupXCCL::gather( - std::vector>& outputTensors, - std::vector& inputTensors, - const GatherOptions& opts) { - static auto invalidArgument = [](const std::string& msg) { - C10_THROW_ERROR(ValueError, "ProcessGroupXCCL::gather: " + msg); - }; - - assertRootRank(invalidArgument, opts.rootRank, size_); - - TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto inputTensor = inputTensors.back(); - - std::vector outputs; - - if (getRank() == opts.rootRank) { - if (outputTensors.size() != 1) { - std::stringstream ss; - ss << "requires a single-element output list containing a list with " - << getSize() << " tensors."; - invalidArgument(ss.str()); - } else if (outputTensors[0].size() != static_cast(getSize())) { - std::stringstream ss; - ss << "Incorrect output list size " << outputTensors[0].size() - << ". Output list size should be " << getSize() - << ", same as size of the process group."; - invalidArgument(ss.str()); - } - - const auto& options = inputTensor.options(); - const auto& sizes = inputTensor.sizes(); - assertTypeAndSizesMatch(invalidArgument, outputTensors[0], options, sizes); - outputs = outputTensors[0]; - } else { - // if not in the root rank, initialize outputs as empty list - if (outputTensors.size() != 0) { - invalidArgument("requires empty output on non-root"); - } - outputs = {}; - // append a empty tensor to the list, we don't use it but the - // `collective` template function requires it to invoke its function - outputs.emplace_back(); - } - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - opts.rootRank, // root rank - "gather", // collective name - inputTensor.numel(), // inNelems - inputTensor.numel() * this->getSize(), // outNelems - inputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - auto inputs = std::vector{inputTensor}; - return collective( - inputs, - outputs, // just to fit the collective interface - [&](at::Tensor& /* unused */, - at::Tensor& /* unused */, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - const auto root = opts.rootRank; - if (getRank() == root) { - for (auto output : outputs) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } - } - { - auto xcclDataType = getXcclDataType(inputTensor.scalar_type()); - if (rank_ == root) { - for (const auto r : c10::irange(size_)) { - if (r != root) { - // do receive - ccl::recv( - outputs[r].data_ptr(), - (size_t)inputTensor.numel(), - xcclDataType, - r, - comm, - ccl::create_stream(stream.queue())); - } else { - // on its own rank, simply copy from the input - outputs[r].copy_(inputTensor); - } - } - } else { - // do send - ccl::send( - inputTensor.data_ptr(), - (size_t)inputTensor.numel(), - xcclDataType, - root, - comm, - ccl::create_stream(stream.queue())); - } - return; - } - }, - OpType::GATHER); -} - -c10::intrusive_ptr ProcessGroupXCCL::scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ScatterOptions& opts) { - static auto invalidArgument = [](const std::string& msg) { - C10_THROW_ERROR(ValueError, "ProcessGroupXCCL::scatter: " + msg); - }; - - assertRootRank(invalidArgument, opts.rootRank, size_); - - TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto outputTensor = outputTensors.back(); - - std::vector inputs; - - if (getRank() == opts.rootRank) { - if (inputTensors.size() != 1) { - std::stringstream ss; - ss << "requires a single-element input list containing a list with " - << getSize() << " tensors."; - invalidArgument(ss.str()); - } else if (inputTensors[0].size() != static_cast(getSize())) { - std::stringstream ss; - ss << "Incorrect input list size " << inputTensors[0].size() - << ". Input list size should be " << getSize() - << ", same as size of the process group."; - invalidArgument(ss.str()); - } - - const auto& options = outputTensor.options(); - const auto& sizes = outputTensor.sizes(); - assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes); - inputs = inputTensors[0]; - } else { - // if not in the root rank, initialize inputTensors as empty place holder - // with an empty list - if (inputTensors.size() != 0) { - invalidArgument("requires empty input on non-root"); - } - inputs = {}; - // append a empty tensor to the list, we don't use it but the - // `collective` template function requires it to invoke its function - inputs.emplace_back(); - } - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - opts.rootRank, // root rank - "scatter", // collective name - outputTensor.numel() * this->getSize(), // inNelems - outputTensor.numel(), // outNelems - outputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - const auto root = opts.rootRank; - - auto outputs = std::vector{outputTensor}; - return collective( - outputs, - inputs, // just to fit the collective interface - [&](at::Tensor& /* unused */, - at::Tensor& /* unused */, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - if (getRank() == root) { - for (auto input : inputs) { - c10::xpu::XPUCachingAllocator::recordStream( - input.storage().data_ptr(), stream); - } - } - { - if (rank_ == root) { - for (const auto r : c10::irange(size_)) { - if (r != root) { - // do send - size_t send_count = inputs[r].numel(); - auto send_type = getXcclDataType(inputs[r].scalar_type()); - ccl::send( - inputs[r].data_ptr(), - send_count, - send_type, - r, - comm, - ccl::create_stream(stream.queue())); - } else { - // on its own rank, simply copy from the input - outputTensor.copy_(inputs[r]); - } - } - } else { - // do receive - size_t recv_count = outputTensor.numel(); - auto recv_type = getXcclDataType(outputTensor.scalar_type()); - ccl::recv( - outputTensor.data_ptr(), - recv_count, - recv_type, - root, - comm, - ccl::create_stream(stream.queue())); - } - - return; - } - }, - OpType::SCATTER); -} - -c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( - at::Tensor& tensor, - const AllreduceOptions& opts) { - return collective( - tensor, - tensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto ccl_stream = ccl::create_stream(stream.queue()); - ccl::allreduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - OpType::ALLREDUCE, - "xccl:all_reduce"); -} - c10::intrusive_ptr ProcessGroupXCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + TORCH_CHECK( + tensors.size() == 1, "Expecting one tensor only but got multiple"); auto tensor = tensors.back(); - check_xpu_single_tensor(tensor); - - // @lint-ignore CLANGTIDY - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - rank_, // rank - "allreduce", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - size_); // worldSize + checkXPUTensor(tensor); return collective( tensor, @@ -943,6 +265,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + auto ccl_stream = ccl::create_stream(stream.queue()); ccl::allreduce( input.data_ptr(), output.data_ptr(), @@ -950,845 +273,13 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( xcclDataType, xcclReduceOp, comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - output.div_(divisor); - } + ccl_stream); return; }, OpType::ALLREDUCE, "xccl:all_reduce"); } -c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( - std::vector& tensors, - const AllreduceCoalescedOptions& opts) { - auto total_numel = check_xpu_tensors_same_device(tensors); - - // @lint-ignore CLANGTIDY - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - rank_, // rank - "allreduce_coalesced", // collective name - total_numel, // inNelems - total_numel, // outNelems - tensors[0].scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collectiveCoalesced( - tensors, - tensors, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::allreduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - OpType::COALESCED, - "xccl:allreduce_coalesced"); -} - -c10::intrusive_ptr ProcessGroupXCCL::broadcast( - std::vector& tensors, - const BroadcastOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto tensor = tensors.back(); - check_xpu_single_tensor(tensor); - - // @lint-ignore CLANGTIDY - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - opts.rootRank, // root rank - "broadcast", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - const auto root = opts.rootRank + opts.rootTensor; - - return collective( - tensor, - tensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::broadcast( - input.data_ptr(), - (size_t)input.numel(), - xcclDataType, - root, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::BROADCAST, - "nccl:broadcast"); -} - -c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const BroadcastOptions& opts) { - if (outputTensor.numel() != inputTensor.numel()) { - C10_THROW_ERROR( - ValueError, - "Tensor input and output of _broadcast_oop must have the same number of elements "); - } - const auto root = opts.rootRank + opts.rootTensor; - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::broadcast( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - root, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::BROADCAST, - "xccl:_broadcast_oop"); -} - -c10::intrusive_ptr ProcessGroupXCCL::reduce( - std::vector& tensors, - const ReduceOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto tensor = tensors.back(); - check_xpu_single_tensor(tensor); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - opts.rootRank, // root rank - "reduce", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - tensor, - tensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = getXcclDataType(input.scalar_type(), true); - const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - root, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - OpType::REDUCE, - "xccl:reduce"); -} - -c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceOptions& opts) { - TORCH_CHECK_WITH( - ValueError, - outputTensor.numel() == inputTensor.numel(), - "Tensor input and output of _reduce_oop must have the same number of elements"); - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = getXcclDataType(input.scalar_type(), true); - const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - root, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - OpType::REDUCE, - "xccl:_reduce_oop"); -} - -c10::intrusive_ptr ProcessGroupXCCL::allgather( - std::vector>& outputTensors, - std::vector& inputTensors, - const AllgatherOptions& opts) { - TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto inputTensor = inputTensors.back(); - check_xpu_single_tensor(inputTensor); - // @lint-ignore CLANGTIDY - std::vector& outputTensors_ = outputTensors.back(); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - rank_, // rank - "all_gather", // collective name - inputTensor.numel(), // inNelems - inputTensor.numel() * // outNelems - this->getSize(), - inputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - bool same_size = check_same_size(outputTensors_); - if (same_size) { - // Flatten a vector of tensors into a single, stacked tensor. - at::Tensor outputFlattened = newLikeFlat(outputTensors_); - - return collective( - inputTensor, - outputFlattened, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::allgather( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - return; - }, - [](at::xpu::XPUStream&, - c10::intrusive_ptr& work) {}, - [&](at::xpu::XPUStream& Stream, - c10::intrusive_ptr& work) { - // Copy the flattened output tensors to the outputs. - c10::StreamGuard guard(Stream); - for (const auto j : c10::irange(outputTensors_.size())) { - c10::xpu::XPUCachingAllocator::recordStream( - outputTensors_[j].storage().data_ptr(), Stream); - outputTensors_[j].copy_(outputFlattened[j], true); - } - }, - OpType::ALLGATHER, - "xccl:all_gather"); - } else { - const auto num_reduces = outputTensors_.size(); - startCoalescing(); - for (const int i : c10::irange(num_reduces)) { - auto& output = outputTensors_[i]; - auto& input = (i == rank_) ? inputTensor : output; - auto broadcastOpts = BroadcastOptions{ - static_cast(i), static_cast(0), opts.timeout}; - _broadcast_oop(output, input, broadcastOpts); - } - auto work = endCoalescing(OpType::ALLGATHER); - return work; - } -} - -c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( - at::Tensor& output_tensor, - at::Tensor& input_tensor, - const AllgatherOptions& opts) { - check_xpu_single_tensor(input_tensor); - check_xpu_single_tensor(output_tensor); - - TORCH_CHECK_WITH( - TypeError, - input_tensor.dtype() == output_tensor.dtype(), - "output tensor must have the same type as input tensor"); - TORCH_CHECK_WITH( - ValueError, - input_tensor.numel() * size_ == output_tensor.numel(), - "output tensor size must be equal to world_size times input tensor size"); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - input_tensor, // inputTensors - output_tensor, // outputTensors - rank_, // rank - "_allgather_base", // collective name - input_tensor.numel(), // inNelems - output_tensor.numel(), // outNelems - output_tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - input_tensor, - output_tensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::allgather( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::_ALLGATHER_BASE, - "xccl:_all_gather_base"); -} - -c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const AllgatherOptions& opts) { - return collectiveCoalesced( - inputs, - outputs, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::allgather( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::COALESCED, - "xccl:all_gather_into_tensor_coalesced"); -} - -c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ReduceScatterOptions& opts) { - TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto outputTensor = outputTensors.back(); - check_xpu_single_tensor(outputTensor); - // @lint-ignore CLANGTIDY - auto inputTensors_ = inputTensors.back(); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - rank_, // rank - "reduce_scatter", // collective name - outputTensor.numel() * this->getSize(), // inNelems - outputTensor.numel(), // outNelems - outputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - bool same_size = check_same_size(inputTensors_); - if (same_size) { - // Flatten a vector of tensors into a single, stacked tensor. - at::Tensor inputFlattened = newLikeFlat(inputTensors_); - return collective( - inputFlattened, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce_scatter( - input.data_ptr(), - output.data_ptr(), - (size_t)output.numel(), - xcclDataType, - xcclReduceOp, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - [&](at::xpu::XPUStream& Stream, - c10::intrusive_ptr& work) { - // Copy the input tensors to the flattened inputs. - c10::StreamGuard guard(Stream); - for (const auto j : c10::irange(inputTensors_.size())) { - c10::xpu::XPUCachingAllocator::recordStream( - inputTensors_[j].storage().data_ptr(), Stream); - inputFlattened[j].copy_(inputTensors_[j], true); - } - }, - [&](at::xpu::XPUStream&, - c10::intrusive_ptr&) {}, - OpType::REDUCE_SCATTER, - "xccl:reduce_scatter"); - } else { - const auto num_reduces = inputTensors_.size(); - startCoalescing(); - for (const int i : c10::irange(num_reduces)) { - auto& input = inputTensors_[i]; - auto& output = (i == rank_) ? outputTensor : input; - auto reduceOpts = ReduceOptions{ - opts.reduceOp, - static_cast(i), - static_cast(0), - opts.timeout}; - _reduce_oop(output, input, reduceOpts); - } - auto work = endCoalescing(OpType::REDUCE_SCATTER); - return work; - } -} - -c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts) { - TORCH_CHECK_WITH( - TypeError, - inputTensor.dtype() == outputTensor.dtype(), - "input tensor must be the same type as the output tensor."); - TORCH_CHECK_WITH( - ValueError, - inputTensor.numel() == outputTensor.numel() * size_, - "input tensor must be the same size as output size times world size"); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "_reduce_scatter_base", // collective name - inputTensor.numel(), // inNelems - outputTensor.numel(), // outNelems - outputTensor.scalar_type(), // dtype - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce_scatter( - input.data_ptr(), - output.data_ptr(), - (size_t)output.numel(), - xcclDataType, - xcclReduceOp, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - OpType::_REDUCE_SCATTER_BASE, - "xccl:_reduce_scatter_base"); -} - -c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const ReduceScatterOptions& opts) { - return collectiveCoalesced( - inputs, - outputs, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce_scatter( - input.data_ptr(), - output.data_ptr(), - (size_t)output.numel(), - xcclDataType, - xcclReduceOp, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - OpType::COALESCED, - "xccl:reduce_scatter_tensor_coalesced"); -} - -c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { - RECORD_PARAM_COMMS( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - rank_, // rank - "barrier", // collective name - 0, // inNelems - 0, // outNelems - at::kByte, // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - // Device to use for barrier - int barDevIdx = -1; - - // See nccl barrier comments - if (!opts.device_ids.empty()) { - barDevIdx = opts.device_ids[0]; - } else if (getBoundDeviceId()) { - barDevIdx = (*getBoundDeviceId()).index(); - } else if (!usedDeviceIdxs_.empty()) { - barDevIdx = *usedDeviceIdxs_.begin(); - } else { - barDevIdx = - static_cast(rank_ % at::detail::getXPUHooks().getNumGPUs()); - } - - // todo: use barrier instead of allreduce - TORCH_CHECK_WITH( - ValueError, - barDevIdx >= 0, - "Failed to infer a GPU device id to perform barrier. "); - auto barDevice = at::Device(at::DeviceType::XPU, barDevIdx); - - at::Tensor barrierTensor = - at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); - - auto work = allreduce_impl(barrierTensor); - - auto xcclWork = dynamic_cast(work.get()); - TORCH_CHECK(xcclWork); - xcclWork->barrierTensor_ = std::move(barrierTensor); - return work; -} - -c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - std::vector& outputSplitSizes, - std::vector& inputSplitSizes, - const AllToAllOptions& /* unused */) { - check_xpu_single_tensor(outputTensor, true); - check_xpu_single_tensor(inputTensor, true); - if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "all_to_all", // collective name - inputTensor.numel(), // inNelems - outputTensor.numel(), // outNelems - inputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - TORCH_CHECK( - outputTensor.numel() == inputTensor.numel() && - outputTensor.scalar_type() == inputTensor.scalar_type(), - "xpu_alltoall_base: tensors are not equal in size or data type"); - TORCH_CHECK( - outputTensor.size(0) % size_ == 0, - "xpu_alltoall_base: tensor's dim 0 does not divide equally across group size"); - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::alltoall( - input.data_ptr(), - output.data_ptr(), - (size_t)output.numel() / comm.size(), - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::ALLTOALL_BASE, - "xccl:all_to_all"); - } else { - c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); - c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "all_to_allv", // collective name - inputTensor.numel(), // inNelems - outputTensor.numel(), // outNelems - inputTensor.scalar_type(), // dType - inputSplitSizes, // inSplitSizes - outputSplitSizes, // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - std::vector sendCounts(size_); - std::vector recvCounts(size_); - bool inputSplitsEqual = inputSplitSizes.size() == 0; - bool outputSplitsEqual = outputSplitSizes.size() == 0; - - size_t inLen = input.numel(); - size_t outLen = output.numel(); - if (inLen) - inLen /= (inputSplitsEqual ? size_ : input.size(0)); - if (outLen) - outLen /= (outputSplitsEqual ? size_ : output.size(0)); - - for (int i = 0; i < size_; i++) { - sendCounts[i] = - (inputSplitsEqual ? inLen : inputSplitSizes[i] * inLen); - recvCounts[i] = - (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); - } - auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::alltoallv( - input.data_ptr(), - sendCounts, - output.data_ptr(), - recvCounts, - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::ALLTOALL_BASE, - "xccl:all_to_all"); - } -} - -c10::intrusive_ptr ProcessGroupXCCL::alltoall( - std::vector& outputTensors, - std::vector& inputTensors, - const AllToAllOptions& /* unused */) { - auto device = outputTensors[0].device(); - int64_t total_numel = 0; - for (const auto r : c10::irange(outputTensors.size())) { - check_xpu_single_tensor(outputTensors[r], true); - check_xpu_single_tensor(inputTensors[r], true); - TORCH_CHECK( - device == outputTensors[r].device() && - device == inputTensors[r].device(), - "Tensors must be on the same device") - total_numel += inputTensors[r].numel(); - } - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - rank_, // rank - "all_to_all", // collective name - total_numel, // inNelems - total_numel, // outNelems - inputTensors.front().scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - inputTensors, - outputTensors, - [&](at::Tensor& /* unused */, - at::Tensor& /* unused */, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::OptionalStreamGuard stream_guard(stream.unwrap()); - at::Tensor flatInput; - at::Tensor flatOutput; - - std::vector sendCounts(size_); - std::vector recvCounts(size_); - - int64_t flatSendCount; - int64_t flatRecvCount; - - bool isInputFlat = computeLengthsAndCheckAndGetFlat( - inputTensors, sendCounts, flatInput, flatSendCount); - bool isOutputFlat = computeLengthsAndCheckAndGetFlat( - outputTensors, recvCounts, flatOutput, flatRecvCount); - if (!isInputFlat) { - auto flatInputSplits = flatInput.split_with_sizes( - c10::IntArrayRef((int64_t*)sendCounts.data(), sendCounts.size()), - 0); - - for (int i = 0; i < size_; i++) { - flatInputSplits[i].copy_(inputTensors[i].view({-1})); - } - } - - auto xcclDataType = getXcclDataType(flatOutput.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::alltoallv( - flatInput.data_ptr(), - sendCounts, - flatOutput.data_ptr(), - recvCounts, - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - - if (!isOutputFlat) { - ret_evt.wait(); - auto flatOutputSplits = flatOutput.split_with_sizes( - c10::IntArrayRef((int64_t*)recvCounts.data(), recvCounts.size()), - 0); - - for (int i = 0; i < size_; i++) { - outputTensors[i].view({-1}).copy_(flatOutputSplits[i]); - } - } - stream.synchronize(); - return; - }, - OpType::ALLTOALL, - "xccl:all_to_all"); -} - } // namespace c10d #endif // USE_C10D_XCCL diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index e2099e658..21269bd6f 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -26,7 +26,6 @@ static std::vector TORCH_XCCL_BLOCKING_WAIT = { "XCCL_BLOCKING_WAIT"}; using xcclComm_t = ccl::communicator; -using XCCL_KVS = ccl::shared_ptr_class; constexpr const char* XCCL_BACKEND_NAME = "xccl"; class TORCH_API ProcessGroupXCCL : public Backend { @@ -68,7 +67,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { protected: at::Device device_; std::shared_ptr xcclEndEvent_; - at::Tensor barrierTensor_; bool blockingWait_ = false; std::chrono::time_point workStartTime_; uint64_t seq_; @@ -95,18 +93,9 @@ class TORCH_API ProcessGroupXCCL : public Backend { return std::string(XCCL_BACKEND_NAME); } - void startCoalescing() override; - - c10::intrusive_ptr endCoalescing() override; - - c10::intrusive_ptr endCoalescing(OpType optype); - std::shared_ptr getXCCLComm( const std::string& deviceKey, - at::Device& device, - OpType opType, - int p2pRank = 0, - bool isSendRecvSelf = false); + at::Device& device); virtual c10::intrusive_ptr initWork( at::Device& device, @@ -123,39 +112,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { Fn fn, OpType opType, const char* profilingTitle = nullptr) { - return collective( - input, - output, - fn, - [](at::xpu::XPUStream&, - c10::intrusive_ptr&) {}, - [](at::xpu::XPUStream&, - c10::intrusive_ptr&) {}, - opType, - profilingTitle); - } - - template - c10::intrusive_ptr collective( - at::Tensor& input, - at::Tensor& output, - Fn fn, - PreProcess pre, - PostProcess post, - OpType opType, - const char* profilingTitle = nullptr) { auto inputs = std::vector{input}; auto outputs = std::vector{output}; - return collective(inputs, outputs, fn, pre, post, opType, profilingTitle); - } - - template - c10::intrusive_ptr collective( - std::vector& inputs, - std::vector& outputs, - Fn fn, - OpType opType, - const char* profilingTitle = nullptr) { return collective( inputs, outputs, @@ -164,8 +122,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr&) {}, [](at::xpu::XPUStream&, c10::intrusive_ptr&) {}, - opType, - profilingTitle); + opType); } template @@ -178,140 +135,14 @@ class TORCH_API ProcessGroupXCCL : public Backend { OpType opType, const char* profilingTitle = nullptr); - template - c10::intrusive_ptr collectiveCoalesced( - std::vector& input, - std::vector& output, - Fn fn, - OpType opType, - const char* profilingTitle = nullptr) { - return collective( - input, - output, - fn, - [](at::xpu::XPUStream&, - c10::intrusive_ptr&) { - ccl::group_start(); - }, - [](at::xpu::XPUStream&, - c10::intrusive_ptr&) { - ccl::group_end(); - }, - opType, - profilingTitle); - } - - template - c10::intrusive_ptr pointToPoint( - at::Tensor& tensor, - Fn fn, - int peer, - OpType opType, - const char* profilingTitle = nullptr); - - c10::intrusive_ptr allreduce_impl( - at::Tensor& tensor, - const AllreduceOptions& opts = AllreduceOptions()); - c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - c10::intrusive_ptr allreduce_coalesced( - std::vector& tensors, - const AllreduceCoalescedOptions& opts = - AllreduceCoalescedOptions()) override; - - c10::intrusive_ptr reduce( - std::vector& tensors, - const ReduceOptions& opts = ReduceOptions()) override; - - c10::intrusive_ptr _reduce_oop( - at::Tensor& outputTensors, - at::Tensor& inputTensors, - const ReduceOptions& opts = ReduceOptions()); - - c10::intrusive_ptr broadcast( - std::vector& tensors, - const BroadcastOptions& opts = BroadcastOptions()) override; - - c10::intrusive_ptr _broadcast_oop( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const BroadcastOptions& opts); - - c10::intrusive_ptr allgather( - std::vector>& outputTensors, - std::vector& inputTensors, - const AllgatherOptions& opts = AllgatherOptions()) override; - - c10::intrusive_ptr _allgather_base( - at::Tensor& outputbuffer, - at::Tensor& inputbuffer, - const AllgatherOptions& opts = AllgatherOptions()) override; - - c10::intrusive_ptr allgather_into_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const AllgatherOptions& opts = AllgatherOptions()) override; - - c10::intrusive_ptr reduce_scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - - c10::intrusive_ptr _reduce_scatter_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - - c10::intrusive_ptr reduce_scatter_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - - c10::intrusive_ptr barrier( - const BarrierOptions& opts = BarrierOptions()) override; - - c10::intrusive_ptr alltoall_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - std::vector& outputSplitSizes, - std::vector& inputSplitSizes, - const AllToAllOptions& opts = AllToAllOptions()) override; - - c10::intrusive_ptr alltoall( - std::vector& outputTensors, - std::vector& inputTensors, - const AllToAllOptions& opts = AllToAllOptions()) override; - - c10::intrusive_ptr send( - std::vector& tensors, - int dstRank, - int tag) override; - - c10::intrusive_ptr recv( - std::vector& tensors, - int srcRank, - int tag) override; - - void groupStart(); - - void groupEnd(); - - c10::intrusive_ptr gather( - std::vector>& outputTensors, - std::vector& inputTensors, - const GatherOptions& opts = GatherOptions()) override; - - c10::intrusive_ptr scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ScatterOptions& opts = ScatterOptions()) override; - - void setSequenceNumberForGroup() override; - - uint64_t getSequenceNumberForGroup() override; + void setSequenceNumberForGroup() override {} + uint64_t getSequenceNumberForGroup() override { + return seqCollective_; + } protected: std::unordered_map xcclStreamsMap_; @@ -320,14 +151,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr store_; uint64_t xcclCommCounter_{0}; std::mutex mutex_; - std::set usedDeviceIdxs_; - int coalescing_state_ = 0; - at::Device coalescedDevice_ = at::Device("xpu"); - std::shared_ptr coalescedComm_ = nullptr; bool blockingWait_ = false; - static thread_local uint64_t xcclActiveGroupCounter_; uint64_t seqCollective_{0}; - uint64_t seqP2P_{0}; private: std::mutex kvs_mutex; From b1aee2627818bc6419c872c1fbd91b9c21438ba7 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Fri, 13 Dec 2024 21:31:36 +0800 Subject: [PATCH 09/19] add non-reduction datatype --- src/xccl/ProcessGroupXCCL.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 4a82e3cbc..024bc5a11 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -24,6 +24,11 @@ const std::map xcclDatatypes = { {at::kDouble, ccl::datatype::float64}, {at::kBFloat16, ccl::datatype::bfloat16}, {at::kBool, ccl::datatype::uint8}, + // use for allgather + {at::kFloat8_e5m2, ccl::datatype::uint8}, + {at::kFloat8_e4m3fn, ccl::datatype::uint8}, + {at::kFloat8_e4m3fnuz, ccl::datatype::uint8}, + {at::kFloat8_e5m2fnuz, ccl::datatype::uint8}, }; void checkXPUTensor(at::Tensor& tensor) { From c55b16e1040442ab8b977a9aa8a4089dc9094486 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Fri, 13 Dec 2024 21:45:06 +0800 Subject: [PATCH 10/19] add comment --- src/xccl/ProcessGroupXCCL.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 024bc5a11..54db563c1 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -24,7 +24,7 @@ const std::map xcclDatatypes = { {at::kDouble, ccl::datatype::float64}, {at::kBFloat16, ccl::datatype::bfloat16}, {at::kBool, ccl::datatype::uint8}, - // use for allgather + // use for non-reducetion op like allgather {at::kFloat8_e5m2, ccl::datatype::uint8}, {at::kFloat8_e4m3fn, ccl::datatype::uint8}, {at::kFloat8_e4m3fnuz, ccl::datatype::uint8}, From d139548f85b304882697924402d8473cb55ef9bf Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Fri, 13 Dec 2024 22:58:51 +0800 Subject: [PATCH 11/19] Simply cmake logit --- CMakeLists.txt | 13 ------------- cmake/Modules/FindXCCL.cmake | 11 ++++------- cmake/XCCL.cmake | 2 ++ src/xccl/CMakeLists.txt | 5 ++++- 4 files changed, 10 insertions(+), 21 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index fe279c1ef..9874c48c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,22 +39,9 @@ include(${TORCH_XPU_OPS_ROOT}/cmake/SYCL.cmake) include(${TORCH_XPU_OPS_ROOT}/cmake/BuildFlags.cmake) option(USE_XCCL "Build with XCCL support" ON) -if (DEFINED ENV{USE_XCCL}) - string(TOLOWER "$ENV{USE_XCCL}" USE_XCCL_LOWER) - - if (NOT (USE_XCCL_LOWER STREQUAL "1" OR - USE_XCCL_LOWER STREQUAL "on" OR - USE_XCCL_LOWER STREQUAL "yes")) - set(USE_XCCL OFF CACHE BOOL "Build with XCCL support" FORCE) - else() - set(USE_XCCL ON CACHE BOOL "Build with XCCL support" FORCE) - endif() -endif() if(NOT WIN32 AND USE_XCCL) include(${TORCH_XPU_OPS_ROOT}/cmake/XCCL.cmake) - set(USE_C10D_XCCL ON) - set(USE_C10D_XCCL ${USE_C10D_XCCL} PARENT_SCOPE) endif() if(BUILD_TEST) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index 1881cf3aa..063670a90 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -12,11 +12,9 @@ if (NOT EXISTS "${XCCL_ROOT}") set(XCCL_ROOT $ENV{CCL_ROOT}) endif() -string(COMPARE EQUAL "${XCCL_ROOT}" "" nocclfound) -if(nocclfound) +if(NOT DEFINED $ENV{CCL_ROOT}) set(XCCL_FOUND False) - set(XCCL_REASON_FAILURE "OneCCL library not found!!") - set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") + set(XCCL_NOT_FOUND_MESSAGE "OneCCL library not found!!") return() endif() @@ -56,8 +54,7 @@ find_library( if((NOT XCCL_INCLUDE_DIR) OR (NOT XCCL_LIBRARY_DIR) OR (NOT XCCL_LIBRARY)) set(XCCL_FOUND False) - set(XCCL_REASON_FAILURE "OneCCL library not found!!") - set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") + set(XCCL_NOT_FOUND_MESSAGE "OneCCL library not found!!") return() endif() @@ -65,6 +62,6 @@ find_package_handle_standard_args( XCCL FOUND_VAR XCCL_FOUND REQUIRED_VARS XCCL_INCLUDE_DIR XCCL_LIBRARY_DIR XCCL_LIBRARY - REASON_FAILURE_MESSAGE "${XCCL_REASON_FAILURE}" + REASON_FAILURE_MESSAGE "${XCCL_NOT_FOUND_MESSAGE}" ) diff --git a/cmake/XCCL.cmake b/cmake/XCCL.cmake index 50e1bdcf5..94979c438 100644 --- a/cmake/XCCL.cmake +++ b/cmake/XCCL.cmake @@ -15,6 +15,8 @@ if(NOT __XCCL_INCLUDED) set_property( TARGET torch::xccl PROPERTY INTERFACE_LINK_LIBRARIES ${XCCL_LIBRARY}) + set(USE_C10D_XCCL ON) + set(USE_C10D_XCCL ${USE_C10D_XCCL} PARENT_SCOPE) endif() endif() diff --git a/src/xccl/CMakeLists.txt b/src/xccl/CMakeLists.txt index 7e16ea8ff..f147b55ca 100644 --- a/src/xccl/CMakeLists.txt +++ b/src/xccl/CMakeLists.txt @@ -7,7 +7,10 @@ list(APPEND ATen_XPU_XCCL_SRCS ${xccl_cpp}) set(ATen_XPU_XCCL_SRCS ${ATen_XPU_XCCL_SRCS} PARENT_SCOPE) -# Copy the header file to the build directory so that the PyTorch registration file can locate it. +# Why copy the header file to the build directory? +# We want register XCCL backend to PyTorch c10d in torch/csrc/distributed/c10d/init.cpp#L27-L29. +# To align with other backends, we need to copy the header file to the build torch/csrc/distributed/c10d directory. +# Further solution is add find path for torch/csrc/distributed/c10d/init.cpp#L27-L29. foreach(HEADER ${xccl_h}) file(COPY ${HEADER} DESTINATION "${CMAKE_BINARY_DIR}/torch/csrc/distributed/c10d") endforeach() From b8e9f30917ee03f68c7a921858a3569ad3e7c613 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Fri, 13 Dec 2024 23:18:49 +0800 Subject: [PATCH 12/19] update --- cmake/Modules/FindXCCL.cmake | 7 ------- cmake/XCCL.cmake | 3 +-- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index 063670a90..da637ef4d 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -12,12 +12,6 @@ if (NOT EXISTS "${XCCL_ROOT}") set(XCCL_ROOT $ENV{CCL_ROOT}) endif() -if(NOT DEFINED $ENV{CCL_ROOT}) - set(XCCL_FOUND False) - set(XCCL_NOT_FOUND_MESSAGE "OneCCL library not found!!") - return() -endif() - # Find include path from binary. find_file( XCCL_INCLUDE_DIR @@ -64,4 +58,3 @@ find_package_handle_standard_args( REQUIRED_VARS XCCL_INCLUDE_DIR XCCL_LIBRARY_DIR XCCL_LIBRARY REASON_FAILURE_MESSAGE "${XCCL_NOT_FOUND_MESSAGE}" ) - diff --git a/cmake/XCCL.cmake b/cmake/XCCL.cmake index 94979c438..ffe040291 100644 --- a/cmake/XCCL.cmake +++ b/cmake/XCCL.cmake @@ -4,7 +4,7 @@ if(NOT __XCCL_INCLUDED) # XCCL_ROOT, XCCL_LIBRARY_DIR, XCCL_INCLUDE_DIR are handled by FindXCCL.cmake. find_package(XCCL REQUIRED) if(NOT XCCL_FOUND) - message("${XCCL_NOT_FOUND_MESSAGE") + message("${XCCL_NOT_FOUND_MESSAGE}") return() endif() if(XCCL_FOUND) @@ -19,4 +19,3 @@ if(NOT __XCCL_INCLUDED) set(USE_C10D_XCCL ${USE_C10D_XCCL} PARENT_SCOPE) endif() endif() - From 86f09cbd1aa2db1c8c99d3ef1b5611b0127a6d8c Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 16 Dec 2024 23:15:24 +0800 Subject: [PATCH 13/19] update findxccl logit like mkl --- cmake/Modules/FindXCCL.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index da637ef4d..2904b2e35 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -8,8 +8,8 @@ include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) set(XCCL_ROOT "/opt/intel/oneapi/ccl/latest") if (NOT EXISTS "${XCCL_ROOT}") - message(STATUS "Default OneCCL not found, using current environment OneCCL") - set(XCCL_ROOT $ENV{CCL_ROOT}) + message(STATUS "Default OneCCL not found, using current environment OneAPI") + set(XCCL_ROOT $ENV{ONEAPI_ROOT}/ccl/latest) endif() # Find include path from binary. From d8c1e979ae0fb82f44fbacbb9b130a68a7ec6f5c Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 16 Dec 2024 23:23:07 +0800 Subject: [PATCH 14/19] add oneccl path to cmake include --- cmake/Modules/FindXCCL.cmake | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index 2904b2e35..faf5165f6 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -52,6 +52,11 @@ if((NOT XCCL_INCLUDE_DIR) OR (NOT XCCL_LIBRARY_DIR) OR (NOT XCCL_LIBRARY)) return() endif() +SET(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH} + "${XCCL_INCLUDE_DIR}") +SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} + "${XCCL_LIBRARY_DIR}") + find_package_handle_standard_args( XCCL FOUND_VAR XCCL_FOUND From 4b0eba06fee54d1c11a132ebf9d8d0077ab0bc31 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 16 Dec 2024 23:53:02 +0800 Subject: [PATCH 15/19] add deault oneapi path --- cmake/Modules/FindXCCL.cmake | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index faf5165f6..2ae3a3536 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -6,7 +6,17 @@ include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) -set(XCCL_ROOT "/opt/intel/oneapi/ccl/latest") +SET(DEFAULT_INTEL_CCL_DIR "/opt/intel/ccl") +SET(DEFAULT_INTEL_ONEAPI_DIR "/opt/intel/oneapi") +if (EXISTS "${DEFAULT_INTEL_ONEAPI_DIR}") + if (EXISTS "${DEFAULT_INTEL_ONEAPI_DIR}/ccl/latest") + SET(DEFAULT_INTEL_CCL_DIR "${DEFAULT_INTEL_ONEAPI_DIR}/ccl/latest") + endif() +endif() + +SET(XCCL_ROOT "${DEFAULT_INTEL_CCL_DIR}" CACHE STRING + "Root directory of the Intel CCL (standalone)") + if (NOT EXISTS "${XCCL_ROOT}") message(STATUS "Default OneCCL not found, using current environment OneAPI") set(XCCL_ROOT $ENV{ONEAPI_ROOT}/ccl/latest) From 72b2687a8c3d6eb3c200f2398a1f89ccc266a2a8 Mon Sep 17 00:00:00 2001 From: hanchao Date: Tue, 17 Dec 2024 01:05:35 +0000 Subject: [PATCH 16/19] rm default find path due to user source oneapi mandatory --- cmake/Modules/FindXCCL.cmake | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index 2ae3a3536..a54977434 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -6,21 +6,7 @@ include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) -SET(DEFAULT_INTEL_CCL_DIR "/opt/intel/ccl") -SET(DEFAULT_INTEL_ONEAPI_DIR "/opt/intel/oneapi") -if (EXISTS "${DEFAULT_INTEL_ONEAPI_DIR}") - if (EXISTS "${DEFAULT_INTEL_ONEAPI_DIR}/ccl/latest") - SET(DEFAULT_INTEL_CCL_DIR "${DEFAULT_INTEL_ONEAPI_DIR}/ccl/latest") - endif() -endif() - -SET(XCCL_ROOT "${DEFAULT_INTEL_CCL_DIR}" CACHE STRING - "Root directory of the Intel CCL (standalone)") - -if (NOT EXISTS "${XCCL_ROOT}") - message(STATUS "Default OneCCL not found, using current environment OneAPI") - set(XCCL_ROOT $ENV{ONEAPI_ROOT}/ccl/latest) -endif() +set(XCCL_ROOT $ENV{ONEAPI_ROOT}/ccl/latest) # Find include path from binary. find_file( From 5a40bd482e29fc28612ee6d8b3b53257432a0be2 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Wed, 18 Dec 2024 17:45:10 +0800 Subject: [PATCH 17/19] add simple xccl test --- test/xpu/xccl/test_c10d_xccl.py | 307 ++++++++++++++++++++++++++++++++ 1 file changed, 307 insertions(+) create mode 100644 test/xpu/xccl/test_c10d_xccl.py diff --git a/test/xpu/xccl/test_c10d_xccl.py b/test/xpu/xccl/test_c10d_xccl.py new file mode 100644 index 000000000..c842b6f80 --- /dev/null +++ b/test/xpu/xccl/test_c10d_xccl.py @@ -0,0 +1,307 @@ +# Owner(s): ["oncall: distributed"] + +import math +import os +import sys +import time +from datetime import timedelta +from unittest import mock + +import torch +import torch.distributed as c10d + + +if not c10d.is_available() or not c10d.is_xccl_available(): + print("c10d XCCL not available, skipping tests", file=sys.stderr) + sys.exit(0) + +import torch.distributed as dist +import torch.testing._internal.common_utils as common +from torch.testing._internal.common_distributed import ( + init_multigpu_helper, + MultiProcessTestCase, +) +from torch.testing._internal.common_utils import ( + retry_on_connect_failures, + run_tests, + skip_but_pass_in_sandcastle_if, + TEST_XPU, + TestCase, +) + +def requires_xccl(): + return skip_but_pass_in_sandcastle_if( + not c10d.is_xccl_available(), + "c10d was not compiled with the XCCL backend", + ) + +def simple_reduce_tests(rank, world_size): + tests = [ + ( + c10d.ReduceOp.SUM, + torch.tensor([rank + 1.0]), + torch.tensor([float(world_size * (world_size + 1) / 2)]), + ), + ( + c10d.ReduceOp.PRODUCT, + torch.tensor([rank + 1.0]), + torch.tensor([float(math.factorial(world_size))]), + ), + ( + c10d.ReduceOp.MIN, + torch.tensor([rank + 1.0]), + torch.tensor([1.0]), + ), + ( + c10d.ReduceOp.MAX, + torch.tensor([rank + 1.0]), + torch.tensor([world_size]), + ), + ] + + return tests + + +TEST_MULTIXPU = torch.xpu.device_count() > 1 + + +class RendezvousEnvTest(TestCase): + @retry_on_connect_failures + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test") + def test_common_errors(self): + vars = { + "WORLD_SIZE": "1", + "RANK": "0", + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": str(common.find_free_port()), + } + + class Env: + def __init__(self, vars): + self.env_patcher = mock.patch.dict(os.environ, vars, clear=True) + + def __enter__(self): + self.env_patcher.start() + + def __exit__(self, type, value, traceback): + self.env_patcher.stop() + + def without(d, key): + d = d.copy() + d.pop(key) + return d + + def withouts(d, keys): + d = d.copy() + for key in keys: + d.pop(key) + return d + + with Env(without(vars, "WORLD_SIZE")): + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + with self.assertRaisesRegex(ValueError, "WORLD_SIZE expected"): + gen = c10d.rendezvous("env://") + next(gen) + c10d.init_process_group(backend="xccl", world_size=1) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(without(vars, "RANK")): + self.assertEqual(None, os.environ.get("RANK")) + with self.assertRaisesRegex(ValueError, "RANK expected"): + gen = c10d.rendezvous("env://") + next(gen) + c10d.init_process_group(backend="xccl", rank=0) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + self.assertEqual(None, os.environ.get("RANK")) + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + c10d.init_process_group(backend="xccl", rank=0, world_size=1) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(vars): + c10d.init_process_group(backend="xccl") + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(without(vars, "MASTER_ADDR")): + self.assertEqual(None, os.environ.get("MASTER_ADDR")) + with self.assertRaisesRegex(ValueError, "MASTER_ADDR expected"): + gen = c10d.rendezvous("env://") + next(gen) + + with Env(without(vars, "MASTER_PORT")): + self.assertEqual(None, os.environ.get("MASTER_PORT")) + with self.assertRaisesRegex(ValueError, "MASTER_PORT expected"): + gen = c10d.rendezvous("env://") + next(gen) + + with Env(without(vars, "WORLD_SIZE")): + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + gen = c10d.rendezvous(f"env://?world_size={1}") + _, _, size = next(gen) + self.assertEqual(size, 1) + + with Env(without(vars, "RANK")): + self.assertEqual(None, os.environ.get("RANK")) + gen = c10d.rendezvous(f"env://?rank={0}") + _, rank, _ = next(gen) + self.assertEqual(rank, 0) + + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + self.assertEqual(None, os.environ.get("RANK")) + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + gen = c10d.rendezvous(f"env://?rank={0}&world_size={1}") + _, rank, size = next(gen) + self.assertEqual(rank, 0) + self.assertEqual(size, 1) + +class ProcessGroupXCCLTest(MultiProcessTestCase): + def _create_process_group_xccl( + self, timeout=timedelta(seconds=600), device_id=None + ): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + timeout=timeout, + device_id=device_id, + ) + pg = c10d.distributed_c10d._get_default_group() + return pg + + def setUp(self): + super().setUp() + self._spawn_processes() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self): + return 2 + + @property + def rank_to_GPU(self): + # return rank to GPU map + nGPUs = torch.xpu.device_count() + visible_devices = range(nGPUs) + nGPUs_per_process = 1 + if self.world_size > nGPUs: + nGPUs_per_process = nGPUs // self.world_size + GPUs = { + i: list(visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process]) + for i in range(self.world_size) + } + return GPUs + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_close_multi_pg_unordered(self): + pg = self._create_process_group_xccl() + device = self.rank_to_GPU[self.rank][0] + t = torch.rand(10, 10, device=device) + # First allreduce to initialize default PG's communicator. + pg.allreduce(t).wait() + new_pg1 = c10d.new_group([0, 1]) + new_pg2 = c10d.new_group([0, 1]) + if self.rank == 0 or self.rank == 1: + t1 = torch.rand(10, 10, device=device) + t2 = torch.rand(10, 10, device=device) + new_pg1.allreduce(t1).wait() + new_pg2.allreduce(t2).wait() + if self.rank == 0: + dist.destroy_process_group(new_pg2) + # force destruction of pg2 first + del new_pg2 + dist.destroy_process_group(new_pg1) + del new_pg1 + if self.rank == 1: + c10d.destroy_process_group(new_pg1) + # force destruction of pg1 first + del new_pg1 + dist.destroy_process_group(new_pg2) + del new_pg2 + dist.destroy_process_group() + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_file_store_check(self): + # self.file_name is created using "delete=False" + # e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="xccl", rank=self.rank, world_size=self.world_size, store=store + ) + pg = dist.distributed_c10d._get_default_group() + self.assertEqual(pg.rank(), self.rank) + self.assertEqual(pg.size(), self.world_size) + # give enough time for check() to be executed multiple times + time.sleep(2) + dist.destroy_process_group() + + # todo: https://github.com/pytorch/pytorch/blob/c06b5048ba866e2dd39e5da5399fe8261322c7ca/torch/distributed/distributed_c10d.py#L1862 device agnostic + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIXPU, "XCCL test requires 2+ GPUs") + # def test_set_process_group_desc(self): + # device = torch.device(f"xpu:{self.rank}") + # pg_default = self._create_process_group_xccl(device_id=device) + # self.assertEqual(pg_default.group_desc, "default_pg") + # pg_1 = c10d.new_group([0, 1], group_desc="test_purpose") + # self.assertEqual(pg_1.group_desc, "test_purpose") + # pg_2 = c10d.new_group([0, 1]) + # self.assertEqual(pg_2.group_desc, "undefined") + + def _test_allreduce_basics(self, fn): + pg = self._create_process_group_xccl() + device = torch.device("xpu:" + str(self.rank)) + # Single input tests + tests = simple_reduce_tests(self.rank, self.world_size) + for op, input, expected in tests: + opts = c10d.AllreduceOptions() + opts.reduceOp = op + tensor = fn(input.to(device)) + fut = pg.allreduce([tensor], opts).get_future() + fut.wait() + result = fut.value() + self.assertEqual(expected, result[0], exact_dtype=False) + + x = fn(torch.tensor([self.rank + 1.0], device=device)) + fut = pg.allreduce(x).get_future() + fut.wait() + result = fut.value() + self.assertEqual( + torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]), + result[0], + ) + + @requires_xccl() + def test_allreduce_basics(self): + self._test_allreduce_basics(lambda t: t.clone()) + + +if __name__ == "__main__": + assert ( + not torch.xpu._initialized + ), "test_distributed must not have initialized XPU context on main process" + + run_tests() + From 198926229a06686049a282ef5de3834036acb8e0 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Wed, 18 Dec 2024 18:34:46 +0800 Subject: [PATCH 18/19] update find ccl --- cmake/Modules/FindXCCL.cmake | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index a54977434..b211af9a9 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -6,7 +6,8 @@ include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) -set(XCCL_ROOT $ENV{ONEAPI_ROOT}/ccl/latest) +# we need source OneCCL environment before building. +set(XCCL_ROOT $ENV{CCL_ROOT}) # Find include path from binary. find_file( From a71447e55d8076e102f5db280b7c986cd97ef333 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 23 Dec 2024 21:53:42 +0800 Subject: [PATCH 19/19] rm ut --- test/xpu/xccl/test_c10d_xccl.py | 307 -------------------------------- 1 file changed, 307 deletions(-) delete mode 100644 test/xpu/xccl/test_c10d_xccl.py diff --git a/test/xpu/xccl/test_c10d_xccl.py b/test/xpu/xccl/test_c10d_xccl.py deleted file mode 100644 index c842b6f80..000000000 --- a/test/xpu/xccl/test_c10d_xccl.py +++ /dev/null @@ -1,307 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -import math -import os -import sys -import time -from datetime import timedelta -from unittest import mock - -import torch -import torch.distributed as c10d - - -if not c10d.is_available() or not c10d.is_xccl_available(): - print("c10d XCCL not available, skipping tests", file=sys.stderr) - sys.exit(0) - -import torch.distributed as dist -import torch.testing._internal.common_utils as common -from torch.testing._internal.common_distributed import ( - init_multigpu_helper, - MultiProcessTestCase, -) -from torch.testing._internal.common_utils import ( - retry_on_connect_failures, - run_tests, - skip_but_pass_in_sandcastle_if, - TEST_XPU, - TestCase, -) - -def requires_xccl(): - return skip_but_pass_in_sandcastle_if( - not c10d.is_xccl_available(), - "c10d was not compiled with the XCCL backend", - ) - -def simple_reduce_tests(rank, world_size): - tests = [ - ( - c10d.ReduceOp.SUM, - torch.tensor([rank + 1.0]), - torch.tensor([float(world_size * (world_size + 1) / 2)]), - ), - ( - c10d.ReduceOp.PRODUCT, - torch.tensor([rank + 1.0]), - torch.tensor([float(math.factorial(world_size))]), - ), - ( - c10d.ReduceOp.MIN, - torch.tensor([rank + 1.0]), - torch.tensor([1.0]), - ), - ( - c10d.ReduceOp.MAX, - torch.tensor([rank + 1.0]), - torch.tensor([world_size]), - ), - ] - - return tests - - -TEST_MULTIXPU = torch.xpu.device_count() > 1 - - -class RendezvousEnvTest(TestCase): - @retry_on_connect_failures - @requires_xccl() - @skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test") - def test_common_errors(self): - vars = { - "WORLD_SIZE": "1", - "RANK": "0", - "MASTER_ADDR": "127.0.0.1", - "MASTER_PORT": str(common.find_free_port()), - } - - class Env: - def __init__(self, vars): - self.env_patcher = mock.patch.dict(os.environ, vars, clear=True) - - def __enter__(self): - self.env_patcher.start() - - def __exit__(self, type, value, traceback): - self.env_patcher.stop() - - def without(d, key): - d = d.copy() - d.pop(key) - return d - - def withouts(d, keys): - d = d.copy() - for key in keys: - d.pop(key) - return d - - with Env(without(vars, "WORLD_SIZE")): - self.assertEqual(None, os.environ.get("WORLD_SIZE")) - with self.assertRaisesRegex(ValueError, "WORLD_SIZE expected"): - gen = c10d.rendezvous("env://") - next(gen) - c10d.init_process_group(backend="xccl", world_size=1) - self.assertEqual(c10d.get_rank(), 0) - self.assertEqual(c10d.get_world_size(), 1) - c10d.destroy_process_group() - - with Env(without(vars, "RANK")): - self.assertEqual(None, os.environ.get("RANK")) - with self.assertRaisesRegex(ValueError, "RANK expected"): - gen = c10d.rendezvous("env://") - next(gen) - c10d.init_process_group(backend="xccl", rank=0) - self.assertEqual(c10d.get_rank(), 0) - self.assertEqual(c10d.get_world_size(), 1) - c10d.destroy_process_group() - - with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): - self.assertEqual(None, os.environ.get("RANK")) - self.assertEqual(None, os.environ.get("WORLD_SIZE")) - c10d.init_process_group(backend="xccl", rank=0, world_size=1) - self.assertEqual(c10d.get_rank(), 0) - self.assertEqual(c10d.get_world_size(), 1) - c10d.destroy_process_group() - - with Env(vars): - c10d.init_process_group(backend="xccl") - self.assertEqual(c10d.get_rank(), 0) - self.assertEqual(c10d.get_world_size(), 1) - c10d.destroy_process_group() - - with Env(without(vars, "MASTER_ADDR")): - self.assertEqual(None, os.environ.get("MASTER_ADDR")) - with self.assertRaisesRegex(ValueError, "MASTER_ADDR expected"): - gen = c10d.rendezvous("env://") - next(gen) - - with Env(without(vars, "MASTER_PORT")): - self.assertEqual(None, os.environ.get("MASTER_PORT")) - with self.assertRaisesRegex(ValueError, "MASTER_PORT expected"): - gen = c10d.rendezvous("env://") - next(gen) - - with Env(without(vars, "WORLD_SIZE")): - self.assertEqual(None, os.environ.get("WORLD_SIZE")) - gen = c10d.rendezvous(f"env://?world_size={1}") - _, _, size = next(gen) - self.assertEqual(size, 1) - - with Env(without(vars, "RANK")): - self.assertEqual(None, os.environ.get("RANK")) - gen = c10d.rendezvous(f"env://?rank={0}") - _, rank, _ = next(gen) - self.assertEqual(rank, 0) - - with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): - self.assertEqual(None, os.environ.get("RANK")) - self.assertEqual(None, os.environ.get("WORLD_SIZE")) - gen = c10d.rendezvous(f"env://?rank={0}&world_size={1}") - _, rank, size = next(gen) - self.assertEqual(rank, 0) - self.assertEqual(size, 1) - -class ProcessGroupXCCLTest(MultiProcessTestCase): - def _create_process_group_xccl( - self, timeout=timedelta(seconds=600), device_id=None - ): - store = c10d.FileStore(self.file_name, self.world_size) - c10d.init_process_group( - "xccl", - world_size=self.world_size, - rank=self.rank, - store=store, - timeout=timeout, - device_id=device_id, - ) - pg = c10d.distributed_c10d._get_default_group() - return pg - - def setUp(self): - super().setUp() - self._spawn_processes() - - def tearDown(self): - super().tearDown() - try: - os.remove(self.file_name) - except OSError: - pass - - @property - def world_size(self): - return 2 - - @property - def rank_to_GPU(self): - # return rank to GPU map - nGPUs = torch.xpu.device_count() - visible_devices = range(nGPUs) - nGPUs_per_process = 1 - if self.world_size > nGPUs: - nGPUs_per_process = nGPUs // self.world_size - GPUs = { - i: list(visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process]) - for i in range(self.world_size) - } - return GPUs - - @requires_xccl() - @skip_but_pass_in_sandcastle_if( - torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" - ) - def test_close_multi_pg_unordered(self): - pg = self._create_process_group_xccl() - device = self.rank_to_GPU[self.rank][0] - t = torch.rand(10, 10, device=device) - # First allreduce to initialize default PG's communicator. - pg.allreduce(t).wait() - new_pg1 = c10d.new_group([0, 1]) - new_pg2 = c10d.new_group([0, 1]) - if self.rank == 0 or self.rank == 1: - t1 = torch.rand(10, 10, device=device) - t2 = torch.rand(10, 10, device=device) - new_pg1.allreduce(t1).wait() - new_pg2.allreduce(t2).wait() - if self.rank == 0: - dist.destroy_process_group(new_pg2) - # force destruction of pg2 first - del new_pg2 - dist.destroy_process_group(new_pg1) - del new_pg1 - if self.rank == 1: - c10d.destroy_process_group(new_pg1) - # force destruction of pg1 first - del new_pg1 - dist.destroy_process_group(new_pg2) - del new_pg2 - dist.destroy_process_group() - - @requires_xccl() - @skip_but_pass_in_sandcastle_if( - torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" - ) - def test_file_store_check(self): - # self.file_name is created using "delete=False" - # e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name - store = dist.FileStore(self.file_name, self.world_size) - dist.init_process_group( - backend="xccl", rank=self.rank, world_size=self.world_size, store=store - ) - pg = dist.distributed_c10d._get_default_group() - self.assertEqual(pg.rank(), self.rank) - self.assertEqual(pg.size(), self.world_size) - # give enough time for check() to be executed multiple times - time.sleep(2) - dist.destroy_process_group() - - # todo: https://github.com/pytorch/pytorch/blob/c06b5048ba866e2dd39e5da5399fe8261322c7ca/torch/distributed/distributed_c10d.py#L1862 device agnostic - # @requires_xccl() - # @skip_but_pass_in_sandcastle_if(not TEST_MULTIXPU, "XCCL test requires 2+ GPUs") - # def test_set_process_group_desc(self): - # device = torch.device(f"xpu:{self.rank}") - # pg_default = self._create_process_group_xccl(device_id=device) - # self.assertEqual(pg_default.group_desc, "default_pg") - # pg_1 = c10d.new_group([0, 1], group_desc="test_purpose") - # self.assertEqual(pg_1.group_desc, "test_purpose") - # pg_2 = c10d.new_group([0, 1]) - # self.assertEqual(pg_2.group_desc, "undefined") - - def _test_allreduce_basics(self, fn): - pg = self._create_process_group_xccl() - device = torch.device("xpu:" + str(self.rank)) - # Single input tests - tests = simple_reduce_tests(self.rank, self.world_size) - for op, input, expected in tests: - opts = c10d.AllreduceOptions() - opts.reduceOp = op - tensor = fn(input.to(device)) - fut = pg.allreduce([tensor], opts).get_future() - fut.wait() - result = fut.value() - self.assertEqual(expected, result[0], exact_dtype=False) - - x = fn(torch.tensor([self.rank + 1.0], device=device)) - fut = pg.allreduce(x).get_future() - fut.wait() - result = fut.value() - self.assertEqual( - torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]), - result[0], - ) - - @requires_xccl() - def test_allreduce_basics(self): - self._test_allreduce_basics(lambda t: t.clone()) - - -if __name__ == "__main__": - assert ( - not torch.xpu._initialized - ), "test_distributed must not have initialized XPU context on main process" - - run_tests() -