From fb2a3e3e78bb5fbf1a551b15923d5d08f8c27d3c Mon Sep 17 00:00:00 2001 From: liuweiyu Date: Sat, 12 Oct 2024 09:55:48 +0000 Subject: [PATCH] Do not use the pool when measuring all_reduce time. --- .../torch_dipu/csrc_dipu/binding/ExportRT.cpp | 3 ++- .../csrc_dipu/runtime/core/DIPUEvent.h | 21 +++++++++++++++---- .../runtime/distributed/ProcessGroupDICL.cpp | 5 ++--- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp b/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp index a2a9f9f8e..f6ef6bbf0 100644 --- a/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp +++ b/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp @@ -205,7 +205,8 @@ void exportEvent(py::module& m) { py::arg("enable_timing") = false, py::arg("blocking") = false, py::arg("interprocess") = false) .def("record", py::overload_cast<>(&DIPUEvent::record), "record event") - .def("record", py::overload_cast(&DIPUEvent::record), + .def("record", py::overload_cast(&DIPUEvent::record), + py::arg("stream"), py::arg("use_pool") = true, "record event on stream") .def("elapsed_time", &dipu::DIPUEvent::elapsed_time) .def("synchronize", diff --git a/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUEvent.h b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUEvent.h index 1af20b840..3165215dd 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUEvent.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUEvent.h @@ -20,6 +20,7 @@ class DIPU_API DIPUEvent { deviceEvent_t event_{nullptr}; c10::DeviceIndex device_index_{-1}; c10::StreamId last_recorded_stream_id_{-1}; + bool use_pool_{true}; public: DIPUEvent(const DIPUEvent&) = delete; @@ -29,7 +30,8 @@ class DIPU_API DIPUEvent { constexpr DIPUEvent(DIPUEvent&& other) noexcept : event_(other.event_), device_index_(other.device_index_), - last_recorded_stream_id_(other.last_recorded_stream_id_) { + last_recorded_stream_id_(other.last_recorded_stream_id_), + use_pool_(other.use_pool_) { other.unsafe_reset(); } @@ -39,6 +41,7 @@ class DIPU_API DIPUEvent { event_ = other.event_; device_index_ = other.device_index_; last_recorded_stream_id_ = other.last_recorded_stream_id_; + use_pool_ = other.use_pool_; other.unsafe_reset(); } return *this; @@ -76,8 +79,9 @@ class DIPU_API DIPUEvent { void record() { record(getCurrentDIPUStream()); } - void record(const DIPUStream& stream) { + void record(const DIPUStream& stream, bool use_pool = true) { if (!initialized()) { + use_pool_ = use_pool; create_event(stream.device_index()); } @@ -124,14 +128,23 @@ class DIPU_API DIPUEvent { void create_event(c10::DeviceIndex device_index) { device_index_ = device_index; DIPUGuard guard(device_index_); - devproxy::createEvent(&event_); + if(use_pool_) { + devproxy::createEvent(&event_); + } else { + devapis::createEvent(&event_); + } } void release_event() { if (initialized()) { DIPUGuard guard(device_index_); - devproxy::destroyEvent(event_); + if(use_pool_) { + devproxy::destroyEvent(event_); + } else { + devapis::destroyEvent(event_); + } event_ = nullptr; + use_pool_ = true; } } }; diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index df8d90f02..8d06f0780 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -139,7 +139,7 @@ class WorkStore { std::lock_guard lock(mtx_); info_vec_.push_back(WorkInfo()); size_t index = info_vec_.size() - 1; - info_vec_[index].startEvent_.record(stream); + info_vec_[index].startEvent_.record(stream, false); info_vec_[index].rank_ = rank; info_vec_[index].comm_size_ = comm_size; @@ -148,12 +148,11 @@ class WorkStore { void recordEnd(const DIPUStream& stream, size_t index) { std::lock_guard lock(mtx_); - info_vec_[index].endEvent_.record(stream); + info_vec_[index].endEvent_.record(stream, false); } void dump(std::string& path) { for (auto& wi : info_vec_) { - wi.startEvent_.synchronize(); wi.endEvent_.synchronize(); float duration = wi.startEvent_.elapsed_time(wi.endEvent_); std::ostringstream oss;