Skip to content

Commit

Permalink
Fix bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
lwysense committed Aug 28, 2024
1 parent a6aa591 commit b2b1087
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ void ProcessGroupDICL::printInfo(float duration, int comm_size,
oss << static_cast<int>(uniqueidVec_[i]);
}
oss << ", rank =" << rank_ << ", deviceId = " << deviceID
<< " comm_size = " << comm_size << " bytes "
<< ", comm_size = " << comm_size << " bytes"
<< ", duration = " << duration << std::endl;
DIPU_LOG_INFO << oss.str();
}
Expand Down Expand Up @@ -481,9 +481,11 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::doComm(

if (timingEnabled_ && opType == OpType::ALLREDUCE) {
samplingCount_--;
for (const auto i : c10::irange(inputs.size())) {
DIPUStream& diclStream = diclComms[i]->diclStream_;
(*work->diclStartEvents_)[i].record(diclStream);
if (samplingCount_ == 0) {
for (const auto i : c10::irange(inputs.size())) {
DIPUStream& diclStream = diclComms[i]->diclStream_;
(*work->diclStartEvents_)[i].record(diclStream);
}
}
}

Expand Down Expand Up @@ -525,7 +527,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::doComm(
work->future_->markCompleted(at::IValue(*work->outputs_));
}

if (timingEnabled_ && opType == OpType::ALLREDUCE) {
if (timingEnabled_ && opType == OpType::ALLREDUCE && samplingCount_ == 0) {
for (const auto i : c10::irange(inputs.size())) {
DIPUStream& diclStream = diclComms[i]->diclStream_;
(*work->diclEndEvents_)[i].record(diclStream);
Expand All @@ -534,12 +536,9 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::doComm(
DIPUStream& diclStream = diclComms[i]->diclStream_;
(*work->diclEndEvents_)[i].synchronize();
}
if (samplingCount_ == 0) {
printInfo(work->getDuration(),
inputs[0].element_size() * inputs[0].numel(),
static_cast<int>(diclComms[0]->diclStream_.device_index()));
samplingCount_ = samplingInterval_;
}
printInfo(work->getDuration(), inputs[0].element_size() * inputs[0].numel(),
static_cast<int>(diclComms[0]->diclStream_.device_index()));
samplingCount_ = samplingInterval_;
}

return work;
Expand Down

0 comments on commit b2b1087

Please sign in to comment.