Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dipu): Add timing feature in cluster communication. #933

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "csrc_dipu/runtime/core/DIPUStream.h"
#include "csrc_dipu/runtime/core/allocator/DIPUCachingAllocator.h"
#include "csrc_dipu/runtime/devproxy/diclproxy.h"
#include "csrc_dipu/utils/Log.h"
#include "csrc_dipu/utils/helpfunc.hpp"
#include <csrc_dipu/vendor/vendorapi.h>

Expand Down Expand Up @@ -192,6 +193,16 @@ ProcessGroupDICL::WorkDICL::getFuture() {
return future_;
}

float ProcessGroupDICL::WorkDICL::getDuration() const {
TORCH_CHECK(diclStartEvents_->size() == 1,
"getDuration only works for single device per ProcessGroup.");
TORCH_CHECK(diclEndEvents_->size() == 1,
"getDuration only works for single device per ProcessGroup.");
TORCH_CHECK((*diclEndEvents_)[0].query(),
"getDuration can only be called after work is succeeded.")
return (*diclStartEvents_)[0].elapsed_time((*diclEndEvents_)[0]);
}

// end WorkDICL

ProcessGroupDICL::ProcessGroupDICL(const c10::intrusive_ptr<Store>& store,
Expand All @@ -213,10 +224,34 @@ ProcessGroupDICL::ProcessGroupDICL(const c10::intrusive_ptr<Store>& store,
throw std::runtime_error("Invalid value for environment variable: " +
std::string(DICL_BLOCKING_WAIT));
}

char* timingVar = getenv("DIPU_DICL_ENABLE_TIMING");
if (timingVar != nullptr) timingEnabled_ = true;

if (timingEnabled_) {
char* samplingVar = getenv("DIPU_DICL_SAMPLING_INTERVAL");
if (samplingVar != nullptr) samplingInterval_ = std::stoi(samplingVar);

samplingCount_ = samplingInterval_;
}
}

ProcessGroupDICL::~ProcessGroupDICL() = default;

void ProcessGroupDICL::printInfo(float duration, int comm_size,
int deviceID) const {
TORCH_CHECK(samplingCount_ == 0, "Print count hasn't reached 0 yet.")
std::ostringstream oss;
oss << "PG uniqueId = ";
for (int i = 0; i < 16; ++i) {
oss << static_cast<int>(uniqueidVec_[i]);
}
oss << ", rank =" << rank_ << ", deviceId = " << deviceID
<< ", comm_size = " << comm_size << " bytes"
<< ", duration = " << duration << std::endl;
DIPU_LOG_INFO << oss.str();
}

void ProcessGroupDICL::broadcastUniqueID(commUniqueId* uniqueId,
const std::string& storeKey,
int commRank) {
Expand All @@ -238,6 +273,7 @@ void ProcessGroupDICL::broadcastUniqueID(commUniqueId* uniqueId,
auto vec = std::vector<uint8_t>(reinterpret_cast<uint8_t*>(uniqueId),
reinterpret_cast<uint8_t*>(uniqueId) +
devapis::DICL_UNIQUE_ID_BYTES_SIZE);
uniqueidVec_ = vec;
store_->set(storeKey, vec);
} else {
auto vec = store_->get(storeKey);
Expand All @@ -246,6 +282,7 @@ void ProcessGroupDICL::broadcastUniqueID(commUniqueId* uniqueId,
"Unexpected DICL unique ID length received "
"from the store");
}
uniqueidVec_ = vec;
std::memcpy(uniqueId, vec.data(), vec.size());
}
}
Expand Down Expand Up @@ -442,6 +479,16 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::doComm(
auto work = c10::make_intrusive<ProcessGroupDICL::WorkDICL>(
diclComms, blockingWait_, opTimeout_);

if (timingEnabled_ && opType == OpType::ALLREDUCE) {
samplingCount_--;
if (samplingCount_ == 0) {
for (const auto i : c10::irange(inputs.size())) {
DIPUStream& diclStream = diclComms[i]->diclStream_;
(*work->diclStartEvents_)[i].record(diclStream);
}
}
}

OptionalDIPUGuard dipuGuard;
pre(diclComms);

Expand All @@ -466,6 +513,7 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::doComm(
}

post(diclComms);

work->record();

work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
Expand All @@ -478,6 +526,21 @@ c10::intrusive_ptr<Work> ProcessGroupDICL::doComm(
c10::ListType::create(c10::TensorType::get()), devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
}

if (timingEnabled_ && opType == OpType::ALLREDUCE && samplingCount_ == 0) {
for (const auto i : c10::irange(inputs.size())) {
DIPUStream& diclStream = diclComms[i]->diclStream_;
(*work->diclEndEvents_)[i].record(diclStream);
}
for (const auto i : c10::irange(inputs.size())) {
DIPUStream& diclStream = diclComms[i]->diclStream_;
(*work->diclEndEvents_)[i].synchronize();
}
printInfo(work->getDuration(), inputs[0].element_size() * inputs[0].numel(),
static_cast<int>(diclComms[0]->diclStream_.device_index()));
samplingCount_ = samplingInterval_;
}

return work;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ class DIPU_API ProcessGroupDICL : public Backend {
opTimeout_(opTimeout),
workStartTime_(std::chrono::steady_clock::now()) {
workEvents_.resize(diclComms_.size());

const char* timingVar = std::getenv("DIPU_DICL_ENABLE_TIMING");
if (timingVar != nullptr) {
diclStartEvents_ = std::make_shared<std::vector<DIPUEvent>>();
diclStartEvents_->reserve(diclComms_.size());
diclEndEvents_ = std::make_shared<std::vector<DIPUEvent>>();
diclEndEvents_->reserve(diclComms_.size());
for (uint32_t i = 0; i < diclComms_.size(); ++i) {
diclStartEvents_->emplace_back(DIPUEvent());
diclEndEvents_->emplace_back(DIPUEvent());
}
}
}

~WorkDICL() override = default;
Expand All @@ -111,6 +123,8 @@ class DIPU_API ProcessGroupDICL : public Backend {

c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;

float getDuration() const;

protected:
// Store a reference to DICL collective's outputs, used by result and to
// give a more descriptive message when representing the Work as a string.
Expand All @@ -125,6 +139,15 @@ class DIPU_API ProcessGroupDICL : public Backend {
// The DIPU events used to sync DICL work on comm stream
std::vector<DIPUEvent> workEvents_;

// The start DIPU events of DICL operator tracking this work item on
// multiple DIPU devices. These start DIPU events are needed by desync
// debugging if enabled.
std::shared_ptr<std::vector<DIPUEvent>> diclStartEvents_;

// The end DIPU events of DICL operator tracking this work item on
// multiple DIPU devices.
std::shared_ptr<std::vector<DIPUEvent>> diclEndEvents_;

// Just checks whether DIPU execution has completed, without modifying
// exception_ptr.
bool finishedDICLExecutionInternal() const;
Expand Down Expand Up @@ -310,6 +333,16 @@ class DIPU_API ProcessGroupDICL : public Backend {
bool blockingWait_ = false;

std::chrono::milliseconds opTimeout_ = kBackendDefaultTimeout;

void printInfo(float duration, int comm_size, int deviceID) const;

int samplingInterval_ = 100;

int samplingCount_;

bool timingEnabled_ = false;

std::vector<uint8_t> uniqueidVec_;
};

namespace dicl_hook {
Expand Down
Loading