Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change dist ckpt defaults (#10913) #11031

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nemo/collections/llm/recipes/log/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.


from datetime import timedelta
from typing import Optional

from nemo_run import Config, cli
Expand Down Expand Up @@ -50,7 +51,7 @@ def default_log(
nl.ModelCheckpoint,
save_last=True,
save_top_k=10,
every_n_train_steps=200,
train_time_interval=Config(timedelta, minutes=15),
filename="{model_name}--{val_loss:.2f}-{step}-{consumed_samples}",
)

Expand Down
2 changes: 2 additions & 0 deletions nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer, dump_io: bool = True
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
trainer.save_checkpoint(ckpt_to_weights_subdir(output_path))
if getattr(trainer.strategy, "async_save", False):
trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True)

from nemo.lightning.io.pl import TrainerContext
from nemo.utils.get_rank import is_global_rank_zero
Expand Down
28 changes: 24 additions & 4 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import json
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple

Expand All @@ -27,6 +28,7 @@
from nemo.lightning.io.pl import ckpt_to_dir
from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
from nemo.utils import logging
from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO

if TYPE_CHECKING:
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
Expand Down Expand Up @@ -97,11 +99,28 @@ def __call__(self, model: nn.Module) -> nn.Module:
return model

def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
from nemo.lightning.pytorch.strategies.utils import create_checkpoint_io

super().setup(trainer, pl_module, stage=stage)

trainer.strategy.trainer = trainer
self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io, self)
trainer.strategy._checkpoint_io = self.wrapped_io
wrapped_io = partial(WrappedAdapterIO, peft=self)
ckpt_io_kwargs = {
"save_ckpt_format": trainer.strategy.save_ckpt_format,
"async_save": trainer.strategy.async_save,
"torch_dist_multiproc": trainer.strategy.torch_dist_multiproc,
"assume_constant_structure": trainer.strategy.assume_constant_structure,
"parallel_save": trainer.strategy.parallel_save,
"parallel_save_within_dp": trainer.strategy.parallel_save_within_dp,
"parallel_load": trainer.strategy.parallel_load,
"load_directly_on_device": trainer.strategy.load_directly_on_device,
}
trainer.strategy._checkpoint_io = create_checkpoint_io(wrapping_ckpt_io=wrapped_io, **ckpt_io_kwargs)
self.wrapped_io = (
trainer.strategy._checkpoint_io._checkpoint_io
if trainer.strategy.async_save
else trainer.strategy._checkpoint_io
)
trainer.strategy._init_model_parallel = False
trainer.strategy._setup_optimizers = False

Expand Down Expand Up @@ -257,7 +276,7 @@ def load_state_dict(self, state_dict, strict=True):
self.adapter.load_state_dict(adapter_state_dict, strict)


class WrappedAdapterIO(_WrappingCheckpointIO):
class WrappedAdapterIO(_WrappingCheckpointIO, AsyncCompatibleCheckpointIO):
peft: Optional[PEFT] = None
model_ckpt_path: Optional[Path] = None
adapter_ckpt_path: Optional[Path] = None
Expand All @@ -273,7 +292,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
checkpoint['sharded_state_dict'] = dict(
filter(lambda item: self.peft.adapter_key_filter(item[0]), checkpoint['sharded_state_dict'].items())
)
self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options)
request = self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options)

from nemo.utils.get_rank import is_global_rank_zero

Expand All @@ -282,6 +301,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
adapter_meta_path = ckpt_to_dir(path) / _ADAPTER_META_FILENAME
with open(adapter_meta_path, "w") as f:
json.dump(metadata, f)
return request

@override
def load_checkpoint(
Expand Down
8 changes: 4 additions & 4 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
save_ckpt_format (str): Distributed checkpoint format to use for checkpoint saving. Should be one of
'torch_dist' or 'zarr'. Defaults to 'torch_dist'.
ckpt_async_save (bool): Whether to save checkpoints asynchronously to reduce checkpointing overhead.
Defaults to False.
Defaults to True.
ckpt_torch_dist_multiproc (int): Number of extra processes per rank used during ckpt save
with PyTorch distributed format. Defaults to None.
ckpt_assume_constant_structure (bool): Allows caching some computation across checkpoint saves.
Expand All @@ -140,7 +140,7 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
ckpt_parallel_save_within_dp (bool): If true, save will be parallelized only within a DP group
(whole world otherwise), which might slightly reduce the save overhead. Defaults to False.
ckpt_parallel_load (bool): If true, each worker will load part of the dist checkpoint
and exchange with NCCL. Might use some extra GPU memory. Defaults to False.
and exchange with NCCL. Might use some extra GPU memory. Defaults to True.
ckpt_parallel_save_optim (bool): Parallel save/load of a DistributedOptimizer. 'True'
allows performant save and reshardable checkpoints. Set to 'False' only in order to minimize
the number of checkpoint files.
Expand Down Expand Up @@ -186,12 +186,12 @@ def __init__(
lazy_init: bool = False,
pipeline_dtype: Optional[torch.dtype] = None,
save_ckpt_format: str = "torch_dist",
ckpt_async_save: bool = False,
ckpt_async_save: bool = True,
ckpt_torch_dist_multiproc: int = None, ## TODO(ashors): put elsewhere?
ckpt_assume_constant_structure: bool = False,
ckpt_parallel_save: bool = True,
ckpt_parallel_save_within_dp: bool = False,
ckpt_parallel_load: bool = False,
ckpt_parallel_load: bool = True,
ckpt_parallel_save_optim: bool = True,
ckpt_load_directly_on_device: bool = True,
setup_optimizers: bool = True,
Expand Down
4 changes: 3 additions & 1 deletion nemo/lightning/pytorch/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,10 @@ def ckpt_to_dir(filepath: Union[str, Path]) -> Path:
return filepath


def create_checkpoint_io(**kwargs):
def create_checkpoint_io(wrapping_ckpt_io=None, **kwargs):
checkpoint_io = MegatronCheckpointIO(**kwargs)
if wrapping_ckpt_io:
checkpoint_io = wrapping_ckpt_io(checkpoint_io)
if kwargs.get("async_save", False):
checkpoint_io = AsyncFinalizableCheckpointIO(checkpoint_io)

Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/run/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ class PreemptionPlugin(run.Plugin):
preempt_time (int): The time, in seconds, before the task's time limit at which the executor
will send a SIGTERM preemption signal. This allows tasks to be gracefully
stopped before reaching their time limit, reducing waste and
promoting fair resource usage. The default value is 300 seconds (5 minutes).
promoting fair resource usage. The default value is 60 seconds (1 minute).
This is only supported for ``run.SlurmExecutor``.
callbacks (list[run.Config[Callback]]): A list of callback configurations that the plugin
will merge with the task's existing callbacks.
By default, the list includes NeMo's preemption callback.
"""

preempt_time: int = 300
preempt_time: int = 60
callbacks: list[run.Config[Callback]] = field(default_factory=lambda: [run.Config(PreemptionCallback)])

def setup(self, task: run.Partial | run.Script, executor: run.Executor):
Expand Down
1 change: 1 addition & 0 deletions tests/collections/llm/test_mnist_model_nemo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu():
monitor="val_loss",
save_top_k=1,
every_n_train_steps=5,
filename="{model_name}--{val_loss:.2f}-{step}-{consumed_samples}",
# Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
always_save_context=True,
)
Expand Down
4 changes: 3 additions & 1 deletion tests/lightning/pytorch/callbacks/test_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytorch_lightning.trainer.states import TrainerFn
from nemo.collections.llm import fn
from nemo.lightning.pytorch.callbacks.peft import PEFT, WrappedAdapterIO
from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO


class TestPEFT:
Expand Down Expand Up @@ -48,7 +49,8 @@ def test_peft_setup(self):
pl_module.model_transform = peft
peft.setup(trainer, pl_module, "fit")

assert isinstance(trainer.strategy._checkpoint_io, WrappedAdapterIO)
assert isinstance(trainer.strategy._checkpoint_io, AsyncFinalizableCheckpointIO)
assert isinstance(trainer.strategy._checkpoint_io._checkpoint_io, WrappedAdapterIO)
assert peft.model_transform is not None
assert peft._needs_to_call is True

Expand Down
1 change: 1 addition & 0 deletions tests/lightning/test_dist_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def set_env():
def _get_strategy():
strategy = nl.MegatronStrategy(
enable_nemo_ckpt_io=False,
ckpt_async_save=False,
)
return strategy

Expand Down
Loading