Skip to content

Commit

Permalink
Ensure HLO instruction to_apply has matching execution thread.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 660553975
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Aug 7, 2024
1 parent be3181c commit 5f4c02b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
9 changes: 9 additions & 0 deletions xla/service/hlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2956,6 +2956,15 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
}
}

if (instruction->has_to_apply() &&
instruction->to_apply()->execution_thread() !=
instruction->parent()->execution_thread()) {
return Internal(
"%s top_apply computation execution thread does not match (%s vs %s)",
instruction->name(), instruction->to_apply()->execution_thread(),
instruction->parent()->execution_thread());
}

return absl::OkStatus();
}

Expand Down
16 changes: 8 additions & 8 deletions xla/service/hlo_verifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ TEST_F(HloVerifierTest, CheckCallThreadMismatch) {
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
HasSubstr("expects parent computation thread name same as called "
"computation's thread name"));
HasSubstr("mycall top_apply computation execution thread does "
"not match (parallel_thread vs main)"));
}

TEST_F(HloVerifierTest, CompositeCall) {
Expand Down Expand Up @@ -2165,10 +2165,10 @@ TEST_F(HloVerifierTest, FusionNestedComputationThreadVerifier) {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(verifier().Run(module.get()).status().message(),
HasSubstr("Nested computations expects same computation's thread "
"name: parallel_thread vs main, in called computation "
"`add` vs caller computation `fused_computation`"));
EXPECT_THAT(
verifier().Run(module.get()).status().message(),
HasSubstr("crs0 top_apply computation execution thread does not match "
"(parallel_thread vs main)"));
}

TEST_F(HloVerifierTest, AllReduceVerifier) {
Expand Down Expand Up @@ -2804,8 +2804,8 @@ TEST_F(HloVerifierTest, VerifyCustomCallThread) {
.status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
HasSubstr("expects parent computation thread name same as called "
"computation's thread name"));
HasSubstr("custom top_apply computation execution thread does "
"not match (parallel_thread vs main)"));
}

TEST_F(HloVerifierTest, CheckWhileThread) {
Expand Down

0 comments on commit 5f4c02b

Please sign in to comment.