Skip to content

Commit

Permalink
fix (apple#852)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 authored Dec 2, 2024
1 parent b1b6c25 commit c20387c
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions axlearn/cloud/gcp/tpu_health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
The main API is the `setup` function, which is commonly enabled via context manager:
```
with setup(spec, output_dir=...):
with setup(spec):
# Initialize jax distributed.
```
"""
Expand All @@ -33,7 +33,7 @@
from typing import Literal, Optional, Union

import tensorflow as tf
from absl import logging
from absl import flags, logging

from axlearn.cloud.gcp import tpu_health_check_main

Expand Down Expand Up @@ -128,11 +128,11 @@ def _run_health_check_program(


@contextmanager
def setup(check_spec: str, *, output_dir: str):
_pre_init_health_check(check_spec, output_dir=output_dir)
def setup(check_spec: str):
_pre_init_health_check(check_spec, output_dir=flags.FLAGS.trainer_dir)
yield
# Skip global health check if there's an exception.
global_health_check(check_spec, output_dir=output_dir)
global_health_check(check_spec, output_dir=flags.FLAGS.trainer_dir)


def _pre_init_health_check(check_spec: str, *, output_dir: str):
Expand Down

0 comments on commit c20387c

Please sign in to comment.