Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make nemo.collections.llm PreTrainingDataModule num samples configurable #11088

Merged
merged 11 commits into from
Nov 1, 2024
3 changes: 2 additions & 1 deletion nemo/collections/llm/gpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule
from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule, build_pretraining_datamodule
from nemo.collections.llm.gpt.data.squad import SquadDataModule

__all__ = [
Expand All @@ -25,5 +25,6 @@
"DollyDataModule",
"MockDataModule",
"PreTrainingDataModule",
"build_pretraining_datamodule",
"HfDatasetDataModule",
]
125 changes: 101 additions & 24 deletions nemo/collections/llm/gpt/data/pre_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
Expand Down Expand Up @@ -77,7 +77,7 @@ def validate_dataset_asset_accessibility(paths):
raise ValueError("Expected path to be of string or Path type.")

path = Path(paths)
suffices = ('.bin', '.idx')
suffices = (".bin", ".idx")
if path.is_dir():
if not os.access(path, os.R_OK):
raise PermissionError(f"Expected {str(path)} to be readable.")
Expand Down Expand Up @@ -133,6 +133,9 @@ class PreTrainingDataModule(pl.LightningDataModule, IOMixin):
to allocate to train, validation, and test sets, respectively. Unused if ``paths`` is a dict.
index_mapping_dir (Optional[str]): Path to a directory to write index mapping files.
num_dataset_builder_threads (int): The number of threads to use for dataset building.
num_train_samples (Optional[int]): The number of samples to use for training, defaults to total train steps times global batch size.
num_val_samples (Optional[int]): The number of samples to use for validation, defaults to total validation steps times global batch size.
num_test_samples (Optional[int]): The number of samples to use for testing, defaults to total test steps times global batch size.
"""

def __init__(
Expand All @@ -154,6 +157,9 @@ def __init__(
split: str = "900,50,50",
index_mapping_dir: Optional[str] = None,
num_dataset_builder_threads: int = 1,
num_train_samples: Optional[int] = None,
num_val_samples: Optional[int] = None,
num_test_samples: Optional[int] = None,
) -> None:
super().__init__()
if not isinstance(paths, (list, tuple, dict)):
Expand Down Expand Up @@ -196,6 +202,9 @@ def __init__(
self.index_mapping_dir = index_mapping_dir
self.num_dataset_builder_threads = num_dataset_builder_threads
self.init_global_step = 0
self.num_train_samples = num_train_samples
self.num_val_samples = num_val_samples
self.num_test_samples = num_test_samples

from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

Expand All @@ -207,27 +216,46 @@ def __init__(
rampup_batch_size=rampup_batch_size,
)

def setup(self, stage: str = "") -> None:
def build(
self,
trainer_max_steps: int,
trainer_val_check_interval: int,
trainer_limit_val_batches: Union[int, float],
trainer_limit_test_batches: Union[int, float],
):
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDataset

assert (
hasattr(self, "trainer") and self.trainer is not None
), "Setup should be completed when trainer and config are attached."
train_iters = trainer_max_steps
assert train_iters > 0, f"max_steps {train_iters} should be greater than 0"
num_train_samples = int(train_iters * self.data_sampler.global_batch_size)

if self.num_train_samples is not None:
assert (
self.num_train_samples >= num_train_samples
), f"num_train_samples must be greater than or equal to {num_train_samples}."
num_train_samples = self.num_train_samples
train_iters = int(num_train_samples / self.data_sampler.global_batch_size)

# Trainer API
max_train_steps = self.trainer.max_steps
assert max_train_steps > 0, "Please specify trainer.max_steps"
eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches
test_iters = self.trainer.limit_test_batches
num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size)
eval_iters = (train_iters // trainer_val_check_interval + 1) * trainer_limit_val_batches
num_val_samples = int(eval_iters * self.data_sampler.global_batch_size)

test_iters = trainer_limit_test_batches
num_test_samples = int(test_iters * self.data_sampler.global_batch_size)

if self.num_val_samples is not None:
assert self.num_val_samples > num_val_samples, f"num_val_samples must be greater than {num_val_samples}."
num_val_samples = self.num_val_samples
if self.num_test_samples is not None:
assert (
self.num_test_samples > num_test_samples
), f"num_test_samples must be greater than {num_test_samples}."
num_test_samples = self.num_test_samples

if (
self.trainer.limit_val_batches > 0.0
and self.trainer.limit_val_batches <= 1.0
and isinstance(self.trainer.limit_val_batches, float)
trainer_limit_val_batches > 0.0
and trainer_limit_val_batches <= 1.0
and isinstance(trainer_limit_val_batches, float)
):
assert "blend" not in self.build_kwargs, (
"When using a single data distribution, limit_val_batches <= 1.0 is not supported. If you'd "
Expand All @@ -251,6 +279,18 @@ def setup(self, stage: str = "") -> None:
config=self.gpt_dataset_config,
).build()

def setup(self, stage: str = "") -> None:
assert (
hasattr(self, "trainer") and self.trainer is not None
), "Setup should be completed when trainer and config are attached."

self.build(
trainer_max_steps=self.trainer.max_steps,
trainer_val_check_interval=self.trainer.val_check_interval,
trainer_limit_val_batches=self.trainer.limit_val_batches,
trainer_limit_test_batches=self.trainer.limit_test_batches,
)

# uncomment once fabric API is merged
# def fabric_setup(
# self,
Expand All @@ -269,13 +309,13 @@ def setup(self, stage: str = "") -> None:
# ).build()

def train_dataloader(self) -> TRAIN_DATALOADERS:
return self._create_dataloader(self._train_ds, mode='train')
return self._create_dataloader(self._train_ds, mode="train")

def val_dataloader(self) -> EVAL_DATALOADERS:
return self._create_dataloader(self._validation_ds, mode='validation')
return self._create_dataloader(self._validation_ds, mode="validation")

def test_dataloader(self) -> EVAL_DATALOADERS:
return self._create_dataloader(self._test_ds, mode='test')
return self._create_dataloader(self._test_ds, mode="test")

def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader:
self.init_global_step = self.trainer.global_step
Expand All @@ -286,7 +326,7 @@ def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader:
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=getattr(dataset, 'collate_fn', data.dataloader.default_collate),
collate_fn=getattr(dataset, "collate_fn", data.dataloader.default_collate),
**kwargs,
)
return dataloader
Expand Down Expand Up @@ -316,7 +356,7 @@ def state_dict(self) -> Dict[str, Any]:

"""
consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
return {'consumed_samples': consumed_samples}
return {"consumed_samples": consumed_samples}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint, implement to reload datamodule state given datamodule stat
Expand All @@ -332,7 +372,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import update_num_microbatches

consumed_samples = state_dict['consumed_samples']
consumed_samples = state_dict["consumed_samples"]
self.data_sampler.init_consumed_samples = consumed_samples
self.data_sampler.prev_consumed_samples = consumed_samples

Expand All @@ -344,9 +384,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

def reconfigure_limit_batches(self):
# Override limit_train_batches in terms of num of microbatches
self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_ds, 'train')
self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_ds, "train")
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_ds, 'val')
self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_ds, "val")

def _reconfigure_limit_batches(self, limit_batches, dataloader, mode):
"""
Expand Down Expand Up @@ -388,10 +428,47 @@ def _reconfigure_limit_batches(self, limit_batches, dataloader, mode):
else:
limit_batches = limit_batches - limit_batches % get_num_microbatches()

if mode == 'train':
if mode == "train":
self.trainer.limit_train_batches = limit_batches
else:
self.trainer.limit_val_batches = limit_batches

# Override num sanity steps to be a multiple of num of microbatches
self.trainer.num_sanity_val_steps *= get_num_microbatches()


def build_pretraining_datamodule(
datamodule: PreTrainingDataModule,
trainer_max_steps: int,
trainer_val_check_interval: int,
trainer_limit_val_batches: Union[int, float],
trainer_limit_test_batches: Union[int, float],
):
"""
Builds the index mapping cache for nemo.collections.llm.gpt.data.PreTrainingDataModule.

Args:
datamodule (PreTrainingDataModule): The pre-training data module to build.
trainer_max_steps (int): The max_steps set in your trainer.
trainer_val_check_interval (int): The interval at which to perform validation in your trainer.
trainer_limit_val_batches (Union[int, float]): The number of validation batches to use in your trainer.
trainer_limit_test_batches (Union[int, float]): The number of test batches to use in your trainer.

Returns:
None
"""
import torch.distributed as dist

assert not dist.is_initialized(), "This function cannot be called inside an existing torch.distributed job."
# The indices in Megatron are built on rank 0, so we set the world size to 1 here.
dist.init_process_group(world_size=1, rank=0)

from nemo.utils import logging

logging.info(f"Building {datamodule}")
datamodule.build(
trainer_max_steps=trainer_max_steps,
trainer_val_check_interval=trainer_val_check_interval,
trainer_limit_val_batches=trainer_limit_val_batches,
trainer_limit_test_batches=trainer_limit_test_batches,
)
Loading