Skip to content

Commit

Permalink
Do not use the pool when measuring all_reduce time.
Browse files Browse the repository at this point in the history
  • Loading branch information
lwysense committed Oct 12, 2024
1 parent bda6ef6 commit fb2a3e3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
3 changes: 2 additions & 1 deletion dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ void exportEvent(py::module& m) {
py::arg("enable_timing") = false, py::arg("blocking") = false,
py::arg("interprocess") = false)
.def("record", py::overload_cast<>(&DIPUEvent::record), "record event")
.def("record", py::overload_cast<const DIPUStream&>(&DIPUEvent::record),
.def("record", py::overload_cast<const DIPUStream&, bool>(&DIPUEvent::record),
py::arg("stream"), py::arg("use_pool") = true,
"record event on stream")
.def("elapsed_time", &dipu::DIPUEvent::elapsed_time)
.def("synchronize",
Expand Down
21 changes: 17 additions & 4 deletions dipu/torch_dipu/csrc_dipu/runtime/core/DIPUEvent.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class DIPU_API DIPUEvent {
deviceEvent_t event_{nullptr};
c10::DeviceIndex device_index_{-1};
c10::StreamId last_recorded_stream_id_{-1};
bool use_pool_{true};

public:
DIPUEvent(const DIPUEvent&) = delete;
Expand All @@ -29,7 +30,8 @@ class DIPU_API DIPUEvent {
constexpr DIPUEvent(DIPUEvent&& other) noexcept
: event_(other.event_),
device_index_(other.device_index_),
last_recorded_stream_id_(other.last_recorded_stream_id_) {
last_recorded_stream_id_(other.last_recorded_stream_id_),
use_pool_(other.use_pool_) {
other.unsafe_reset();
}

Expand All @@ -39,6 +41,7 @@ class DIPU_API DIPUEvent {
event_ = other.event_;
device_index_ = other.device_index_;
last_recorded_stream_id_ = other.last_recorded_stream_id_;
use_pool_ = other.use_pool_;
other.unsafe_reset();
}
return *this;
Expand Down Expand Up @@ -76,8 +79,9 @@ class DIPU_API DIPUEvent {

void record() { record(getCurrentDIPUStream()); }

void record(const DIPUStream& stream) {
void record(const DIPUStream& stream, bool use_pool = true) {
if (!initialized()) {
use_pool_ = use_pool;
create_event(stream.device_index());
}

Expand Down Expand Up @@ -124,14 +128,23 @@ class DIPU_API DIPUEvent {
void create_event(c10::DeviceIndex device_index) {
device_index_ = device_index;
DIPUGuard guard(device_index_);
devproxy::createEvent(&event_);
if(use_pool_) {
devproxy::createEvent(&event_);
} else {
devapis::createEvent(&event_);
}
}

void release_event() {
if (initialized()) {
DIPUGuard guard(device_index_);
devproxy::destroyEvent(event_);
if(use_pool_) {
devproxy::destroyEvent(event_);
} else {
devapis::destroyEvent(event_);
}
event_ = nullptr;
use_pool_ = true;
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class WorkStore {
std::lock_guard<std::mutex> lock(mtx_);
info_vec_.push_back(WorkInfo());
size_t index = info_vec_.size() - 1;
info_vec_[index].startEvent_.record(stream);
info_vec_[index].startEvent_.record(stream, false);
info_vec_[index].rank_ = rank;
info_vec_[index].comm_size_ = comm_size;

Expand All @@ -148,12 +148,11 @@ class WorkStore {

void recordEnd(const DIPUStream& stream, size_t index) {
std::lock_guard<std::mutex> lock(mtx_);
info_vec_[index].endEvent_.record(stream);
info_vec_[index].endEvent_.record(stream, false);
}

void dump(std::string& path) {
for (auto& wi : info_vec_) {
wi.startEvent_.synchronize();
wi.endEvent_.synchronize();
float duration = wi.startEvent_.elapsed_time(wi.endEvent_);
std::ostringstream oss;
Expand Down

0 comments on commit fb2a3e3

Please sign in to comment.