Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo UX] Support generating datasets using different train/valid/test distributions #9771

Merged
merged 7 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/llm/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def get_args():
logger=loggers,
callbacks=callbacks,
log_every_n_steps=1,
limit_val_batches=2,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed", amp_O2=False),
)

Expand Down
41 changes: 29 additions & 12 deletions nemo/collections/llm/gpt/data/pre_training.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional

Expand All @@ -17,8 +18,7 @@
class PreTrainingDataModule(pl.LightningDataModule):
def __init__(
self,
paths: Path | List[Path],
weights: Optional[List[float]] = None,
paths: Path | List | Dict[str, List],
seq_length: int = 2048,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
Expand All @@ -38,16 +38,30 @@ def __init__(
index_mapping_dir: Optional[str] = None,
) -> None:
super().__init__()
if not isinstance(paths, (list, tuple)):
if not isinstance(paths, (list, tuple, dict)):
paths = [paths]
if weights is not None:
assert len(weights) == len(paths)
if len(weights) == 1:
# weights must be None if there is only one dataset

from megatron.core.datasets.utils import get_blend_from_list

build_kwargs = {}
if isinstance(paths, dict):
if split is not None:
warnings.warn(
f"{split=} will be ignored since datasets are being created " f"from 3 separate distributions."
)
build_kwargs["blend_per_split"] = [
get_blend_from_list(paths["train"]),
get_blend_from_list(paths["validation"]),
get_blend_from_list(paths["test"]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashors1 does get_blend_from_list(paths["train"]) work even when you have multiple data files. For ex: {"train": /datafile1/, /datafile2/}. Also in this case if weights are ignored then is the dataset built with all samples from both /datafile1/ and /datafile2/ ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, get_blend_from_list(paths["train"]) works when you have multiple paths. You're also able to pass in weights by interleaving them with the paths. For example, the following would work:

paths={
            "train": [25, PATH1, 75, PATH2],
            "validation": [PATH3, PATH4],
            "test": ['1', PATH5], 
        }

The only time the weights are not used is when limit_val_batches <= 1.0, in which case we want to return the full validation dataset. In this case, users are expected not to provide weights for the paths.

]
else:
paths, weights = get_blend_from_list(paths)
if len(paths) == 1:
weights = None
build_kwargs["blend"] = [paths, weights]
build_kwargs["split"] = split

self.paths = paths
self.weights = weights
self.build_kwargs = build_kwargs
self.seq_length = seq_length
self.tokenizer = tokenizer
self.num_train_samples = num_train_samples
Expand Down Expand Up @@ -92,8 +106,12 @@ def setup(self, stage: str = "") -> None:
num_test_samples = int(test_iters * self.data_sampler.global_batch_size)

if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float):
assert (
"blend" not in self.build_kwargs
), f"When using a single data distribution, limit_val_batches <= 1.0 is not supported."

ashors1 marked this conversation as resolved.
Show resolved Hide resolved
# This is to make sure we only have one epoch on every validation iteration
num_val_samples = None if self.weights is None else 1
num_val_samples = None

train_valid_test_num_samples = [num_train_samples, num_val_samples, num_test_samples]
self._train_ds, self._validation_ds, self._test_ds = BlendedMegatronDatasetBuilder(
Expand Down Expand Up @@ -145,15 +163,14 @@ def gpt_dataset_config(self) -> "GPTDatasetConfig":
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig

return GPTDatasetConfig(
blend=[[str(path) for path in self.paths], self.weights],
random_seed=self.seed,
sequence_length=self.seq_length,
tokenizer=self.tokenizer,
split=self.split,
path_to_cache=self.index_mapping_dir,
reset_position_ids=self.reset_position_ids,
reset_attention_mask=self.reset_attention_mask,
eod_mask_loss=self.eod_mask_loss,
**self.build_kwargs,
)

def state_dict(self) -> Dict[str, Any]:
Expand Down
66 changes: 66 additions & 0 deletions tests/collections/llm/gpt/data/test_pre_training_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest

import nemo.lightning as nl
from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

DATA_PATH = "/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document"
VOCAB_PATH = "/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json"
MERGES_PATH = "/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt"


@pytest.fixture
def tokenizer():
return get_nmt_tokenizer(
"megatron",
"GPT2BPETokenizer",
vocab_file=VOCAB_PATH,
merges_file=MERGES_PATH,
)


@pytest.fixture
def trainer():
return nl.Trainer(
accelerator="cpu",
max_steps=1,
)


def test_single_data_distribution(tokenizer, trainer):

data = PreTrainingDataModule(
paths=[DATA_PATH],
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
tokenizer=tokenizer,
)
data.trainer = trainer

## AssertioneError because we are trying to do eval on the whole
## dataset with just a single distribution
with pytest.raises(AssertionError):
data.setup(stage="dummy")

trainer.limit_val_batches = 5
## this should succeed
data.setup(stage="dummy")


def test_multiple_data_distributions(tokenizer, trainer):
data = PreTrainingDataModule(
paths={
"train": ['1', DATA_PATH],
"validation": [DATA_PATH, DATA_PATH],
"test": ['1', DATA_PATH],
},
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
tokenizer=tokenizer,
)
data.trainer = trainer

## this should succeed
data.setup(stage="dummy")
Loading