From f39352b4a75ee6cd30223409bcc17ac6aa2bed63 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 1 Nov 2024 09:31:37 -0700 Subject: [PATCH] Make nemo.collections.llm PreTrainingDataModule num samples configurable (#11088) * Make nemo.collections.llm PreTrainingDataModule num samples configurable Signed-off-by: Hemil Desai * Apply isort and black reformatting Signed-off-by: hemildesai * Fix Signed-off-by: Hemil Desai * Apply isort and black reformatting Signed-off-by: hemildesai * Add explicit method to build pretraining datamodule index mapping Signed-off-by: Hemil Desai * Apply isort and black reformatting Signed-off-by: hemildesai * Fix Signed-off-by: Hemil Desai * Apply isort and black reformatting Signed-off-by: hemildesai * fix Signed-off-by: Hemil Desai * Apply isort and black reformatting Signed-off-by: hemildesai * PR feedback Signed-off-by: Hemil Desai --------- Signed-off-by: Hemil Desai Signed-off-by: hemildesai Co-authored-by: hemildesai --- nemo/collections/llm/gpt/data/__init__.py | 3 +- nemo/collections/llm/gpt/data/pre_training.py | 125 ++++++++++++++---- 2 files changed, 103 insertions(+), 25 deletions(-) diff --git a/nemo/collections/llm/gpt/data/__init__.py b/nemo/collections/llm/gpt/data/__init__.py index f4e97d91e5cd..92f73069fcc2 100644 --- a/nemo/collections/llm/gpt/data/__init__.py +++ b/nemo/collections/llm/gpt/data/__init__.py @@ -16,7 +16,7 @@ from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule +from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule, build_pretraining_datamodule from nemo.collections.llm.gpt.data.squad import SquadDataModule __all__ = [ @@ -25,5 +25,6 @@ "DollyDataModule", "MockDataModule", "PreTrainingDataModule", + "build_pretraining_datamodule", "HfDatasetDataModule", ] diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index 534922efe3a3..cfacde118b89 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -16,7 +16,7 @@ import os import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import pytorch_lightning as pl from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -77,7 +77,7 @@ def validate_dataset_asset_accessibility(paths): raise ValueError("Expected path to be of string or Path type.") path = Path(paths) - suffices = ('.bin', '.idx') + suffices = (".bin", ".idx") if path.is_dir(): if not os.access(path, os.R_OK): raise PermissionError(f"Expected {str(path)} to be readable.") @@ -133,6 +133,9 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin): to allocate to train, validation, and test sets, respectively. Unused if ``paths`` is a dict. index_mapping_dir (Optional[str]): Path to a directory to write index mapping files. num_dataset_builder_threads (int): The number of threads to use for dataset building. + num_train_samples (Optional[int]): The number of samples to use for training, defaults to total train steps times global batch size. + num_val_samples (Optional[int]): The number of samples to use for validation, defaults to total validation steps times global batch size. + num_test_samples (Optional[int]): The number of samples to use for testing, defaults to total test steps times global batch size. """ def __init__( @@ -154,6 +157,9 @@ def __init__( split: str = "900,50,50", index_mapping_dir: Optional[str] = None, num_dataset_builder_threads: int = 1, + num_train_samples: Optional[int] = None, + num_val_samples: Optional[int] = None, + num_test_samples: Optional[int] = None, ) -> None: super().__init__() if not isinstance(paths, (list, tuple, dict)): @@ -196,6 +202,9 @@ def __init__( self.index_mapping_dir = index_mapping_dir self.num_dataset_builder_threads = num_dataset_builder_threads self.init_global_step = 0 + self.num_train_samples = num_train_samples + self.num_val_samples = num_val_samples + self.num_test_samples = num_test_samples from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer @@ -207,27 +216,46 @@ def __init__( rampup_batch_size=rampup_batch_size, ) - def setup(self, stage: str = "") -> None: + def build( + self, + trainer_max_steps: int, + trainer_val_check_interval: int, + trainer_limit_val_batches: Union[int, float], + trainer_limit_test_batches: Union[int, float], + ): from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset - assert ( - hasattr(self, "trainer") and self.trainer is not None - ), "Setup should be completed when trainer and config are attached." + train_iters = trainer_max_steps + assert train_iters > 0, f"max_steps {train_iters} should be greater than 0" + num_train_samples = int(train_iters * self.data_sampler.global_batch_size) + + if self.num_train_samples is not None: + assert ( + self.num_train_samples >= num_train_samples + ), f"num_train_samples must be greater than or equal to {num_train_samples}." + num_train_samples = self.num_train_samples + train_iters = int(num_train_samples / self.data_sampler.global_batch_size) - # Trainer API - max_train_steps = self.trainer.max_steps - assert max_train_steps > 0, "Please specify trainer.max_steps" - eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches - test_iters = self.trainer.limit_test_batches - num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size) + eval_iters = (train_iters // trainer_val_check_interval + 1) * trainer_limit_val_batches num_val_samples = int(eval_iters * self.data_sampler.global_batch_size) + + test_iters = trainer_limit_test_batches num_test_samples = int(test_iters * self.data_sampler.global_batch_size) + if self.num_val_samples is not None: + assert self.num_val_samples > num_val_samples, f"num_val_samples must be greater than {num_val_samples}." + num_val_samples = self.num_val_samples + if self.num_test_samples is not None: + assert ( + self.num_test_samples > num_test_samples + ), f"num_test_samples must be greater than {num_test_samples}." + num_test_samples = self.num_test_samples + if ( - self.trainer.limit_val_batches > 0.0 - and self.trainer.limit_val_batches <= 1.0 - and isinstance(self.trainer.limit_val_batches, float) + trainer_limit_val_batches > 0.0 + and trainer_limit_val_batches <= 1.0 + and isinstance(trainer_limit_val_batches, float) ): assert "blend" not in self.build_kwargs, ( "When using a single data distribution, limit_val_batches <= 1.0 is not supported. If you'd " @@ -251,6 +279,18 @@ def setup(self, stage: str = "") -> None: config=self.gpt_dataset_config, ).build() + def setup(self, stage: str = "") -> None: + assert ( + hasattr(self, "trainer") and self.trainer is not None + ), "Setup should be completed when trainer and config are attached." + + self.build( + trainer_max_steps=self.trainer.max_steps, + trainer_val_check_interval=self.trainer.val_check_interval, + trainer_limit_val_batches=self.trainer.limit_val_batches, + trainer_limit_test_batches=self.trainer.limit_test_batches, + ) + # uncomment once fabric API is merged # def fabric_setup( # self, @@ -269,13 +309,13 @@ def setup(self, stage: str = "") -> None: # ).build() def train_dataloader(self) -> TRAIN_DATALOADERS: - return self._create_dataloader(self._train_ds, mode='train') + return self._create_dataloader(self._train_ds, mode="train") def val_dataloader(self) -> EVAL_DATALOADERS: - return self._create_dataloader(self._validation_ds, mode='validation') + return self._create_dataloader(self._validation_ds, mode="validation") def test_dataloader(self) -> EVAL_DATALOADERS: - return self._create_dataloader(self._test_ds, mode='test') + return self._create_dataloader(self._test_ds, mode="test") def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader: self.init_global_step = self.trainer.global_step @@ -286,7 +326,7 @@ def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader: num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, - collate_fn=getattr(dataset, 'collate_fn', data.dataloader.default_collate), + collate_fn=getattr(dataset, "collate_fn", data.dataloader.default_collate), **kwargs, ) return dataloader @@ -316,7 +356,7 @@ def state_dict(self) -> Dict[str, Any]: """ consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step) - return {'consumed_samples': consumed_samples} + return {"consumed_samples": consumed_samples} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """Called when loading a checkpoint, implement to reload datamodule state given datamodule stat @@ -332,7 +372,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") from apex.transformer.pipeline_parallel.utils import update_num_microbatches - consumed_samples = state_dict['consumed_samples'] + consumed_samples = state_dict["consumed_samples"] self.data_sampler.init_consumed_samples = consumed_samples self.data_sampler.prev_consumed_samples = consumed_samples @@ -344,9 +384,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def reconfigure_limit_batches(self): # Override limit_train_batches in terms of num of microbatches - self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_ds, 'train') + self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_ds, "train") # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step - self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_ds, 'val') + self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_ds, "val") def _reconfigure_limit_batches(self, limit_batches, dataloader, mode): """ @@ -388,10 +428,47 @@ def _reconfigure_limit_batches(self, limit_batches, dataloader, mode): else: limit_batches = limit_batches - limit_batches % get_num_microbatches() - if mode == 'train': + if mode == "train": self.trainer.limit_train_batches = limit_batches else: self.trainer.limit_val_batches = limit_batches # Override num sanity steps to be a multiple of num of microbatches self.trainer.num_sanity_val_steps *= get_num_microbatches() + + +def build_pretraining_datamodule( + datamodule: PreTrainingDataModule, + trainer_max_steps: int, + trainer_val_check_interval: int, + trainer_limit_val_batches: Union[int, float], + trainer_limit_test_batches: Union[int, float], +): + """ + Builds the index mapping cache for nemo.collections.llm.gpt.data.PreTrainingDataModule. + + Args: + datamodule (PreTrainingDataModule): The pre-training data module to build. + trainer_max_steps (int): The max_steps set in your trainer. + trainer_val_check_interval (int): The interval at which to perform validation in your trainer. + trainer_limit_val_batches (Union[int, float]): The number of validation batches to use in your trainer. + trainer_limit_test_batches (Union[int, float]): The number of test batches to use in your trainer. + + Returns: + None + """ + import torch.distributed as dist + + assert not dist.is_initialized(), "This function cannot be called inside an existing torch.distributed job." + # The indices in Megatron are built on rank 0, so we set the world size to 1 here. + dist.init_process_group(world_size=1, rank=0) + + from nemo.utils import logging + + logging.info(f"Building {datamodule}") + datamodule.build( + trainer_max_steps=trainer_max_steps, + trainer_val_check_interval=trainer_val_check_interval, + trainer_limit_val_batches=trainer_limit_val_batches, + trainer_limit_test_batches=trainer_limit_test_batches, + )