diff --git a/axlearn/cloud/gcp/tpu_health_check.py b/axlearn/cloud/gcp/tpu_health_check.py index 99e21e84..5f7c1899 100644 --- a/axlearn/cloud/gcp/tpu_health_check.py +++ b/axlearn/cloud/gcp/tpu_health_check.py @@ -17,7 +17,7 @@ The main API is the `setup` function, which is commonly enabled via context manager: ``` -with setup(spec, output_dir=...): +with setup(spec): # Initialize jax distributed. ``` """ @@ -33,7 +33,7 @@ from typing import Literal, Optional, Union import tensorflow as tf -from absl import logging +from absl import flags, logging from axlearn.cloud.gcp import tpu_health_check_main @@ -128,11 +128,11 @@ def _run_health_check_program( @contextmanager -def setup(check_spec: str, *, output_dir: str): - _pre_init_health_check(check_spec, output_dir=output_dir) +def setup(check_spec: str): + _pre_init_health_check(check_spec, output_dir=flags.FLAGS.trainer_dir) yield # Skip global health check if there's an exception. - global_health_check(check_spec, output_dir=output_dir) + global_health_check(check_spec, output_dir=flags.FLAGS.trainer_dir) def _pre_init_health_check(check_spec: str, *, output_dir: str):