Skip to content

Commit

Permalink
Make nemo.collections.llm PreTrainingDataModule num samples configura…
Browse files Browse the repository at this point in the history
…ble (NVIDIA#11088)

* Make nemo.collections.llm PreTrainingDataModule num samples configurable

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

* Fix

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

* Add explicit method to build pretraining datamodule index mapping

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

* Fix

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

* fix

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

* PR feedback

Signed-off-by: Hemil Desai <[email protected]>

---------

Signed-off-by: Hemil Desai <[email protected]>
Signed-off-by: hemildesai <[email protected]>
Co-authored-by: hemildesai <[email protected]>
  • Loading branch information
2 people authored and HuiyingLi committed Nov 15, 2024
1 parent f7a88d1 commit f39352b
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 25 deletions.
3 changes: 2 additions & 1 deletion nemo/collections/llm/gpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule
from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule, build_pretraining_datamodule
from nemo.collections.llm.gpt.data.squad import SquadDataModule

__all__ = [
Expand All @@ -25,5 +25,6 @@
"DollyDataModule",
"MockDataModule",
"PreTrainingDataModule",
"build_pretraining_datamodule",
"HfDatasetDataModule",
]
125 changes: 101 additions & 24 deletions nemo/collections/llm/gpt/data/pre_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
Expand Down Expand Up @@ -77,7 +77,7 @@ def validate_dataset_asset_accessibility(paths):
raise ValueError("Expected path to be of string or Path type.")

path = Path(paths)
suffices = ('.bin', '.idx')
suffices = (".bin", ".idx")
if path.is_dir():
if not os.access(path, os.R_OK):
raise PermissionError(f"Expected {str(path)} to be readable.")
Expand Down Expand Up @@ -133,6 +133,9 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin):
to allocate to train, validation, and test sets, respectively. Unused if ``paths`` is a dict.
index_mapping_dir (Optional[str]): Path to a directory to write index mapping files.
num_dataset_builder_threads (int): The number of threads to use for dataset building.
num_train_samples (Optional[int]): The number of samples to use for training, defaults to total train steps times global batch size.
num_val_samples (Optional[int]): The number of samples to use for validation, defaults to total validation steps times global batch size.
num_test_samples (Optional[int]): The number of samples to use for testing, defaults to total test steps times global batch size.
"""

def __init__(
Expand All @@ -154,6 +157,9 @@ def __init__(
split: str = "900,50,50",
index_mapping_dir: Optional[str] = None,
num_dataset_builder_threads: int = 1,
num_train_samples: Optional[int] = None,
num_val_samples: Optional[int] = None,
num_test_samples: Optional[int] = None,
) -> None:
super().__init__()
if not isinstance(paths, (list, tuple, dict)):
Expand Down Expand Up @@ -196,6 +202,9 @@ def __init__(
self.index_mapping_dir = index_mapping_dir
self.num_dataset_builder_threads = num_dataset_builder_threads
self.init_global_step = 0
self.num_train_samples = num_train_samples
self.num_val_samples = num_val_samples
self.num_test_samples = num_test_samples

from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

Expand All @@ -207,27 +216,46 @@ def __init__(
rampup_batch_size=rampup_batch_size,
)

def setup(self, stage: str = "") -> None:
def build(
self,
trainer_max_steps: int,
trainer_val_check_interval: int,
trainer_limit_val_batches: Union[int, float],
trainer_limit_test_batches: Union[int, float],
):
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDataset

assert (
hasattr(self, "trainer") and self.trainer is not None
), "Setup should be completed when trainer and config are attached."
train_iters = trainer_max_steps
assert train_iters > 0, f"max_steps {train_iters} should be greater than 0"
num_train_samples = int(train_iters * self.data_sampler.global_batch_size)

if self.num_train_samples is not None:
assert (
self.num_train_samples >= num_train_samples
), f"num_train_samples must be greater than or equal to {num_train_samples}."
num_train_samples = self.num_train_samples
train_iters = int(num_train_samples / self.data_sampler.global_batch_size)

# Trainer API
max_train_steps = self.trainer.max_steps
assert max_train_steps > 0, "Please specify trainer.max_steps"
eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches
test_iters = self.trainer.limit_test_batches
num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size)
eval_iters = (train_iters // trainer_val_check_interval + 1) * trainer_limit_val_batches
num_val_samples = int(eval_iters * self.data_sampler.global_batch_size)

test_iters = trainer_limit_test_batches
num_test_samples = int(test_iters * self.data_sampler.global_batch_size)

if self.num_val_samples is not None:
assert self.num_val_samples > num_val_samples, f"num_val_samples must be greater than {num_val_samples}."
num_val_samples = self.num_val_samples
if self.num_test_samples is not None:
assert (
self.num_test_samples > num_test_samples
), f"num_test_samples must be greater than {num_test_samples}."
num_test_samples = self.num_test_samples

if (
self.trainer.limit_val_batches > 0.0
and self.trainer.limit_val_batches <= 1.0
and isinstance(self.trainer.limit_val_batches, float)
trainer_limit_val_batches > 0.0
and trainer_limit_val_batches <= 1.0
and isinstance(trainer_limit_val_batches, float)
):
assert "blend" not in self.build_kwargs, (
"When using a single data distribution, limit_val_batches <= 1.0 is not supported. If you'd "
Expand All @@ -251,6 +279,18 @@ def setup(self, stage: str = "") -> None:
config=self.gpt_dataset_config,
).build()

def setup(self, stage: str = "") -> None:
assert (
hasattr(self, "trainer") and self.trainer is not None
), "Setup should be completed when trainer and config are attached."

self.build(
trainer_max_steps=self.trainer.max_steps,
trainer_val_check_interval=self.trainer.val_check_interval,
trainer_limit_val_batches=self.trainer.limit_val_batches,
trainer_limit_test_batches=self.trainer.limit_test_batches,
)

# uncomment once fabric API is merged
# def fabric_setup(
# self,
Expand All @@ -269,13 +309,13 @@ def setup(self, stage: str = "") -> None:
# ).build()

def train_dataloader(self) -> TRAIN_DATALOADERS:
return self._create_dataloader(self._train_ds, mode='train')
return self._create_dataloader(self._train_ds, mode="train")

def val_dataloader(self) -> EVAL_DATALOADERS:
return self._create_dataloader(self._validation_ds, mode='validation')
return self._create_dataloader(self._validation_ds, mode="validation")

def test_dataloader(self) -> EVAL_DATALOADERS:
return self._create_dataloader(self._test_ds, mode='test')
return self._create_dataloader(self._test_ds, mode="test")

def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader:
self.init_global_step = self.trainer.global_step
Expand All @@ -286,7 +326,7 @@ def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader:
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=getattr(dataset, 'collate_fn', data.dataloader.default_collate),
collate_fn=getattr(dataset, "collate_fn", data.dataloader.default_collate),
**kwargs,
)
return dataloader
Expand Down Expand Up @@ -316,7 +356,7 @@ def state_dict(self) -> Dict[str, Any]:
"""
consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
return {'consumed_samples': consumed_samples}
return {"consumed_samples": consumed_samples}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint, implement to reload datamodule state given datamodule stat
Expand All @@ -332,7 +372,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import update_num_microbatches

consumed_samples = state_dict['consumed_samples']
consumed_samples = state_dict["consumed_samples"]
self.data_sampler.init_consumed_samples = consumed_samples
self.data_sampler.prev_consumed_samples = consumed_samples

Expand All @@ -344,9 +384,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

def reconfigure_limit_batches(self):
# Override limit_train_batches in terms of num of microbatches
self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_ds, 'train')
self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_ds, "train")
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_ds, 'val')
self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_ds, "val")

def _reconfigure_limit_batches(self, limit_batches, dataloader, mode):
"""
Expand Down Expand Up @@ -388,10 +428,47 @@ def _reconfigure_limit_batches(self, limit_batches, dataloader, mode):
else:
limit_batches = limit_batches - limit_batches % get_num_microbatches()

if mode == 'train':
if mode == "train":
self.trainer.limit_train_batches = limit_batches
else:
self.trainer.limit_val_batches = limit_batches

# Override num sanity steps to be a multiple of num of microbatches
self.trainer.num_sanity_val_steps *= get_num_microbatches()


def build_pretraining_datamodule(
datamodule: PreTrainingDataModule,
trainer_max_steps: int,
trainer_val_check_interval: int,
trainer_limit_val_batches: Union[int, float],
trainer_limit_test_batches: Union[int, float],
):
"""
Builds the index mapping cache for nemo.collections.llm.gpt.data.PreTrainingDataModule.
Args:
datamodule (PreTrainingDataModule): The pre-training data module to build.
trainer_max_steps (int): The max_steps set in your trainer.
trainer_val_check_interval (int): The interval at which to perform validation in your trainer.
trainer_limit_val_batches (Union[int, float]): The number of validation batches to use in your trainer.
trainer_limit_test_batches (Union[int, float]): The number of test batches to use in your trainer.
Returns:
None
"""
import torch.distributed as dist

assert not dist.is_initialized(), "This function cannot be called inside an existing torch.distributed job."
# The indices in Megatron are built on rank 0, so we set the world size to 1 here.
dist.init_process_group(world_size=1, rank=0)

from nemo.utils import logging

logging.info(f"Building {datamodule}")
datamodule.build(
trainer_max_steps=trainer_max_steps,
trainer_val_check_interval=trainer_val_check_interval,
trainer_limit_val_batches=trainer_limit_val_batches,
trainer_limit_test_batches=trainer_limit_test_batches,
)

0 comments on commit f39352b

Please sign in to comment.