Skip to content

Commit

Permalink
Improve Jax distributed connect errors with possible user remedies.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666524115
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Aug 22, 2024
1 parent 7dcc8e9 commit 511141e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
12 changes: 11 additions & 1 deletion xla/pjrt/distributed/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,18 @@ absl::Status DistributedRuntimeCoordinationServiceClient::Connect() {
}
if (s.ok()) {
LOG(INFO) << "Connected to distributed JAX controller";
} else if (absl::IsDeadlineExceeded(s)) {
LOG(ERROR)
<< "Failed to connect to distributed JAX controller: waited too "
"long for some tasks to show up. This may be due to 1) some "
"tasks crashed earlier before connecting, 2) some tasks were never "
"scheduled, or 3) scheduling delays. Consider setting a longer "
"initialization timeout if such delays are expected, the timeout is "
"currently set to: "
<< absl::Milliseconds(config_.cluster_register_timeout_in_ms())
<< ".\n\nOriginal runtime error: " << s;
} else {
LOG(INFO) << "Failed to connect to distributed JAX controller: " << s;
LOG(ERROR) << "Failed to connect to distributed JAX controller: " << s;
}
return s;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ void CoordinationServiceStandaloneImpl::CheckBarrierTimeout() {
"Total Number of tasks already at the barrier: ",
barrier->tasks_at_barrier.size() - pending_task_count,
"/", barrier->tasks_at_barrier.size(),
". Timed out task names:\n%s", pending_tasks);
". Timed out task names:\n", pending_tasks);
}
const absl::Status error =
MakeCoordinationError(absl::DeadlineExceededError(error_message));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -944,8 +944,9 @@ void CoordinationServiceAgentImpl::WaitAtBarrierAsync(
leader_client_->BarrierAsync(
request.get(), response.get(),
[request, response, done = std::move(done)](const absl::Status& s) {
done(s);
VLOG(3) << "WaitAtBarrierResponse: " << s;
auto status = TrimCoordinationErrorMessage(s);
done(status);
VLOG(3) << "WaitAtBarrierResponse: " << status;
});
}

Expand Down

0 comments on commit 511141e

Please sign in to comment.