diff --git a/nemo/collections/llm/recipes/log/default.py b/nemo/collections/llm/recipes/log/default.py index 93bd9f9470fa..d83580a1a543 100644 --- a/nemo/collections/llm/recipes/log/default.py +++ b/nemo/collections/llm/recipes/log/default.py @@ -13,6 +13,7 @@ # limitations under the License. +from datetime import timedelta from typing import Optional from nemo_run import Config, cli @@ -50,7 +51,7 @@ def default_log( nl.ModelCheckpoint, save_last=True, save_top_k=10, - every_n_train_steps=200, + train_time_interval=Config(timedelta, minutes=15), filename="{model_name}--{val_loss:.2f}-{step}-{consumed_samples}", ) diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py index 3ccbef536b99..41ce2d8f1117 100644 --- a/nemo/lightning/io/connector.py +++ b/nemo/lightning/io/connector.py @@ -183,6 +183,8 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer, dump_io: bool = True output_path = Path(output_path) output_path.mkdir(parents=True, exist_ok=True) trainer.save_checkpoint(ckpt_to_weights_subdir(output_path)) + if getattr(trainer.strategy, "async_save", False): + trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True) from nemo.lightning.io.pl import TrainerContext from nemo.utils.get_rank import is_global_rank_zero diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index 1e3cde0bbcde..15d0dd8ac2ab 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -14,6 +14,7 @@ import json from abc import ABC, abstractmethod +from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple @@ -27,6 +28,7 @@ from nemo.lightning.io.pl import ckpt_to_dir from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform from nemo.utils import logging +from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO if TYPE_CHECKING: from megatron.core.dist_checkpointing.mapping import ShardedStateDict @@ -97,11 +99,28 @@ def __call__(self, model: nn.Module) -> nn.Module: return model def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: + from nemo.lightning.pytorch.strategies.utils import create_checkpoint_io + super().setup(trainer, pl_module, stage=stage) trainer.strategy.trainer = trainer - self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io, self) - trainer.strategy._checkpoint_io = self.wrapped_io + wrapped_io = partial(WrappedAdapterIO, peft=self) + ckpt_io_kwargs = { + "save_ckpt_format": trainer.strategy.save_ckpt_format, + "async_save": trainer.strategy.async_save, + "torch_dist_multiproc": trainer.strategy.torch_dist_multiproc, + "assume_constant_structure": trainer.strategy.assume_constant_structure, + "parallel_save": trainer.strategy.parallel_save, + "parallel_save_within_dp": trainer.strategy.parallel_save_within_dp, + "parallel_load": trainer.strategy.parallel_load, + "load_directly_on_device": trainer.strategy.load_directly_on_device, + } + trainer.strategy._checkpoint_io = create_checkpoint_io(wrapping_ckpt_io=wrapped_io, **ckpt_io_kwargs) + self.wrapped_io = ( + trainer.strategy._checkpoint_io._checkpoint_io + if trainer.strategy.async_save + else trainer.strategy._checkpoint_io + ) trainer.strategy._init_model_parallel = False trainer.strategy._setup_optimizers = False @@ -257,7 +276,7 @@ def load_state_dict(self, state_dict, strict=True): self.adapter.load_state_dict(adapter_state_dict, strict) -class WrappedAdapterIO(_WrappingCheckpointIO): +class WrappedAdapterIO(_WrappingCheckpointIO, AsyncCompatibleCheckpointIO): peft: Optional[PEFT] = None model_ckpt_path: Optional[Path] = None adapter_ckpt_path: Optional[Path] = None @@ -273,7 +292,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio checkpoint['sharded_state_dict'] = dict( filter(lambda item: self.peft.adapter_key_filter(item[0]), checkpoint['sharded_state_dict'].items()) ) - self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options) + request = self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options) from nemo.utils.get_rank import is_global_rank_zero @@ -282,6 +301,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio adapter_meta_path = ckpt_to_dir(path) / _ADAPTER_META_FILENAME with open(adapter_meta_path, "w") as f: json.dump(metadata, f) + return request @override def load_checkpoint( diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index e68b67c86f2d..c367353fbb58 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -130,7 +130,7 @@ class MegatronStrategy(DDPStrategy, io.IOMixin): save_ckpt_format (str): Distributed checkpoint format to use for checkpoint saving. Should be one of 'torch_dist' or 'zarr'. Defaults to 'torch_dist'. ckpt_async_save (bool): Whether to save checkpoints asynchronously to reduce checkpointing overhead. - Defaults to False. + Defaults to True. ckpt_torch_dist_multiproc (int): Number of extra processes per rank used during ckpt save with PyTorch distributed format. Defaults to None. ckpt_assume_constant_structure (bool): Allows caching some computation across checkpoint saves. @@ -140,7 +140,7 @@ class MegatronStrategy(DDPStrategy, io.IOMixin): ckpt_parallel_save_within_dp (bool): If true, save will be parallelized only within a DP group (whole world otherwise), which might slightly reduce the save overhead. Defaults to False. ckpt_parallel_load (bool): If true, each worker will load part of the dist checkpoint - and exchange with NCCL. Might use some extra GPU memory. Defaults to False. + and exchange with NCCL. Might use some extra GPU memory. Defaults to True. ckpt_parallel_save_optim (bool): Parallel save/load of a DistributedOptimizer. 'True' allows performant save and reshardable checkpoints. Set to 'False' only in order to minimize the number of checkpoint files. @@ -186,12 +186,12 @@ def __init__( lazy_init: bool = False, pipeline_dtype: Optional[torch.dtype] = None, save_ckpt_format: str = "torch_dist", - ckpt_async_save: bool = False, + ckpt_async_save: bool = True, ckpt_torch_dist_multiproc: int = None, ## TODO(ashors): put elsewhere? ckpt_assume_constant_structure: bool = False, ckpt_parallel_save: bool = True, ckpt_parallel_save_within_dp: bool = False, - ckpt_parallel_load: bool = False, + ckpt_parallel_load: bool = True, ckpt_parallel_save_optim: bool = True, ckpt_load_directly_on_device: bool = True, setup_optimizers: bool = True, diff --git a/nemo/lightning/pytorch/strategies/utils.py b/nemo/lightning/pytorch/strategies/utils.py index 415392f2bef0..150fc14726ec 100644 --- a/nemo/lightning/pytorch/strategies/utils.py +++ b/nemo/lightning/pytorch/strategies/utils.py @@ -127,8 +127,10 @@ def ckpt_to_dir(filepath: Union[str, Path]) -> Path: return filepath -def create_checkpoint_io(**kwargs): +def create_checkpoint_io(wrapping_ckpt_io=None, **kwargs): checkpoint_io = MegatronCheckpointIO(**kwargs) + if wrapping_ckpt_io: + checkpoint_io = wrapping_ckpt_io(checkpoint_io) if kwargs.get("async_save", False): checkpoint_io = AsyncFinalizableCheckpointIO(checkpoint_io) diff --git a/nemo/lightning/run/plugins.py b/nemo/lightning/run/plugins.py index dfcc7c1650ce..c9a38c5979ca 100644 --- a/nemo/lightning/run/plugins.py +++ b/nemo/lightning/run/plugins.py @@ -52,14 +52,14 @@ class PreemptionPlugin(run.Plugin): preempt_time (int): The time, in seconds, before the task's time limit at which the executor will send a SIGTERM preemption signal. This allows tasks to be gracefully stopped before reaching their time limit, reducing waste and - promoting fair resource usage. The default value is 300 seconds (5 minutes). + promoting fair resource usage. The default value is 60 seconds (1 minute). This is only supported for ``run.SlurmExecutor``. callbacks (list[run.Config[Callback]]): A list of callback configurations that the plugin will merge with the task's existing callbacks. By default, the list includes NeMo's preemption callback. """ - preempt_time: int = 300 + preempt_time: int = 60 callbacks: list[run.Config[Callback]] = field(default_factory=lambda: [run.Config(PreemptionCallback)]) def setup(self, task: run.Partial | run.Script, executor: run.Executor): diff --git a/tests/collections/llm/test_mnist_model_nemo2.py b/tests/collections/llm/test_mnist_model_nemo2.py index 3f0b804e8bd6..a5c2aa96fc03 100644 --- a/tests/collections/llm/test_mnist_model_nemo2.py +++ b/tests/collections/llm/test_mnist_model_nemo2.py @@ -501,6 +501,7 @@ def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(): monitor="val_loss", save_top_k=1, every_n_train_steps=5, + filename="{model_name}--{val_loss:.2f}-{step}-{consumed_samples}", # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe always_save_context=True, ) diff --git a/tests/lightning/pytorch/callbacks/test_peft.py b/tests/lightning/pytorch/callbacks/test_peft.py index 53f9016a3bac..95caca4d2784 100644 --- a/tests/lightning/pytorch/callbacks/test_peft.py +++ b/tests/lightning/pytorch/callbacks/test_peft.py @@ -18,6 +18,7 @@ from pytorch_lightning.trainer.states import TrainerFn from nemo.collections.llm import fn from nemo.lightning.pytorch.callbacks.peft import PEFT, WrappedAdapterIO +from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO class TestPEFT: @@ -48,7 +49,8 @@ def test_peft_setup(self): pl_module.model_transform = peft peft.setup(trainer, pl_module, "fit") - assert isinstance(trainer.strategy._checkpoint_io, WrappedAdapterIO) + assert isinstance(trainer.strategy._checkpoint_io, AsyncFinalizableCheckpointIO) + assert isinstance(trainer.strategy._checkpoint_io._checkpoint_io, WrappedAdapterIO) assert peft.model_transform is not None assert peft._needs_to_call is True diff --git a/tests/lightning/test_dist_ckpt.py b/tests/lightning/test_dist_ckpt.py index e6ea381fdf0b..5deb8085aa30 100644 --- a/tests/lightning/test_dist_ckpt.py +++ b/tests/lightning/test_dist_ckpt.py @@ -35,6 +35,7 @@ def set_env(): def _get_strategy(): strategy = nl.MegatronStrategy( enable_nemo_ckpt_io=False, + ckpt_async_save=False, ) return strategy