Skip to content

Commit

Permalink
[NeMo UX] Support generating datasets using different train/valid/tes…
Browse files Browse the repository at this point in the history
…t distributions (NVIDIA#9771) (NVIDIA#9841)

* support building train/valid/test datasets from separate distributions

* add minimal test

* Apply isort and black reformatting

* set limit_val_batches for nemo 2 example

* improve assert statement

* Apply isort and black reformatting

---------

Signed-off-by: ashors1 <[email protected]>
Signed-off-by: ashors1 <[email protected]>
Co-authored-by: Anna Shors <[email protected]>
Co-authored-by: ashors1 <[email protected]>
Signed-off-by: Vivian Chen <[email protected]>
  • Loading branch information
3 people authored and Vivian Chen committed Aug 1, 2024
1 parent 5518c72 commit d8701ae
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 12 deletions.
1 change: 1 addition & 0 deletions examples/llm/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def get_args():
logger=loggers,
callbacks=callbacks,
log_every_n_steps=1,
limit_val_batches=2,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed", amp_O2=False),
)

Expand Down
48 changes: 36 additions & 12 deletions nemo/collections/llm/gpt/data/pre_training.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional

Expand All @@ -17,8 +18,7 @@
class PreTrainingDataModule(pl.LightningDataModule):
def __init__(
self,
paths: Path | List[Path],
weights: Optional[List[float]] = None,
paths: Path | List | Dict[str, List],
seq_length: int = 2048,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
Expand All @@ -38,16 +38,30 @@ def __init__(
index_mapping_dir: Optional[str] = None,
) -> None:
super().__init__()
if not isinstance(paths, (list, tuple)):
if not isinstance(paths, (list, tuple, dict)):
paths = [paths]
if weights is not None:
assert len(weights) == len(paths)
if len(weights) == 1:
# weights must be None if there is only one dataset

from megatron.core.datasets.utils import get_blend_from_list

build_kwargs = {}
if isinstance(paths, dict):
if split is not None:
warnings.warn(
f"{split=} will be ignored since datasets are being created " f"from 3 separate distributions."
)
build_kwargs["blend_per_split"] = [
get_blend_from_list(paths["train"]),
get_blend_from_list(paths["validation"]),
get_blend_from_list(paths["test"]),
]
else:
paths, weights = get_blend_from_list(paths)
if len(paths) == 1:
weights = None
build_kwargs["blend"] = [paths, weights]
build_kwargs["split"] = split

self.paths = paths
self.weights = weights
self.build_kwargs = build_kwargs
self.seq_length = seq_length
self.tokenizer = tokenizer
self.num_train_samples = num_train_samples
Expand Down Expand Up @@ -92,8 +106,19 @@ def setup(self, stage: str = "") -> None:
num_test_samples = int(test_iters * self.data_sampler.global_batch_size)

if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float):
assert "blend" not in self.build_kwargs, (
"When using a single data distribution, limit_val_batches <= 1.0 is not supported. If you'd "
"like to run with a fractional value of limit_val_batches, please pass in separate datasets for "
"the train, validation, and test datasets by providing a dictionary of paths, e.g.: \n"
" paths={ \n "
" 'train': [PATHS FOR TRAIN], \n "
" 'validation': [PATHS FOR VALIDATION], \n "
" 'test' :[PATHS FOR TEST], \n"
" }"
)

# This is to make sure we only have one epoch on every validation iteration
num_val_samples = None if self.weights is None else 1
num_val_samples = None

train_valid_test_num_samples = [num_train_samples, num_val_samples, num_test_samples]
self._train_ds, self._validation_ds, self._test_ds = BlendedMegatronDatasetBuilder(
Expand Down Expand Up @@ -147,15 +172,14 @@ def gpt_dataset_config(self) -> "GPTDatasetConfig":
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig

return GPTDatasetConfig(
blend=[[str(path) for path in self.paths], self.weights],
random_seed=self.seed,
sequence_length=self.seq_length,
tokenizer=self.tokenizer,
split=self.split,
path_to_cache=self.index_mapping_dir,
reset_position_ids=self.reset_position_ids,
reset_attention_mask=self.reset_attention_mask,
eod_mask_loss=self.eod_mask_loss,
**self.build_kwargs,
)

def state_dict(self) -> Dict[str, Any]:
Expand Down
66 changes: 66 additions & 0 deletions tests/collections/llm/gpt/data/test_pre_training_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest

import nemo.lightning as nl
from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

DATA_PATH = "/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document"
VOCAB_PATH = "/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json"
MERGES_PATH = "/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt"


@pytest.fixture
def tokenizer():
return get_nmt_tokenizer(
"megatron",
"GPT2BPETokenizer",
vocab_file=VOCAB_PATH,
merges_file=MERGES_PATH,
)


@pytest.fixture
def trainer():
return nl.Trainer(
accelerator="cpu",
max_steps=1,
)


def test_single_data_distribution(tokenizer, trainer):

data = PreTrainingDataModule(
paths=[DATA_PATH],
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
tokenizer=tokenizer,
)
data.trainer = trainer

## AssertioneError because we are trying to do eval on the whole
## dataset with just a single distribution
with pytest.raises(AssertionError):
data.setup(stage="dummy")

trainer.limit_val_batches = 5
## this should succeed
data.setup(stage="dummy")


def test_multiple_data_distributions(tokenizer, trainer):
data = PreTrainingDataModule(
paths={
"train": ['1', DATA_PATH],
"validation": [DATA_PATH, DATA_PATH],
"test": ['1', DATA_PATH],
},
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
tokenizer=tokenizer,
)
data.trainer = trainer

## this should succeed
data.setup(stage="dummy")

0 comments on commit d8701ae

Please sign in to comment.