From b4b0899bc61d45bc71f9a220b5046f1fe5355c35 Mon Sep 17 00:00:00 2001 From: mikolajblaz Date: Fri, 29 Sep 2023 00:07:27 +0200 Subject: [PATCH] Avoid duplicated checkpoint save (#7555) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- nemo/collections/nlp/parts/nlp_overrides.py | 35 +++++++++++++-------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 5332a9d2f115..91b1026b6bfb 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -291,6 +291,10 @@ def save_checkpoint( checkpoint_dir = ckpt_to_dir(filepath) fs = get_filesystem(checkpoint_dir) + if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir): + logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving') + return + if is_global_rank_zero(): fs.makedirs(checkpoint_dir, exist_ok=True) @@ -477,19 +481,24 @@ def save_to(self, model, save_path: str): # model weights is a directory dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt)) fs = get_filesystem(dist_ckpt_dir) - if is_global_rank_zero(): - fs.makedirs(dist_ckpt_dir, exist_ok=True) - sharded_state_dict = model.sharded_state_dict() - # dist checkpoint needs torch.distributed to save the checkpoint - if parallel_state.is_unitialized(): - - def dummy(): - return - - if model.trainer.strategy.launcher is not None: - model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) - model.trainer.strategy.setup_environment() - dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=dist_ckpt_dir) + + if fs.isdir(dist_ckpt_dir) and dist_checkpointing.check_is_distributed_checkpoint(dist_ckpt_dir): + logging.info(f'Distributed checkpoint at path {dist_ckpt_dir} already exists, skipping saving') + else: + if is_global_rank_zero(): + fs.makedirs(dist_ckpt_dir, exist_ok=True) + + sharded_state_dict = model.sharded_state_dict() + # dist checkpoint needs torch.distributed to save the checkpoint + if parallel_state.is_unitialized(): + + def dummy(): + return + + if model.trainer.strategy.launcher is not None: + model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) + model.trainer.strategy.setup_environment() + dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=dist_ckpt_dir) else: