Skip to content

Commit

Permalink
Configure no restart validation loop in nl.Trainer (NVIDIA#11029)
Browse files Browse the repository at this point in the history
* Configure no restart validation loop in nl.Trainer

Signed-off-by: Hemil Desai <[email protected]>

* fix

Signed-off-by: Hemil Desai <[email protected]>

* Skip validation whenever restarting=True

Signed-off-by: Hemil Desai <[email protected]>

* PR feedback

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

---------

Signed-off-by: Hemil Desai <[email protected]>
Signed-off-by: hemildesai <[email protected]>
Co-authored-by: hemildesai <[email protected]>
  • Loading branch information
hemildesai and hemildesai authored Nov 13, 2024
1 parent 071f8bc commit 02f0932
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
10 changes: 9 additions & 1 deletion nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@
from typing_extensions import Annotated

import nemo.lightning as nl
from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io
from nemo.lightning import (
AutoResume,
NeMoLogger,
OptimizerModule,
Trainer,
configure_no_restart_validation_training_loop,
io,
)
from nemo.lightning.base import NEMO_MODELS_CACHE
from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform
from nemo.utils import logging
Expand Down Expand Up @@ -680,6 +687,7 @@ def _setup(
tokenizer: Optional[TokenizerType],
model_transform: Optional[Union[PEFT, ModelTransform, Callable]],
) -> Any: # Return type is Any because app_state's type is not specified
configure_no_restart_validation_training_loop(trainer)
_log = log or NeMoLogger()
if resume and isinstance(model_transform, PEFT) and _log.ckpt:
logging.info("Disabling try_restore_best_ckpt restoration for adapters")
Expand Down
3 changes: 2 additions & 1 deletion nemo/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler
from nemo.lightning.pytorch.strategies import FSDPStrategy, MegatronStrategy
from nemo.lightning.pytorch.strategies.utils import RestoreConfig
from nemo.lightning.pytorch.trainer import Trainer
from nemo.lightning.pytorch.trainer import Trainer, configure_no_restart_validation_training_loop
from nemo.lightning.resume import AutoResume


Expand Down Expand Up @@ -66,6 +66,7 @@ def _is_slurm_interactive_mode():
"ModelCheckpoint",
"OptimizerModule",
"Trainer",
"configure_no_restart_validation_training_loop",
"get_vocab_size",
"teardown",
]
37 changes: 36 additions & 1 deletion nemo/lightning/pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from copy import deepcopy

import fiddle as fdl
import pytorch_lightning as pl
from pytorch_lightning.loops import _TrainingEpochLoop
from pytorch_lightning.loops.fetchers import _DataFetcher
from typing_extensions import Self

from nemo.lightning.fabric.conversion import to_fabric
from nemo.lightning.fabric.fabric import Fabric
from nemo.lightning.io.mixin import IOMixin, serialization, track_io


class Trainer(pl.Trainer, IOMixin):
class NoValOnRestartTrainingLoop(_TrainingEpochLoop):
"""
Extend the PTL Epoch loop to skip validation when restarting.
This happens when resuming a checkpoint that has already run validation, but loading restores
the training state before validation has run.
"""

def _should_check_val_fx(self, data_fetcher) -> bool:
if self.skip_val_on_restart:
return False
return super()._should_check_val_fx(data_fetcher)

def load_state_dict(self, state_dict: dict, prefix: str = "") -> None:
super().load_state_dict(state_dict, prefix)

self.skip_val_on_restart = True

def advance(self, data_fetcher: _DataFetcher) -> None:
super().advance(data_fetcher)

self.skip_val_on_restart = False


def configure_no_restart_validation_training_loop(trainer: pl.Trainer) -> None:
if not isinstance(trainer.fit_loop.epoch_loop, _TrainingEpochLoop):
warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning)
return

## Pass trainer object to avoid trainer getting overwritten as None
loop = NoValOnRestartTrainingLoop(trainer, trainer.min_steps, trainer.max_steps)
trainer.fit_loop.epoch_loop = loop


class Trainer(pl.Trainer, IOMixin):
def add_io(self, obj):
"""Recurse to the leaves of a container and add io functionality to non-serializable leaves"""
if isinstance(obj, (dict, list)):
Expand Down

0 comments on commit 02f0932

Please sign in to comment.