Skip to content

Commit

Permalink
Fix bug where EventMetadata Id spaces were clashing when aggregating …
Browse files Browse the repository at this point in the history
…TPU device xplanes.

Add error logging when this is likely encountered.

PiperOrigin-RevId: 666010039
  • Loading branch information
bmass02 authored and copybara-github committed Aug 22, 2024
1 parent e7885c1 commit ba5aff4
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 85 deletions.
1 change: 1 addition & 0 deletions third_party/tsl/tsl/profiler/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ cc_library(
"//tsl/platform:types",
"//tsl/profiler/protobuf:xplane_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/meta:type_traits",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
Expand Down
9 changes: 9 additions & 0 deletions third_party/tsl/tsl/profiler/utils/xplane_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tsl/platform/types.h"
Expand Down Expand Up @@ -55,11 +56,19 @@ XPlaneBuilder::XPlaneBuilder(XPlane* plane)

XEventMetadata* XPlaneBuilder::GetOrCreateEventMetadata(int64_t metadata_id) {
XEventMetadata& metadata = (*plane_->mutable_event_metadata())[metadata_id];
LOG_IF_EVERY_N_SEC(ERROR, last_event_metadata_id_ != 0 && metadata.id() != 0,
1)
<< "Both overloads of GetOrCreateEventMetadata have been called on the "
"same XPlane which is forbidden.";
metadata.set_id(metadata_id);
return &metadata;
}

XEventMetadata* XPlaneBuilder::CreateEventMetadata() {
LOG_IF_EVERY_N_SEC(
ERROR, plane_->event_metadata_size() != last_event_metadata_id_, 1)
<< "Both overloads of GetOrCreateEventMetadata have been called on the "
"same XPlane which is forbidden.";
return GetOrCreateEventMetadata(++last_event_metadata_id_);
}

Expand Down
8 changes: 5 additions & 3 deletions third_party/tsl/tsl/profiler/utils/xplane_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -608,10 +608,12 @@ void AggregateXPlane(const XPlane& full_trace, XPlane& aggregated_trace) {
XLineBuilder aggregated_line = aggregated_plane.GetOrCreateLine(line_id);
for (const auto& [group_id, stat_by_event] : stats_by_group) {
for (const auto& [event_id, event_stat] : stat_by_event) {
const auto& src_event_metadata = *plane.GetEventMetadata(event_id);
XEventMetadata& event_metadata =
*aggregated_plane.GetOrCreateEventMetadata(event_id);
CopyEventMetadata(*plane.GetEventMetadata(event_id), plane,
event_metadata, aggregated_plane);
*aggregated_plane.GetOrCreateEventMetadata(
src_event_metadata.name());
CopyEventMetadata(src_event_metadata, plane, event_metadata,
aggregated_plane);
XEventBuilder aggregated_event =
aggregated_line.AddEvent(event_metadata);
aggregated_event.SetNumOccurrences(event_stat.stat.count());
Expand Down
180 changes: 98 additions & 82 deletions third_party/tsl/tsl/profiler/utils/xplane_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ using ::testing::UnorderedElementsAre;

#if defined(PLATFORM_GOOGLE)
using ::testing::EqualsProto;
using ::testing::proto::IgnoringFields;
using ::testing::proto::IgnoringRepeatedFieldOrdering;
using ::testing::proto::Partially;
#endif
Expand Down Expand Up @@ -394,36 +395,45 @@ TEST(XplaneUtilsTest, FindMutablePlanesWithPredicate) {
TEST(XplaneUtilsTest, TestAggregateXPlanes) {
XPlane xplane;
XPlaneBuilder builder(&xplane);
XEventMetadata* event_metadata1 = builder.GetOrCreateEventMetadata(1);
event_metadata1->set_name("EventMetadata1");
XEventMetadata* event_metadata2 = builder.GetOrCreateEventMetadata(2);
event_metadata2->set_name("EventMetadata2");
XEventMetadata* event_metadata3 = builder.GetOrCreateEventMetadata(3);
event_metadata3->set_name("EventMetadata3");
XEventMetadata* event_metadata4 = builder.GetOrCreateEventMetadata(4);
event_metadata4->set_name("EventMetadata4");

XLineBuilder line = builder.GetOrCreateLine(1);
auto& event_metadata1 = *builder.GetOrCreateEventMetadata("EventMetadata1");
auto& event_metadata2 = *builder.GetOrCreateEventMetadata("EventMetadata2");
auto& event_metadata3 = *builder.GetOrCreateEventMetadata("EventMetadata3");
auto& event_metadata4 = *builder.GetOrCreateEventMetadata("EventMetadata4");
auto& step_event_metadata1 =
*builder.GetOrCreateEventMetadata("StepEventMetadata1");
auto& step_event_metadata2 =
*builder.GetOrCreateEventMetadata("StepEventMetadata2");

XLineBuilder step_line = builder.GetOrCreateLine(1);
step_line.SetName(kStepLineName);
XEventBuilder step1 = step_line.AddEvent(step_event_metadata1);
step1.SetOffsetNs(0);
step1.SetDurationNs(10);
XEventBuilder step2 = step_line.AddEvent(step_event_metadata2);
step2.SetOffsetNs(10);
step2.SetDurationNs(10);

XLineBuilder line = builder.GetOrCreateLine(2);
line.SetName(kTensorFlowOpLineName);
XEventBuilder event1 = line.AddEvent(*event_metadata1);
XEventBuilder event1 = line.AddEvent(event_metadata1);
event1.SetOffsetNs(0);
event1.SetDurationNs(5);
XEventBuilder event3 = line.AddEvent(*event_metadata3);
XEventBuilder event3 = line.AddEvent(event_metadata3);
event3.SetOffsetNs(0);
event3.SetDurationNs(2);
XEventBuilder event2 = line.AddEvent(*event_metadata2);
XEventBuilder event2 = line.AddEvent(event_metadata2);
event2.SetOffsetNs(5);
event2.SetDurationNs(5);
XEventBuilder event4 = line.AddEvent(*event_metadata2);
XEventBuilder event4 = line.AddEvent(event_metadata2);
event4.SetOffsetNs(10);
event4.SetDurationNs(5);
XEventBuilder event5 = line.AddEvent(*event_metadata4);
XEventBuilder event5 = line.AddEvent(event_metadata4);
event5.SetOffsetNs(15);
event5.SetDurationNs(6);
XEventBuilder event6 = line.AddEvent(*event_metadata1);
XEventBuilder event6 = line.AddEvent(event_metadata1);
event6.SetOffsetNs(15);
event6.SetDurationNs(4);
XEventBuilder event7 = line.AddEvent(*event_metadata3);
XEventBuilder event7 = line.AddEvent(event_metadata3);
event7.SetOffsetNs(15);
event7.SetDurationNs(3);

Expand All @@ -433,71 +443,77 @@ TEST(XplaneUtilsTest, TestAggregateXPlanes) {
// Protobuf matchers are unavailable in OSS (b/169705709)
#if defined(PLATFORM_GOOGLE)
// TODO(b/238349654): Proto matcher are ineffective for XPlanes.
ASSERT_THAT(aggregated_xplane,
IgnoringRepeatedFieldOrdering(EqualsProto(
R"pb(lines {
id: 1
name: "Framework Ops"
events {
metadata_id: 1
duration_ps: 9000
stats { metadata_id: 2 int64_value: 4000 }
stats { metadata_id: 3 int64_value: 4000 }
num_occurrences: 2
}
events {
metadata_id: 3
duration_ps: 5000
stats { metadata_id: 2 int64_value: 2000 }
num_occurrences: 2
}
events {
metadata_id: 4
duration_ps: 6000
stats { metadata_id: 3 int64_value: 2000 }
num_occurrences: 1
}
events {
metadata_id: 2
duration_ps: 10000
stats { metadata_id: 2 int64_value: 5000 }
num_occurrences: 2
}
}
event_metadata {
key: 1
value { id: 1 name: "EventMetadata1" }
}
event_metadata {
key: 2
value { id: 2 name: "EventMetadata2" }
}
event_metadata {
key: 3
value { id: 3 name: "EventMetadata3" }
}
event_metadata {
key: 4
value { id: 4 name: "EventMetadata4" }
}
stat_metadata {
key: 1
value { id: 1 name: "total_profile_duration_ps" }
}
stat_metadata {
key: 2
value { id: 2 name: "min_duration_ps" }
}
stat_metadata {
key: 3
value { id: 3 name: "self_duration_ps" }
}
stat_metadata {
key: 4
value { id: 4 name: "group_id" }
}
stats { metadata_id: 1 uint64_value: 21000 }
)pb")));
ASSERT_THAT(
aggregated_xplane,
IgnoringFields(
{"tensorflow.profiler.XEvent.metadata_id",
"tensorflow.profiler.XPlane.event_metadata"},
IgnoringRepeatedFieldOrdering(EqualsProto(
R"pb(lines {
id: 1
name: "Steps"
events { metadata_id: 1 offset_ps: 0 duration_ps: 10000 }
events {
metadata_id: 2
offset_ps: 10000
duration_ps: 10000
}
}
lines {
id: 2
name: "Framework Ops"
events {
metadata_id: 3
duration_ps: 10000
stats { metadata_id: 2 int64_value: 5000 }
num_occurrences: 2
}
events {
metadata_id: 4
duration_ps: 5000
stats { metadata_id: 2 int64_value: 2000 }
num_occurrences: 2
}
events {
metadata_id: 5
duration_ps: 9000
stats { metadata_id: 2 int64_value: 4000 }
stats { metadata_id: 3 int64_value: 4000 }
num_occurrences: 2
}
events {
metadata_id: 6
duration_ps: 6000
stats { metadata_id: 3 int64_value: 2000 }
num_occurrences: 1
}
}
stat_metadata {
key: 1
value { id: 1 name: "total_profile_duration_ps" }
}
stat_metadata {
key: 2
value { id: 2 name: "min_duration_ps" }
}
stat_metadata {
key: 3
value { id: 3 name: "self_duration_ps" }
}
stat_metadata {
key: 4
value { id: 4 name: "group_id" }
}
stats { metadata_id: 1 uint64_value: 21000 }
)pb"))));
std::vector<std::string> event_metadata_names;
for (const auto& [id, event_metadata] : aggregated_xplane.event_metadata()) {
event_metadata_names.push_back(event_metadata.name());
}
EXPECT_THAT(event_metadata_names,
UnorderedElementsAre("EventMetadata1", "EventMetadata2",
"EventMetadata3", "EventMetadata4",
"StepEventMetadata1", "StepEventMetadata2"));
#endif
}

Expand Down

0 comments on commit ba5aff4

Please sign in to comment.