diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index b405a46f729f..11d4b6ce1da1 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -4,7 +4,7 @@ import pytorch_lightning as pl from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data -from torch.utils.data import DataLoader +from nemo.lightning.data import WrappedDataLoader from nemo.lightning.pytorch.plugins import MegatronDataSampler @@ -121,24 +121,26 @@ def setup(self, stage: str = "") -> None: # ).build() def train_dataloader(self) -> TRAIN_DATALOADERS: - return self._create_dataloader(self._train_ds) + return self._create_dataloader(self._train_ds, mode='train') def val_dataloader(self) -> EVAL_DATALOADERS: - return self._create_dataloader(self._validation_ds) + return self._create_dataloader(self._validation_ds, mode='validation') def test_dataloader(self) -> EVAL_DATALOADERS: - return self._create_dataloader(self._test_ds) + return self._create_dataloader(self._test_ds, mode='test') - def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader: self.init_global_step = self.trainer.global_step - return DataLoader( - dataset, + dataloader = WrappedDataLoader( + mode=mode, + dataset=dataset, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, collate_fn=getattr(dataset, 'collate_fn', data.dataloader.default_collate), **kwargs, ) + return dataloader @property def gpt_dataset_config(self) -> "GPTDatasetConfig": @@ -185,11 +187,53 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: consistency_check=False, ) current_global_batch_size = num_microbatch_calculator.current_global_batch_size - '''pl_module.log( - "global_batch_size", - current_global_batch_size, - prog_bar=True, - rank_zero_only=True, - batch_size=1, - )''' - self.if_first_step = 1 + self.data_sampler.if_first_step = 1 + + def reconfigure_limit_batches(self): + # Override limit_train_batches in terms of num of microbatches + self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_ds, 'train') + # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step + self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_ds, 'val') + + def _reconfigure_limit_batches(self, limit_batches, dataloader, mode): + """ + Reconfigure trainer.limit_val_batches for pretraining + """ + # Override limit_batches in terms of num microbatches and so there are limit_batches//num_micro_batches num of global batches + from megatron.core.num_microbatches_calculator import get_num_microbatches + + if isinstance(limit_batches, int): + limit_batches *= get_num_microbatches() + else: + assert isinstance(limit_batches, float) + # Don't reconfigure if limit_batches is 0.0 or if there's no dataloader + if limit_batches == 0.0 or dataloader is None: + return + # len(dataloader) returns len as num of microbatches + dl_len_in_micro_batches = len(dataloader) + if len(dataloader) != float("inf"): + if limit_batches == 1.0: + limit_batches = dl_len_in_micro_batches + else: + limit_micro_batches = int(dl_len_in_micro_batches * limit_batches) + if limit_micro_batches == 0 and limit_batches > 0.0: + min_percentage = 1.0 / len(dataloader) + raise MisconfigurationException( + f"You requested to check {limit_batches} of the val_dataloader but" + f" {limit_batches} * {len(dataloader)} < 1. Please increase the" + f" `limit_val_batches` argument. Try at least" + f" `limit_val_batches={min_percentage}`" + ) + # Make sure trainer.limit_val_batches is a multiple of num of microbatches + if limit_micro_batches < get_num_microbatches(): + limit_batches = get_num_microbatches() + else: + limit_batches = limit_batches - limit_batches % get_num_microbatches() + + if mode == 'train': + self.trainer.limit_train_batches = limit_batches + else: + self.trainer.limit_val_batches = limit_batches + + # Override num sanity steps to be a multiple of num of microbatches + self.trainer.num_sanity_val_steps *= get_num_microbatches() diff --git a/nemo/lightning/data.py b/nemo/lightning/data.py index 58ba81a4ddac..a07f504f1009 100644 --- a/nemo/lightning/data.py +++ b/nemo/lightning/data.py @@ -8,6 +8,7 @@ from torch.utils.data import DataLoader, Dataset +## TODO: remove? unused def create_dataloader( dataset: "Dataset", drop_last: bool = True, pad_samples_to_global_batch_size=False, **kwargs ) -> DataLoader: @@ -127,6 +128,14 @@ def add_megatron_sampler( ) +class WrappedDataLoader(DataLoader): + """Wrapper around torch DataLoader which stores the dataloader mode""" + + def __init__(self, mode="train", **dataloader_kwargs): + super().__init__(**dataloader_kwargs) + self.mode = mode + + # TODO: Replace this with megatron.core.data.data_samplers after we upgrade class BaseMegatronSampler: def __init__( diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py index 8d023d3bb574..9b2b317223ce 100644 --- a/nemo/lightning/pytorch/plugins/data_sampler.py +++ b/nemo/lightning/pytorch/plugins/data_sampler.py @@ -43,12 +43,13 @@ def setup(self, global_rank: int) -> None: def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0) -> DataLoader: from nemo.lightning.data import add_megatron_sampler + mode = getattr(dataloader, 'mode', 'train') return add_megatron_sampler( dataloader, micro_batch_size=self.micro_batch_size, global_batch_size=self.global_batch_size, rampup_batch_size=self.rampup_batch_size, - consumed_samples=self.init_consumed_samples, + consumed_samples=self.init_consumed_samples if mode == 'train' else 0, dataloader_type=self.dataloader_type, ) diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index a17bdd60c77c..2219324f6b67 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -207,6 +207,7 @@ def setup(self, trainer: pl.Trainer) -> None: if not self.data_sampler and hasattr(datamodule, "data_sampler"): self.data_sampler = datamodule.data_sampler self.data_sampler.setup(self.cluster_environment.global_rank()) + datamodule.reconfigure_limit_batches() if self.data_sampler: self.data_sampler.connect(trainer)