From 197b42aafa9bd8dfd1a978050a76b194de1bf2d1 Mon Sep 17 00:00:00 2001 From: liuweiyu Date: Tue, 20 Aug 2024 19:25:21 +0800 Subject: [PATCH 01/11] The timing feature has been added in cluster communication to detect slow nodes. --- .../torch_dipu/csrc_dipu/binding/ExportRT.cpp | 23 +++++++ .../runtime/distributed/ProcessGroupDICL.cpp | 30 +++++++++ .../runtime/distributed/ProcessGroupDICL.h | 66 +++++++++++++++++++ 3 files changed, 119 insertions(+) diff --git a/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp b/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp index 4574f2b4f..7a8f393c8 100644 --- a/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp +++ b/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp @@ -249,6 +249,29 @@ void exportCommunicator(py::module& m) { return kBackendDefaultTimeout; }); + py::class_, + ProcessGroupDICL::PyWorkDICL>(m, "WorkDICL") + .def(py::init([](std::vector>& comms, bool blockingWait, + std::chrono::milliseconds opTimeout) { + return ProcessGroupDICL::WorkDICL(comms, blockingWait, opTimeout); + }), + py::arg("comms"), py::arg("blockingWait"), py::arg("opTimeout"), + py::call_guard()) + .def( + "_get_duration", + &ProcessGroupDICL::WorkDICL::getDuration, + py::call_guard(), + R"( + Returns: + Duration of the corresponding collective communication. + + .. warning :: + This API works for DICL backend for now and must set + DIPU_ENABLE_TIMING environment variable. + )" + ); + // py::object mdist = py::module::import("torch.distributed"); // py::object register_backend = // mdist.attr("Backend").attr("register_backend"); The first parameter is the diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index ea958b12e..72b4ef227 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -192,6 +192,21 @@ ProcessGroupDICL::WorkDICL::getFuture() { return future_; } +float ProcessGroupDICL::WorkDICL::getDuration() const { + // return 0.0; + TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled") + TORCH_CHECK( + diclStartEvents_->size() == 1, + "getDuration only works for single device per ProcessGroup."); + TORCH_CHECK( + diclEndEvents_->size() == 1, + "getDuration only works for single device per ProcessGroup."); + TORCH_CHECK( + (*diclEndEvents_)[0].query(), + "getDuration can only be called after work is succeeded.") + return (*diclStartEvents_)[0].elapsed_time((*diclEndEvents_)[0]); +} + // end WorkDICL ProcessGroupDICL::ProcessGroupDICL(const c10::intrusive_ptr& store, @@ -442,6 +457,13 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( auto work = c10::make_intrusive( diclComms, blockingWait_, opTimeout_); + if (work->timingEnabled_) { + for (const auto i : c10::irange(inputs.size())) { + DIPUStream& diclStream = diclComms[i]->diclStream_; + (*work->diclStartEvents_)[i].record(diclStream); + } + } + OptionalDIPUGuard dipuGuard; pre(diclComms); @@ -466,6 +488,14 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( } post(diclComms); + + if (work->timingEnabled_) { + for (const auto i : c10::irange(inputs.size())) { + DIPUStream& diclStream = diclComms[i]->diclStream_; + (*work->diclEndEvents_)[i].record(diclStream); + } + } + work->record(); work->outputs_ = std::make_shared>(outputs); diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h index d5ba9da1e..6a07e3fbd 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "csrc_dipu/base/basedef.h" #include "csrc_dipu/runtime/core/DIPUEvent.h" @@ -86,6 +87,30 @@ class DIPU_API ProcessGroupDICL : public Backend { opTimeout_(opTimeout), workStartTime_(std::chrono::steady_clock::now()) { workEvents_.resize(diclComms_.size()); + const char* timingVar = std::getenv("DIPU_DICL_ENABLE_TIMING"); + if(timingVar != nullptr) + timingEnabled_ = true; + + if (timingEnabled_) { + diclStartEvents_ = std::make_shared>(); + diclStartEvents_->reserve(diclComms_.size()); + diclEndEvents_ = std::make_shared>(); + diclEndEvents_->reserve(diclComms_.size()); + for (uint32_t i = 0; i < diclComms_.size(); ++i) { + diclStartEvents_->emplace_back(DIPUEvent()); + diclEndEvents_->emplace_back(DIPUEvent()); + } + } + } + + WorkDICL(const WorkDICL& w) + : diclComms_(w.diclComms_), + blockingWait_(w.blockingWait_), + opTimeout_(w.opTimeout_), + workStartTime_(w.workStartTime_), + diclStartEvents_(w.diclStartEvents_), + diclEndEvents_(w.diclEndEvents_) { + workEvents_.resize(w.workEvents_.size()); } ~WorkDICL() override = default; @@ -111,6 +136,8 @@ class DIPU_API ProcessGroupDICL : public Backend { c10::intrusive_ptr getFuture() override; + float getDuration() const; + protected: // Store a reference to DICL collective's outputs, used by result and to // give a more descriptive message when representing the Work as a string. @@ -125,6 +152,15 @@ class DIPU_API ProcessGroupDICL : public Backend { // The DIPU events used to sync DICL work on comm stream std::vector workEvents_; + // The start DIPU events of DICL operator tracking this work item on + // multiple DIPU devices. These start DIPU events are needed by desync + // debugging if enabled. + std::shared_ptr> diclStartEvents_; + + // The end DIPU events of DICL operator tracking this work item on + // multiple DIPU devices. + std::shared_ptr> diclEndEvents_; + // Just checks whether DIPU execution has completed, without modifying // exception_ptr. bool finishedDICLExecutionInternal() const; @@ -141,6 +177,36 @@ class DIPU_API ProcessGroupDICL : public Backend { private: friend class ProcessGroupDICL; + bool timingEnabled_; + }; + +class PyWorkDICL : public WorkDICL { + public: + PyWorkDICL() = default; + + bool wait(std::chrono::milliseconds timeout = kNoTimeout) override { + PYBIND11_OVERRIDE( + bool, /* Return type */ + WorkDICL, /* Parent class */ + wait, /* Name of function in C++ */ + timeout); + } + + c10::intrusive_ptr getFuture() override { + pybind11::gil_scoped_acquire gil; + auto override = + pybind11::get_override(static_cast(this), "get_future"); + + if (override) { + pybind11::object o = override(); + auto futWrapper = + o.cast>(); + return futWrapper->fut; + } + + return WorkDICL::getFuture(); + } + }; struct DIPU_API Options : Backend::Options { From ff428d9d0470799bd3f04b425f8f30c70245b54f Mon Sep 17 00:00:00 2001 From: liuweiyu Date: Tue, 20 Aug 2024 19:51:16 +0800 Subject: [PATCH 02/11] Fix format --- .../torch_dipu/csrc_dipu/binding/ExportRT.cpp | 21 ++++++++----------- .../runtime/distributed/ProcessGroupDICL.cpp | 15 ++++++------- .../runtime/distributed/ProcessGroupDICL.h | 19 +++++++---------- 3 files changed, 23 insertions(+), 32 deletions(-) diff --git a/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp b/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp index 7a8f393c8..7583051a1 100644 --- a/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp +++ b/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp @@ -249,28 +249,25 @@ void exportCommunicator(py::module& m) { return kBackendDefaultTimeout; }); - py::class_, - ProcessGroupDICL::PyWorkDICL>(m, "WorkDICL") - .def(py::init([](std::vector>& comms, bool blockingWait, - std::chrono::milliseconds opTimeout) { + py::class_, + ProcessGroupDICL::PyWorkDICL>(m, "WorkDICL") + .def(py::init([](std::vector>& comms, + bool blockingWait, std::chrono::milliseconds opTimeout) { return ProcessGroupDICL::WorkDICL(comms, blockingWait, opTimeout); }), py::arg("comms"), py::arg("blockingWait"), py::arg("opTimeout"), py::call_guard()) - .def( - "_get_duration", - &ProcessGroupDICL::WorkDICL::getDuration, - py::call_guard(), - R"( + .def("_get_duration", &ProcessGroupDICL::WorkDICL::getDuration, + py::call_guard(), + R"( Returns: Duration of the corresponding collective communication. .. warning :: This API works for DICL backend for now and must set DIPU_ENABLE_TIMING environment variable. - )" - ); + )"); // py::object mdist = py::module::import("torch.distributed"); // py::object register_backend = diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index 72b4ef227..51c7d432f 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -195,15 +195,12 @@ ProcessGroupDICL::WorkDICL::getFuture() { float ProcessGroupDICL::WorkDICL::getDuration() const { // return 0.0; TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled") - TORCH_CHECK( - diclStartEvents_->size() == 1, - "getDuration only works for single device per ProcessGroup."); - TORCH_CHECK( - diclEndEvents_->size() == 1, - "getDuration only works for single device per ProcessGroup."); - TORCH_CHECK( - (*diclEndEvents_)[0].query(), - "getDuration can only be called after work is succeeded.") + TORCH_CHECK(diclStartEvents_->size() == 1, + "getDuration only works for single device per ProcessGroup."); + TORCH_CHECK(diclEndEvents_->size() == 1, + "getDuration only works for single device per ProcessGroup."); + TORCH_CHECK((*diclEndEvents_)[0].query(), + "getDuration can only be called after work is succeeded.") return (*diclStartEvents_)[0].elapsed_time((*diclEndEvents_)[0]); } diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h index 6a07e3fbd..37a9c9ac7 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h @@ -88,8 +88,7 @@ class DIPU_API ProcessGroupDICL : public Backend { workStartTime_(std::chrono::steady_clock::now()) { workEvents_.resize(diclComms_.size()); const char* timingVar = std::getenv("DIPU_DICL_ENABLE_TIMING"); - if(timingVar != nullptr) - timingEnabled_ = true; + if (timingVar != nullptr) timingEnabled_ = true; if (timingEnabled_) { diclStartEvents_ = std::make_shared>(); @@ -180,22 +179,21 @@ class DIPU_API ProcessGroupDICL : public Backend { bool timingEnabled_; }; -class PyWorkDICL : public WorkDICL { + class PyWorkDICL : public WorkDICL { public: PyWorkDICL() = default; bool wait(std::chrono::milliseconds timeout = kNoTimeout) override { - PYBIND11_OVERRIDE( - bool, /* Return type */ - WorkDICL, /* Parent class */ - wait, /* Name of function in C++ */ - timeout); + PYBIND11_OVERRIDE(bool, /* Return type */ + WorkDICL, /* Parent class */ + wait, /* Name of function in C++ */ + timeout); } c10::intrusive_ptr getFuture() override { pybind11::gil_scoped_acquire gil; - auto override = - pybind11::get_override(static_cast(this), "get_future"); + auto override = pybind11::get_override(static_cast(this), + "get_future"); if (override) { pybind11::object o = override(); @@ -206,7 +204,6 @@ class PyWorkDICL : public WorkDICL { return WorkDICL::getFuture(); } - }; struct DIPU_API Options : Backend::Options { From 81b973c66a92bce523dd340f44cf4f45cb9a950b Mon Sep 17 00:00:00 2001 From: liuweiyu Date: Thu, 22 Aug 2024 16:58:41 +0800 Subject: [PATCH 03/11] After a certain number of allreduce operations, print the execution time. --- .../torch_dipu/csrc_dipu/binding/ExportRT.cpp | 20 ----- .../runtime/distributed/ProcessGroupDICL.cpp | 86 ++++++++++++++----- .../runtime/distributed/ProcessGroupDICL.h | 81 ++++------------- 3 files changed, 82 insertions(+), 105 deletions(-) diff --git a/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp b/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp index 7583051a1..4574f2b4f 100644 --- a/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp +++ b/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp @@ -249,26 +249,6 @@ void exportCommunicator(py::module& m) { return kBackendDefaultTimeout; }); - py::class_, - ProcessGroupDICL::PyWorkDICL>(m, "WorkDICL") - .def(py::init([](std::vector>& comms, - bool blockingWait, std::chrono::milliseconds opTimeout) { - return ProcessGroupDICL::WorkDICL(comms, blockingWait, opTimeout); - }), - py::arg("comms"), py::arg("blockingWait"), py::arg("opTimeout"), - py::call_guard()) - .def("_get_duration", &ProcessGroupDICL::WorkDICL::getDuration, - py::call_guard(), - R"( - Returns: - Duration of the corresponding collective communication. - - .. warning :: - This API works for DICL backend for now and must set - DIPU_ENABLE_TIMING environment variable. - )"); - // py::object mdist = py::module::import("torch.distributed"); // py::object register_backend = // mdist.attr("Backend").attr("register_backend"); The first parameter is the diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index 51c7d432f..5eabb14ec 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -23,6 +23,7 @@ #include "csrc_dipu/runtime/core/DIPUStream.h" #include "csrc_dipu/runtime/core/allocator/DIPUCachingAllocator.h" #include "csrc_dipu/runtime/devproxy/diclproxy.h" +#include "csrc_dipu/utils/Log.h" #include "csrc_dipu/utils/helpfunc.hpp" #include @@ -192,18 +193,6 @@ ProcessGroupDICL::WorkDICL::getFuture() { return future_; } -float ProcessGroupDICL::WorkDICL::getDuration() const { - // return 0.0; - TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled") - TORCH_CHECK(diclStartEvents_->size() == 1, - "getDuration only works for single device per ProcessGroup."); - TORCH_CHECK(diclEndEvents_->size() == 1, - "getDuration only works for single device per ProcessGroup."); - TORCH_CHECK((*diclEndEvents_)[0].query(), - "getDuration can only be called after work is succeeded.") - return (*diclStartEvents_)[0].elapsed_time((*diclEndEvents_)[0]); -} - // end WorkDICL ProcessGroupDICL::ProcessGroupDICL(const c10::intrusive_ptr& store, @@ -225,10 +214,49 @@ ProcessGroupDICL::ProcessGroupDICL(const c10::intrusive_ptr& store, throw std::runtime_error("Invalid value for environment variable: " + std::string(DICL_BLOCKING_WAIT)); } + + char* timingVar = getenv("DIPU_DICL_ENABLE_TIMING"); + if (timingVar != nullptr) timingEnabled_ = true; + + // For now, only works for single device per ProcessGroup. + int devicePerProcessGroup = 1; + if (timingEnabled_) { + diclStartEvents_ = std::make_shared>(); + diclStartEvents_->reserve(devicePerProcessGroup); + diclEndEvents_ = std::make_shared>(); + diclEndEvents_->reserve(devicePerProcessGroup); + for (uint32_t i = 0; i < devicePerProcessGroup; ++i) { + diclStartEvents_->emplace_back(DIPUEvent()); + diclEndEvents_->emplace_back(DIPUEvent()); + } + + char* frequencyVar = getenv("DIPU_DICL_PRINT_FREQUENCY"); + if (frequencyVar != nullptr) printFrequency_ = std::stoi(frequencyVar); + + printCount_ = printFrequency_; + } } ProcessGroupDICL::~ProcessGroupDICL() = default; +float ProcessGroupDICL::getDuration() const { + TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled") + TORCH_CHECK(diclStartEvents_->size() == 1, + "getDuration only works for single device per ProcessGroup."); + TORCH_CHECK(diclEndEvents_->size() == 1, + "getDuration only works for single device per ProcessGroup."); + TORCH_CHECK((*diclEndEvents_)[0].query(), + "getDuration can only be called after work is succeeded.") + return (*diclStartEvents_)[0].elapsed_time((*diclEndEvents_)[0]); +} + +void ProcessGroupDICL::printInfo() const { + TORCH_CHECK(printCount_ == 0, "Print count hasn't reached 0 yet.") + std::ostringstream oss; + oss << "Rank " << rank_ << " duration = " << getDuration() << std::endl; + DIPU_LOG_INFO << oss.str(); +} + void ProcessGroupDICL::broadcastUniqueID(commUniqueId* uniqueId, const std::string& storeKey, int commRank) { @@ -454,11 +482,14 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( auto work = c10::make_intrusive( diclComms, blockingWait_, opTimeout_); - if (work->timingEnabled_) { - for (const auto i : c10::irange(inputs.size())) { - DIPUStream& diclStream = diclComms[i]->diclStream_; - (*work->diclStartEvents_)[i].record(diclStream); + if (timingEnabled_ && opType == OpType::ALLREDUCE) { + if (printCount_ == printFrequency_) { + for (const auto i : c10::irange(inputs.size())) { + DIPUStream& diclStream = diclComms[i]->diclStream_; + (*diclStartEvents_)[i].record(diclStream); + } } + printCount_--; } OptionalDIPUGuard dipuGuard; @@ -486,13 +517,6 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( post(diclComms); - if (work->timingEnabled_) { - for (const auto i : c10::irange(inputs.size())) { - DIPUStream& diclStream = diclComms[i]->diclStream_; - (*work->diclEndEvents_)[i].record(diclStream); - } - } - work->record(); work->outputs_ = std::make_shared>(outputs); @@ -505,6 +529,22 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( c10::ListType::create(c10::TensorType::get()), devices); work->future_->markCompleted(at::IValue(*work->outputs_)); } + + if (timingEnabled_ && opType == OpType::ALLREDUCE) { + if (printCount_ == 0) { + for (const auto i : c10::irange(inputs.size())) { + DIPUStream& diclStream = diclComms[i]->diclStream_; + (*diclEndEvents_)[i].record(diclStream); + } + for (const auto i : c10::irange(inputs.size())) { + DIPUStream& diclStream = diclComms[i]->diclStream_; + (*diclEndEvents_)[i].synchronize(); + } + printInfo(); + printCount_ = printFrequency_; + } + } + return work; } diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h index 37a9c9ac7..8253c8f94 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h @@ -11,7 +11,6 @@ #include #include #include -#include #include "csrc_dipu/base/basedef.h" #include "csrc_dipu/runtime/core/DIPUEvent.h" @@ -87,29 +86,6 @@ class DIPU_API ProcessGroupDICL : public Backend { opTimeout_(opTimeout), workStartTime_(std::chrono::steady_clock::now()) { workEvents_.resize(diclComms_.size()); - const char* timingVar = std::getenv("DIPU_DICL_ENABLE_TIMING"); - if (timingVar != nullptr) timingEnabled_ = true; - - if (timingEnabled_) { - diclStartEvents_ = std::make_shared>(); - diclStartEvents_->reserve(diclComms_.size()); - diclEndEvents_ = std::make_shared>(); - diclEndEvents_->reserve(diclComms_.size()); - for (uint32_t i = 0; i < diclComms_.size(); ++i) { - diclStartEvents_->emplace_back(DIPUEvent()); - diclEndEvents_->emplace_back(DIPUEvent()); - } - } - } - - WorkDICL(const WorkDICL& w) - : diclComms_(w.diclComms_), - blockingWait_(w.blockingWait_), - opTimeout_(w.opTimeout_), - workStartTime_(w.workStartTime_), - diclStartEvents_(w.diclStartEvents_), - diclEndEvents_(w.diclEndEvents_) { - workEvents_.resize(w.workEvents_.size()); } ~WorkDICL() override = default; @@ -135,8 +111,6 @@ class DIPU_API ProcessGroupDICL : public Backend { c10::intrusive_ptr getFuture() override; - float getDuration() const; - protected: // Store a reference to DICL collective's outputs, used by result and to // give a more descriptive message when representing the Work as a string. @@ -151,15 +125,6 @@ class DIPU_API ProcessGroupDICL : public Backend { // The DIPU events used to sync DICL work on comm stream std::vector workEvents_; - // The start DIPU events of DICL operator tracking this work item on - // multiple DIPU devices. These start DIPU events are needed by desync - // debugging if enabled. - std::shared_ptr> diclStartEvents_; - - // The end DIPU events of DICL operator tracking this work item on - // multiple DIPU devices. - std::shared_ptr> diclEndEvents_; - // Just checks whether DIPU execution has completed, without modifying // exception_ptr. bool finishedDICLExecutionInternal() const; @@ -179,33 +144,6 @@ class DIPU_API ProcessGroupDICL : public Backend { bool timingEnabled_; }; - class PyWorkDICL : public WorkDICL { - public: - PyWorkDICL() = default; - - bool wait(std::chrono::milliseconds timeout = kNoTimeout) override { - PYBIND11_OVERRIDE(bool, /* Return type */ - WorkDICL, /* Parent class */ - wait, /* Name of function in C++ */ - timeout); - } - - c10::intrusive_ptr getFuture() override { - pybind11::gil_scoped_acquire gil; - auto override = pybind11::get_override(static_cast(this), - "get_future"); - - if (override) { - pybind11::object o = override(); - auto futWrapper = - o.cast>(); - return futWrapper->fut; - } - - return WorkDICL::getFuture(); - } - }; - struct DIPU_API Options : Backend::Options { // NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for // operations. This is only used when blockingWait_ is enabled. @@ -373,6 +311,25 @@ class DIPU_API ProcessGroupDICL : public Backend { bool blockingWait_ = false; std::chrono::milliseconds opTimeout_ = kBackendDefaultTimeout; + + // The start DIPU events of DICL operator tracking this work item on + // multiple DIPU devices. These start DIPU events are needed by desync + // debugging if enabled. + std::shared_ptr> diclStartEvents_; + + // The end DIPU events of DICL operator tracking this work item on + // multiple DIPU devices. + std::shared_ptr> diclEndEvents_; + + float getDuration() const; + + void printInfo() const; + + int printFrequency_ = 100; + + int printCount_; + + bool timingEnabled_ = false; }; namespace dicl_hook { From 55b167658fe29043515dac2e5d739884eba0e83e Mon Sep 17 00:00:00 2001 From: liuweiyu Date: Thu, 22 Aug 2024 18:34:34 +0800 Subject: [PATCH 04/11] Print average duration. --- .../csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index 5eabb14ec..0f159d709 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -253,7 +253,8 @@ float ProcessGroupDICL::getDuration() const { void ProcessGroupDICL::printInfo() const { TORCH_CHECK(printCount_ == 0, "Print count hasn't reached 0 yet.") std::ostringstream oss; - oss << "Rank " << rank_ << " duration = " << getDuration() << std::endl; + oss << "Rank " << rank_ << " duration = " << getDuration() / printFrequency_ + << std::endl; DIPU_LOG_INFO << oss.str(); } From 5ea06f6b991c471edc2967869765907f190b5e80 Mon Sep 17 00:00:00 2001 From: liuweiyu Date: Thu, 22 Aug 2024 20:01:41 +0800 Subject: [PATCH 05/11] Add device ID printing. --- .../csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp | 8 ++++---- .../csrc_dipu/runtime/distributed/ProcessGroupDICL.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index 0f159d709..884ce6ee3 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -250,11 +250,11 @@ float ProcessGroupDICL::getDuration() const { return (*diclStartEvents_)[0].elapsed_time((*diclEndEvents_)[0]); } -void ProcessGroupDICL::printInfo() const { +void ProcessGroupDICL::printInfo(int deviceID) const { TORCH_CHECK(printCount_ == 0, "Print count hasn't reached 0 yet.") std::ostringstream oss; - oss << "Rank " << rank_ << " duration = " << getDuration() / printFrequency_ - << std::endl; + oss << "Rank " << rank_ << ": deviceId = " << deviceID + << " duration = " << getDuration() / printFrequency_ << std::endl; DIPU_LOG_INFO << oss.str(); } @@ -541,7 +541,7 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( DIPUStream& diclStream = diclComms[i]->diclStream_; (*diclEndEvents_)[i].synchronize(); } - printInfo(); + printInfo(static_cast(diclComms[0]->diclStream_.device_index())); printCount_ = printFrequency_; } } diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h index 8253c8f94..d75a5ee4c 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h @@ -323,7 +323,7 @@ class DIPU_API ProcessGroupDICL : public Backend { float getDuration() const; - void printInfo() const; + void printInfo(int deviceID) const; int printFrequency_ = 100; From 95656f31b0fc5718b49549aaa54b4b20ea21ac7e Mon Sep 17 00:00:00 2001 From: liuweiyu Date: Fri, 23 Aug 2024 11:53:03 +0800 Subject: [PATCH 06/11] Fix a logical error. --- .../runtime/distributed/ProcessGroupDICL.cpp | 64 ++++++++----------- .../runtime/distributed/ProcessGroupDICL.h | 37 +++++++---- 2 files changed, 51 insertions(+), 50 deletions(-) diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index 884ce6ee3..ce7cd2972 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -193,6 +193,16 @@ ProcessGroupDICL::WorkDICL::getFuture() { return future_; } +float ProcessGroupDICL::WorkDICL::getDuration() const { + TORCH_CHECK(diclStartEvents_->size() == 1, + "getDuration only works for single device per ProcessGroup."); + TORCH_CHECK(diclEndEvents_->size() == 1, + "getDuration only works for single device per ProcessGroup."); + TORCH_CHECK((*diclEndEvents_)[0].query(), + "getDuration can only be called after work is succeeded.") + return (*diclStartEvents_)[0].elapsed_time((*diclEndEvents_)[0]); +} + // end WorkDICL ProcessGroupDICL::ProcessGroupDICL(const c10::intrusive_ptr& store, @@ -218,18 +228,7 @@ ProcessGroupDICL::ProcessGroupDICL(const c10::intrusive_ptr& store, char* timingVar = getenv("DIPU_DICL_ENABLE_TIMING"); if (timingVar != nullptr) timingEnabled_ = true; - // For now, only works for single device per ProcessGroup. - int devicePerProcessGroup = 1; if (timingEnabled_) { - diclStartEvents_ = std::make_shared>(); - diclStartEvents_->reserve(devicePerProcessGroup); - diclEndEvents_ = std::make_shared>(); - diclEndEvents_->reserve(devicePerProcessGroup); - for (uint32_t i = 0; i < devicePerProcessGroup; ++i) { - diclStartEvents_->emplace_back(DIPUEvent()); - diclEndEvents_->emplace_back(DIPUEvent()); - } - char* frequencyVar = getenv("DIPU_DICL_PRINT_FREQUENCY"); if (frequencyVar != nullptr) printFrequency_ = std::stoi(frequencyVar); @@ -239,22 +238,12 @@ ProcessGroupDICL::ProcessGroupDICL(const c10::intrusive_ptr& store, ProcessGroupDICL::~ProcessGroupDICL() = default; -float ProcessGroupDICL::getDuration() const { - TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled") - TORCH_CHECK(diclStartEvents_->size() == 1, - "getDuration only works for single device per ProcessGroup."); - TORCH_CHECK(diclEndEvents_->size() == 1, - "getDuration only works for single device per ProcessGroup."); - TORCH_CHECK((*diclEndEvents_)[0].query(), - "getDuration can only be called after work is succeeded.") - return (*diclStartEvents_)[0].elapsed_time((*diclEndEvents_)[0]); -} - void ProcessGroupDICL::printInfo(int deviceID) const { TORCH_CHECK(printCount_ == 0, "Print count hasn't reached 0 yet.") std::ostringstream oss; - oss << "Rank " << rank_ << ": deviceId = " << deviceID - << " duration = " << getDuration() / printFrequency_ << std::endl; + oss << "Rank " << rank_ << ": deviceId = " << deviceID + << ", Average duration = " << totalDuration_ / printFrequency_ + << std::endl; DIPU_LOG_INFO << oss.str(); } @@ -484,13 +473,11 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( diclComms, blockingWait_, opTimeout_); if (timingEnabled_ && opType == OpType::ALLREDUCE) { - if (printCount_ == printFrequency_) { - for (const auto i : c10::irange(inputs.size())) { - DIPUStream& diclStream = diclComms[i]->diclStream_; - (*diclStartEvents_)[i].record(diclStream); - } - } printCount_--; + for (const auto i : c10::irange(inputs.size())) { + DIPUStream& diclStream = diclComms[i]->diclStream_; + (*work->diclStartEvents_)[i].record(diclStream); + } } OptionalDIPUGuard dipuGuard; @@ -532,15 +519,16 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( } if (timingEnabled_ && opType == OpType::ALLREDUCE) { + for (const auto i : c10::irange(inputs.size())) { + DIPUStream& diclStream = diclComms[i]->diclStream_; + (*work->diclEndEvents_)[i].record(diclStream); + } + for (const auto i : c10::irange(inputs.size())) { + DIPUStream& diclStream = diclComms[i]->diclStream_; + (*work->diclEndEvents_)[i].synchronize(); + } + totalDuration_ += work->getDuration(); if (printCount_ == 0) { - for (const auto i : c10::irange(inputs.size())) { - DIPUStream& diclStream = diclComms[i]->diclStream_; - (*diclEndEvents_)[i].record(diclStream); - } - for (const auto i : c10::irange(inputs.size())) { - DIPUStream& diclStream = diclComms[i]->diclStream_; - (*diclEndEvents_)[i].synchronize(); - } printInfo(static_cast(diclComms[0]->diclStream_.device_index())); printCount_ = printFrequency_; } diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h index d75a5ee4c..6726ff275 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h @@ -86,6 +86,18 @@ class DIPU_API ProcessGroupDICL : public Backend { opTimeout_(opTimeout), workStartTime_(std::chrono::steady_clock::now()) { workEvents_.resize(diclComms_.size()); + + const char* timingVar = std::getenv("DIPU_DICL_ENABLE_TIMING"); + if (timingVar != nullptr) { + diclStartEvents_ = std::make_shared>(); + diclStartEvents_->reserve(diclComms_.size()); + diclEndEvents_ = std::make_shared>(); + diclEndEvents_->reserve(diclComms_.size()); + for (uint32_t i = 0; i < diclComms_.size(); ++i) { + diclStartEvents_->emplace_back(DIPUEvent()); + diclEndEvents_->emplace_back(DIPUEvent()); + } + } } ~WorkDICL() override = default; @@ -111,6 +123,8 @@ class DIPU_API ProcessGroupDICL : public Backend { c10::intrusive_ptr getFuture() override; + float getDuration() const; + protected: // Store a reference to DICL collective's outputs, used by result and to // give a more descriptive message when representing the Work as a string. @@ -125,6 +139,15 @@ class DIPU_API ProcessGroupDICL : public Backend { // The DIPU events used to sync DICL work on comm stream std::vector workEvents_; + // The start DIPU events of DICL operator tracking this work item on + // multiple DIPU devices. These start DIPU events are needed by desync + // debugging if enabled. + std::shared_ptr> diclStartEvents_; + + // The end DIPU events of DICL operator tracking this work item on + // multiple DIPU devices. + std::shared_ptr> diclEndEvents_; + // Just checks whether DIPU execution has completed, without modifying // exception_ptr. bool finishedDICLExecutionInternal() const; @@ -141,7 +164,6 @@ class DIPU_API ProcessGroupDICL : public Backend { private: friend class ProcessGroupDICL; - bool timingEnabled_; }; struct DIPU_API Options : Backend::Options { @@ -312,23 +334,14 @@ class DIPU_API ProcessGroupDICL : public Backend { std::chrono::milliseconds opTimeout_ = kBackendDefaultTimeout; - // The start DIPU events of DICL operator tracking this work item on - // multiple DIPU devices. These start DIPU events are needed by desync - // debugging if enabled. - std::shared_ptr> diclStartEvents_; - - // The end DIPU events of DICL operator tracking this work item on - // multiple DIPU devices. - std::shared_ptr> diclEndEvents_; - - float getDuration() const; - void printInfo(int deviceID) const; int printFrequency_ = 100; int printCount_; + float totalDuration_ = 0.; + bool timingEnabled_ = false; }; From 6d20e2ae99628f3c16b1f0cf6e646fe709446863 Mon Sep 17 00:00:00 2001 From: liuweiyu Date: Fri, 23 Aug 2024 16:18:38 +0800 Subject: [PATCH 07/11] Fix bug --- .../csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index ce7cd2972..c635d4e69 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -531,6 +531,7 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( if (printCount_ == 0) { printInfo(static_cast(diclComms[0]->diclStream_.device_index())); printCount_ = printFrequency_; + totalDuration_ = 0; } } From 7fc6f6c783ef946ccfe3d4a62158578365ed491a Mon Sep 17 00:00:00 2001 From: liuweiyu Date: Wed, 28 Aug 2024 15:26:59 +0800 Subject: [PATCH 08/11] Change the statistical method to sampling. --- .../runtime/distributed/ProcessGroupDICL.cpp | 33 +++++++++++-------- .../runtime/distributed/ProcessGroupDICL.h | 10 +++--- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index c635d4e69..e58a8567c 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -229,21 +229,26 @@ ProcessGroupDICL::ProcessGroupDICL(const c10::intrusive_ptr& store, if (timingVar != nullptr) timingEnabled_ = true; if (timingEnabled_) { - char* frequencyVar = getenv("DIPU_DICL_PRINT_FREQUENCY"); - if (frequencyVar != nullptr) printFrequency_ = std::stoi(frequencyVar); + char* samplingVar = getenv("DIPU_DICL_SAMPLING_INTERVAL"); + if (samplingVar != nullptr) samplingInterval_ = std::stoi(samplingVar); - printCount_ = printFrequency_; + samplingCount_ = samplingInterval_; } } ProcessGroupDICL::~ProcessGroupDICL() = default; -void ProcessGroupDICL::printInfo(int deviceID) const { - TORCH_CHECK(printCount_ == 0, "Print count hasn't reached 0 yet.") +void ProcessGroupDICL::printInfo(float duration, int deviceID) const { + TORCH_CHECK(samplingCount_ == 0, "Print count hasn't reached 0 yet.") std::ostringstream oss; - oss << "Rank " << rank_ << ": deviceId = " << deviceID - << ", Average duration = " << totalDuration_ / printFrequency_ - << std::endl; + oss << "PG uniqueId = "; + for (int i = 0; i < 16; ++i) { + oss << static_cast(diclID_->internal[i]); + } + // oss << "PG uniqueId = " << diclID_->internal << ", rank =" << rank_ + oss << ", rank =" << rank_ + << ", deviceId = " << deviceID + << ", duration = " << duration << std::endl; DIPU_LOG_INFO << oss.str(); } @@ -310,6 +315,7 @@ std::vector>& ProcessGroupDICL::getDICLComms( ? localCommsKey : std::to_string(diclCommCounter_++); broadcastUniqueID(&diclID, bcastKey, commsRank); + diclID_ = &diclID; OptionalDIPUGuard dipuGuard; @@ -473,7 +479,7 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( diclComms, blockingWait_, opTimeout_); if (timingEnabled_ && opType == OpType::ALLREDUCE) { - printCount_--; + samplingCount_--; for (const auto i : c10::irange(inputs.size())) { DIPUStream& diclStream = diclComms[i]->diclStream_; (*work->diclStartEvents_)[i].record(diclStream); @@ -527,11 +533,10 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( DIPUStream& diclStream = diclComms[i]->diclStream_; (*work->diclEndEvents_)[i].synchronize(); } - totalDuration_ += work->getDuration(); - if (printCount_ == 0) { - printInfo(static_cast(diclComms[0]->diclStream_.device_index())); - printCount_ = printFrequency_; - totalDuration_ = 0; + if (samplingCount_ == 0) { + printInfo(work->getDuration(), + static_cast(diclComms[0]->diclStream_.device_index())); + samplingCount_ = samplingInterval_; } } diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h index 6726ff275..8e830b8b5 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h @@ -334,15 +334,15 @@ class DIPU_API ProcessGroupDICL : public Backend { std::chrono::milliseconds opTimeout_ = kBackendDefaultTimeout; - void printInfo(int deviceID) const; + void printInfo(float duration, int deviceID) const; - int printFrequency_ = 100; + int samplingInterval_ = 100; - int printCount_; - - float totalDuration_ = 0.; + int samplingCount_; bool timingEnabled_ = false; + + commUniqueId* diclID_; }; namespace dicl_hook { From e47b0504451082087839ffbbd595e5c0dd3b0bab Mon Sep 17 00:00:00 2001 From: liuweiyu Date: Wed, 28 Aug 2024 16:44:42 +0800 Subject: [PATCH 09/11] Fix print bug --- .../csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp | 6 +++--- .../csrc_dipu/runtime/distributed/ProcessGroupDICL.h | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index e58a8567c..029ccb216 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -243,9 +243,8 @@ void ProcessGroupDICL::printInfo(float duration, int deviceID) const { std::ostringstream oss; oss << "PG uniqueId = "; for (int i = 0; i < 16; ++i) { - oss << static_cast(diclID_->internal[i]); + oss << static_cast(uniqueidVec_[i]); } - // oss << "PG uniqueId = " << diclID_->internal << ", rank =" << rank_ oss << ", rank =" << rank_ << ", deviceId = " << deviceID << ", duration = " << duration << std::endl; @@ -273,6 +272,7 @@ void ProcessGroupDICL::broadcastUniqueID(commUniqueId* uniqueId, auto vec = std::vector(reinterpret_cast(uniqueId), reinterpret_cast(uniqueId) + devapis::DICL_UNIQUE_ID_BYTES_SIZE); + uniqueidVec_ = vec; store_->set(storeKey, vec); } else { auto vec = store_->get(storeKey); @@ -281,6 +281,7 @@ void ProcessGroupDICL::broadcastUniqueID(commUniqueId* uniqueId, "Unexpected DICL unique ID length received " "from the store"); } + uniqueidVec_ = vec; std::memcpy(uniqueId, vec.data(), vec.size()); } } @@ -315,7 +316,6 @@ std::vector>& ProcessGroupDICL::getDICLComms( ? localCommsKey : std::to_string(diclCommCounter_++); broadcastUniqueID(&diclID, bcastKey, commsRank); - diclID_ = &diclID; OptionalDIPUGuard dipuGuard; diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h index 8e830b8b5..1fed65506 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h @@ -342,7 +342,7 @@ class DIPU_API ProcessGroupDICL : public Backend { bool timingEnabled_ = false; - commUniqueId* diclID_; + std::vector uniqueidVec_; }; namespace dicl_hook { From a6aa591398fa0fd4462bf5609a64aff6dfa5f722 Mon Sep 17 00:00:00 2001 From: liuweiyu Date: Wed, 28 Aug 2024 17:34:53 +0800 Subject: [PATCH 10/11] Add print information. --- .../csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp | 10 ++++++---- .../csrc_dipu/runtime/distributed/ProcessGroupDICL.h | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index 029ccb216..b2bb6489c 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -238,15 +238,16 @@ ProcessGroupDICL::ProcessGroupDICL(const c10::intrusive_ptr& store, ProcessGroupDICL::~ProcessGroupDICL() = default; -void ProcessGroupDICL::printInfo(float duration, int deviceID) const { +void ProcessGroupDICL::printInfo(float duration, int comm_size, + int deviceID) const { TORCH_CHECK(samplingCount_ == 0, "Print count hasn't reached 0 yet.") std::ostringstream oss; oss << "PG uniqueId = "; for (int i = 0; i < 16; ++i) { - oss << static_cast(uniqueidVec_[i]); + oss << static_cast(uniqueidVec_[i]); } - oss << ", rank =" << rank_ - << ", deviceId = " << deviceID + oss << ", rank =" << rank_ << ", deviceId = " << deviceID + << " comm_size = " << comm_size << " bytes " << ", duration = " << duration << std::endl; DIPU_LOG_INFO << oss.str(); } @@ -535,6 +536,7 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( } if (samplingCount_ == 0) { printInfo(work->getDuration(), + inputs[0].element_size() * inputs[0].numel(), static_cast(diclComms[0]->diclStream_.device_index())); samplingCount_ = samplingInterval_; } diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h index 1fed65506..535dad881 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h @@ -334,7 +334,7 @@ class DIPU_API ProcessGroupDICL : public Backend { std::chrono::milliseconds opTimeout_ = kBackendDefaultTimeout; - void printInfo(float duration, int deviceID) const; + void printInfo(float duration, int comm_size, int deviceID) const; int samplingInterval_ = 100; From b2b10874a838aad56fce3d99ab9406c992cebfee Mon Sep 17 00:00:00 2001 From: liuweiyu Date: Wed, 28 Aug 2024 18:12:57 +0800 Subject: [PATCH 11/11] Fix bug. --- .../runtime/distributed/ProcessGroupDICL.cpp | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index b2bb6489c..9fed59d76 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -247,7 +247,7 @@ void ProcessGroupDICL::printInfo(float duration, int comm_size, oss << static_cast(uniqueidVec_[i]); } oss << ", rank =" << rank_ << ", deviceId = " << deviceID - << " comm_size = " << comm_size << " bytes " + << ", comm_size = " << comm_size << " bytes" << ", duration = " << duration << std::endl; DIPU_LOG_INFO << oss.str(); } @@ -481,9 +481,11 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( if (timingEnabled_ && opType == OpType::ALLREDUCE) { samplingCount_--; - for (const auto i : c10::irange(inputs.size())) { - DIPUStream& diclStream = diclComms[i]->diclStream_; - (*work->diclStartEvents_)[i].record(diclStream); + if (samplingCount_ == 0) { + for (const auto i : c10::irange(inputs.size())) { + DIPUStream& diclStream = diclComms[i]->diclStream_; + (*work->diclStartEvents_)[i].record(diclStream); + } } } @@ -525,7 +527,7 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( work->future_->markCompleted(at::IValue(*work->outputs_)); } - if (timingEnabled_ && opType == OpType::ALLREDUCE) { + if (timingEnabled_ && opType == OpType::ALLREDUCE && samplingCount_ == 0) { for (const auto i : c10::irange(inputs.size())) { DIPUStream& diclStream = diclComms[i]->diclStream_; (*work->diclEndEvents_)[i].record(diclStream); @@ -534,12 +536,9 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( DIPUStream& diclStream = diclComms[i]->diclStream_; (*work->diclEndEvents_)[i].synchronize(); } - if (samplingCount_ == 0) { - printInfo(work->getDuration(), - inputs[0].element_size() * inputs[0].numel(), - static_cast(diclComms[0]->diclStream_.device_index())); - samplingCount_ = samplingInterval_; - } + printInfo(work->getDuration(), inputs[0].element_size() * inputs[0].numel(), + static_cast(diclComms[0]->diclStream_.device_index())); + samplingCount_ = samplingInterval_; } return work;