From b85191267dd3580b0dea462493df94e029eb0294 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Mon, 9 Sep 2024 20:41:20 +0000 Subject: [PATCH] Add block_until_ready before peforming save operation Add block_until_ready operation before checkpoint saving operation. --- MaxText/train.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/MaxText/train.py b/MaxText/train.py index 943822550..bd80c11ac 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -25,8 +25,9 @@ import sys from etils import epath import functools +import time -from typing import Sequence +from typing import Sequence, Optional from absl import app from flax import linen as nn from flax.linen import partitioning as nn_partitioning @@ -166,8 +167,30 @@ def clear_buffered_metrics(): _buffered_step = None _buffered_metrics = None -def save_checkpoint(checkpoint_manager, step, state, dataset_type="c4", data_iterator=None): - """Wrapper for saving checkpoint""" +def save_checkpoint( + checkpoint_manager, + step, + state, + dataset_type="c4", + data_iterator=None, + config: Optional[pyconfig.config] = None, +) -> bool: + """Wrapper for saving checkpoint.""" + if config and config.enable_checkpointing: + if (step % config.checkpoint_period == 0) or ( + config.enable_emergency_checkpoint + and step % config.local_checkpoint_period == 0 + ): + blocking_until_ready_start = time.time() + max_logging.log(f"Waiting for step {step} to finish before checkpoint...") + # We block here on the step finishing so that our checkpointing metrics + # measure only checkpointing time, not training time. + jax.block_until_ready(state) + max_logging.log( + f"Waited {time.time() - blocking_until_ready_start} seconds for step " + f"{step} to finish before starting checkpointing." + ) + # specify chunk_byte_size to force orbax to control maximum file size in checkpoint save_args = jax.tree.map( lambda _: orbax.checkpoint.SaveArgs(chunk_byte_size=_CHUNK_BYTE_SIZE), state @@ -617,7 +640,7 @@ def train_loop(config, state=None): last_step_completion = new_time if checkpoint_manager is not None: - if save_checkpoint(checkpoint_manager, int(step), state, config.dataset_type, data_iterator): + if save_checkpoint(checkpoint_manager, int(step), state, config.dataset_type, data_iterator, config): max_logging.log(f"saved a checkpoint at step {step}") # Upon preemption, exit when and only when all ongoing saves are complete.