diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index c635d4e69..e58a8567c 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -229,21 +229,26 @@ ProcessGroupDICL::ProcessGroupDICL(const c10::intrusive_ptr& store, if (timingVar != nullptr) timingEnabled_ = true; if (timingEnabled_) { - char* frequencyVar = getenv("DIPU_DICL_PRINT_FREQUENCY"); - if (frequencyVar != nullptr) printFrequency_ = std::stoi(frequencyVar); + char* samplingVar = getenv("DIPU_DICL_SAMPLING_INTERVAL"); + if (samplingVar != nullptr) samplingInterval_ = std::stoi(samplingVar); - printCount_ = printFrequency_; + samplingCount_ = samplingInterval_; } } ProcessGroupDICL::~ProcessGroupDICL() = default; -void ProcessGroupDICL::printInfo(int deviceID) const { - TORCH_CHECK(printCount_ == 0, "Print count hasn't reached 0 yet.") +void ProcessGroupDICL::printInfo(float duration, int deviceID) const { + TORCH_CHECK(samplingCount_ == 0, "Print count hasn't reached 0 yet.") std::ostringstream oss; - oss << "Rank " << rank_ << ": deviceId = " << deviceID - << ", Average duration = " << totalDuration_ / printFrequency_ - << std::endl; + oss << "PG uniqueId = "; + for (int i = 0; i < 16; ++i) { + oss << static_cast(diclID_->internal[i]); + } + // oss << "PG uniqueId = " << diclID_->internal << ", rank =" << rank_ + oss << ", rank =" << rank_ + << ", deviceId = " << deviceID + << ", duration = " << duration << std::endl; DIPU_LOG_INFO << oss.str(); } @@ -310,6 +315,7 @@ std::vector>& ProcessGroupDICL::getDICLComms( ? localCommsKey : std::to_string(diclCommCounter_++); broadcastUniqueID(&diclID, bcastKey, commsRank); + diclID_ = &diclID; OptionalDIPUGuard dipuGuard; @@ -473,7 +479,7 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( diclComms, blockingWait_, opTimeout_); if (timingEnabled_ && opType == OpType::ALLREDUCE) { - printCount_--; + samplingCount_--; for (const auto i : c10::irange(inputs.size())) { DIPUStream& diclStream = diclComms[i]->diclStream_; (*work->diclStartEvents_)[i].record(diclStream); @@ -527,11 +533,10 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( DIPUStream& diclStream = diclComms[i]->diclStream_; (*work->diclEndEvents_)[i].synchronize(); } - totalDuration_ += work->getDuration(); - if (printCount_ == 0) { - printInfo(static_cast(diclComms[0]->diclStream_.device_index())); - printCount_ = printFrequency_; - totalDuration_ = 0; + if (samplingCount_ == 0) { + printInfo(work->getDuration(), + static_cast(diclComms[0]->diclStream_.device_index())); + samplingCount_ = samplingInterval_; } } diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h index 6726ff275..8e830b8b5 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h @@ -334,15 +334,15 @@ class DIPU_API ProcessGroupDICL : public Backend { std::chrono::milliseconds opTimeout_ = kBackendDefaultTimeout; - void printInfo(int deviceID) const; + void printInfo(float duration, int deviceID) const; - int printFrequency_ = 100; + int samplingInterval_ = 100; - int printCount_; - - float totalDuration_ = 0.; + int samplingCount_; bool timingEnabled_ = false; + + commUniqueId* diclID_; }; namespace dicl_hook {