Skip to content

Commit

Permalink
Add SFT/PEFT HF tests (NVIDIA#11519)
Browse files Browse the repository at this point in the history
* Add SFT/PEFT HF tests

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* move hf examples to examples dir

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* bot

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* use mini_squad

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* use mini_squad

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

* add 2gpu DDP

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* refactor

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* use labels as passed by the user

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* update samples/ tests

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* rm unused imports

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Add tests with subset split names, e.g. train[:100]

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

* add --disable-ckpt

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* use self-hosted-azure-gpus-1 for single-gpu test

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Add TRANSFORMERS_OFFLINE=1 to hf tests

Signed-off-by: Alexandros Koumparoulis <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa and akoumpa authored Dec 12, 2024
1 parent ef68b1b commit 3d94b7e
Show file tree
Hide file tree
Showing 7 changed files with 406 additions and 78 deletions.
53 changes: 50 additions & 3 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3574,17 +3574,60 @@ jobs:
inference.outfile_path=/tmp/nlp_mcore_t5_lora_tuning_tp2/out.jsonl
L2_HF_Transformer_PEFT:
needs: [ cicd-test-container-setup ]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_PEFT') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure-gpus-1
SCRIPT: |
TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/peft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --disable-ckpt
AFTER_SCRIPT: |
rm -rf nemo_experiments
L2_HF_Transformer_PEFT_2gpu:
needs: [ cicd-test-container-setup ]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_PEFT_2gpu') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/peft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2 --strategy ddp --disable-ckpt
AFTER_SCRIPT: |
rm -rf nemo_experiments
L2_HF_Transformer_SFT_2gpu:
needs: [ cicd-test-container-setup ]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_SFT_2gpu') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2 --strategy ddp
AFTER_SCRIPT: |
rm -rf nemo_experiments
L2_HF_Transformer_SFT:
needs: [ cicd-test-container-setup ]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_SFT') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure-gpus-1
SCRIPT: |
TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10
AFTER_SCRIPT: |
rm -rf nemo_experiments
L2_HF_Transformer_SFT_TE_Acceleration:
needs: [ cicd-test-container-setup ]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_SFT_TE_Acceleration') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
RUNNER: self-hosted-azure-gpus-1
SCRIPT: |
python examples/llm/sft/hf.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --model-accelerator te
TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --model-accelerator te --max-steps 10
AFTER_SCRIPT: |
rm -rf nemo_experiments
# L2: Megatron Mock Data Generation
L2_Megatron_Mock_Data_Generation_MockGPTDataset:
Expand Down Expand Up @@ -4685,6 +4728,10 @@ jobs:
- L2_NeMo_2_llama3_pretraining_recipe
- L2_NeMo_2_llama3_fault_tolerance_plugin
- L2_NeMo_2_llama3_straggler_detection
- L2_HF_Transformer_PEFT
- L2_HF_Transformer_PEFT_2gpu
- L2_HF_Transformer_SFT
- L2_HF_Transformer_SFT_2gpu
- L2_HF_Transformer_SFT_TE_Acceleration
- L2_NeMo_2_SSM_Pretraining
- L2_NeMo_2_SSM_Finetuning
Expand Down
17 changes: 10 additions & 7 deletions examples/llm/peft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ def formatting_prompts_func(examples):
output = output[0]
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
ans = tokenizer(text)
tokens = ans['input_ids']
return {
'tokens': tokens,
'labels': tokens[1:] + [tokens[-1]],
}
ans['labels'] = ans['input_ids']
return ans

datamodule = llm.HFDatasetDataModule("rajpurkar/squad", split="train", pad_token_id=tokenizer.eos_token_id)
datamodule.map(formatting_prompts_func, batched=False, batch_size=2)
tokenizer = getattr(tokenizer, 'tokenizer', tokenizer)
datamodule = llm.HFDatasetDataModule("rajpurkar/squad", split="train[:100]", pad_token_id=tokenizer.eos_token_id)
datamodule.map(
formatting_prompts_func,
batched=False,
batch_size=2,
remove_columns=["id", "title", "context", "question", 'answers'],
)
return datamodule


Expand Down
27 changes: 14 additions & 13 deletions nemo/collections/llm/gpt/data/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@

import lightning.pytorch as pl
import torch
from datasets import load_dataset
from datasets import Dataset, DatasetDict, load_dataset
from torch.utils.data import DataLoader

from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging


def clean_split(name):
if '[' in name:
return name.split('[')[0]
return name


def make_dataset_splits(dataset, split, split_aliases):
"""
Given a dataset (e.g. from datasets.load_dataset or datasets.Dataset.from_dict) it
Expand Down Expand Up @@ -51,19 +57,18 @@ def make_dataset_splits(dataset, split, split_aliases):
> "val": Dataset .. (with 10570 rows),
> }
"""
from datasets import Dataset, DatasetDict

split_names = ['train', 'test', 'val']
dataset_splits = {_split: None for _split in split_names}
valid_split_names = ['train', 'test', 'val']
dataset_splits = {_split: None for _split in valid_split_names}

alias_to_split = {}
for split_name, _split_aliases in split_aliases.items():
assert split_name in split_names
assert split_name in valid_split_names
for alias in _split_aliases:
alias_to_split[alias] = split_name

if isinstance(dataset, Dataset):
assert isinstance(split, str), "Expected split to be a string, but got " + str(type(split))
split = clean_split(split)
dataset_splits[split] = dataset
elif isinstance(dataset, DatasetDict):
dataset_split_names = dataset.keys()
Expand All @@ -75,7 +80,7 @@ def make_dataset_splits(dataset, split, split_aliases):
elif isinstance(split, list):
logging.info(f"Loaded HF dataset will use " + str(split) + " splits.")
assert isinstance(dataset, list)
for i, alias_split_name in enumerate(split):
for i, alias_split_name in enumerate(map(clean_split, split)):
split_name = alias_to_split[alias_split_name]
assert dataset_splits[split_name] is None
dataset_splits[split_name] = dataset[i]
Expand All @@ -93,6 +98,7 @@ def make_dataset_splits(dataset, split, split_aliases):
else:
raise ValueError("Expected split name to be None, str or a list")

assert set(valid_split_names) == set(dataset_splits.keys()), dataset_splits.keys()
num_init_splits = sum(map(lambda x: x is not None, dataset_splits.values()))
assert num_init_splits > 0, f"Expected at least one split to have been initialized {num_init_splits}"
return dataset_splits
Expand Down Expand Up @@ -133,8 +139,6 @@ def __init__(
) -> None:
super().__init__()
assert pad_token_id is not None
from datasets import Dataset, DatasetDict

# A dataset usually will have several splits (e.g. train, val, test, etc).
# We map synonym names to canonical names (train, test, val).
# A synonym can be a prefix/suffixed word e.g. train <> training.
Expand Down Expand Up @@ -172,8 +176,6 @@ def __init__(

@staticmethod
def from_dict(dataset_dict, split, **kwargs):
from datasets import Dataset

dataset = Dataset.from_dict(dataset_dict)
return HFDatasetDataModule(path_or_dataset=dataset, split=split, **kwargs)

Expand All @@ -191,7 +193,6 @@ def pad_within_micro(batch, pad_token_id):
max_len = max(map(len, batch))
return [item + [pad_token_id] * (max_len - len(item)) for item in batch]

keys = list(filter(lambda x: x in batch[0], ['tokens', 'labels', 'position_ids', 'loss_mask']))
return {
key: batchify(
torch.LongTensor(
Expand All @@ -201,7 +202,7 @@ def pad_within_micro(batch, pad_token_id):
)
)
)
for key in keys
for key in batch[0].keys()
}

def setup(self, stage: str):
Expand Down
64 changes: 35 additions & 29 deletions nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ def masked_cross_entropy(logits, targets, mask=None):
return F.cross_entropy(logits, targets)


def align_labels(logits, labels):
logits = logits.float()
n_cls = logits.shape[-1]
if logits.shape[-2] == labels.shape[-1]:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
elif logits.shape[-2] == labels.shape[-1] + 1:
logits = logits[..., :-1, :].contiguous()
else:
raise ValueError("Mismatched labels and logits shapes (" + str(labels.shape) + " " + str(logits.shape))
return logits.view(-1, n_cls), labels.view(-1)


class HFAutoModelForCausalLM(pl.LightningModule, io.IOMixin, fn.FNMixin):
def __init__(
self,
Expand Down Expand Up @@ -91,41 +104,34 @@ def configure_model(self):

self.model.train()

def forward(self, input_ids, attention_mask=None, labels=None, loss_mask=None):
outputs = self.model(
input_ids=input_ids.to(self.model.device),
attention_mask=attention_mask,
)
labels = labels.to(self.model.device)
if loss_mask is not None:
loss_mask = loss_mask.to(self.model.device).view(-1)
n_cls = outputs.logits.shape[-1]
outputs.loss = self.loss_fn(outputs.logits.view(-1, n_cls), labels.view(-1), loss_mask)
return outputs
def forward(self, batch):
return self.model(**batch)

def training_step(self, batch):
tokens = batch['tokens']
labels = batch['labels']
loss_mask = batch.get('loss_mask', None)
output = self.forward(
input_ids=tokens,
labels=labels,
loss_mask=loss_mask,
)

loss = output.loss
labels = batch.pop('labels').to(self.model.device)
loss_mask = batch.pop('loss_mask', None)

outputs = self.forward(batch)

# Prepare for loss calculation
logits, labels = align_labels(outputs.logits.float(), labels)
assert logits.shape[-2] == labels.shape[-1]

loss = self.loss_fn(logits, labels, loss_mask)
self.log('train_log', loss, on_step=True, on_epoch=True, prog_bar=True)
return loss

@torch.no_grad
def validation_step(self, batch, batch_idx):
tokens = batch['tokens']
labels = batch['labels']
output = self.forward(
input_ids=tokens,
labels=labels,
)

loss = output.loss
labels = batch.pop('labels').to(self.model.device)
loss_mask = batch.pop('loss_mask', None)

outputs = self.forward(**batch)

logits, labels = align_labels(outputs.logits.float(), labels)
assert logits.shape[-2] == labels.shape[-1]
loss = self.loss_fn(logits, labels, loss_mask)

self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)

def save_pretrained(self, path):
Expand Down
82 changes: 56 additions & 26 deletions tests/collections/llm/gpt/data/test_hf_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,30 @@ def test_load_single_split():
assert ds.test is None


def test_load_single_split_with_subset():
ds = llm.HFDatasetDataModule(
path_or_dataset=DATA_PATH,
split='train[:10]',
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
)
from datasets.arrow_dataset import Dataset

assert isinstance(ds.dataset_splits, dict)
assert len(ds.dataset_splits) == 3
assert 'train' in ds.dataset_splits
assert ds.dataset_splits['train'] is not None
assert ds.train is not None
assert isinstance(ds.dataset_splits['train'], Dataset)
assert 'val' in ds.dataset_splits
assert ds.dataset_splits['val'] is None
assert ds.val is None
assert 'test' in ds.dataset_splits
assert ds.dataset_splits['test'] is None
assert ds.test is None


def test_load_nonexistent_split():
exception_msg = ''
expected_msg = '''Unknown split "this_split_name_should_not_exist". Should be one of ['train', 'validation'].'''
Expand Down Expand Up @@ -84,6 +108,33 @@ def test_load_multiple_split():
assert ds.test is None


def test_load_multiple_split_with_subset():
ds = llm.HFDatasetDataModule(
path_or_dataset=DATA_PATH,
split=['train[:100]', 'validation'],
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
)
from datasets.arrow_dataset import Dataset

assert isinstance(ds.dataset_splits, dict)
assert len(ds.dataset_splits) == 3
assert 'train' in ds.dataset_splits
assert ds.dataset_splits['train'] is not None
assert ds.train is not None
assert isinstance(ds.dataset_splits['train'], Dataset)
assert isinstance(ds.train, Dataset)
assert 'val' in ds.dataset_splits
assert ds.dataset_splits['val'] is not None
assert ds.val is not None
assert isinstance(ds.dataset_splits['val'], Dataset)
assert isinstance(ds.val, Dataset)
assert 'test' in ds.dataset_splits
assert ds.dataset_splits['test'] is None
assert ds.test is None


def test_validate_dataset_asset_accessibility_file_does_not_exist():
raised_exception = False
try:
Expand All @@ -99,8 +150,9 @@ def test_validate_dataset_asset_accessibility_file_does_not_exist():
assert raised_exception == True, "Expected to raise a FileNotFoundError"


def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trainer):
raised_exception = False
def test_validate_dataset_asset_accessibility_file_is_none():
exception_msg = ''
expected_msg = "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got <class 'NoneType'>"
try:
llm.HFDatasetDataModule(
path_or_dataset=None,
Expand All @@ -109,28 +161,6 @@ def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trai
global_batch_size=2,
)
except ValueError as e:
raised_exception = (
str(e) == "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got <class 'NoneType'>"
)

assert raised_exception == True, "Expected to raise a ValueError"


def test_load_from_dict():
data = {'text': "Below is an instruction that describes a task, paired with an input that "}
exception_msg = str(e)

datamodule = llm.HFDatasetDataModule.from_dict(
{"text": [data['text'] for _ in range(101)]},
split='train',
global_batch_size=4,
micro_batch_size=1,
)
assert datamodule is not None
assert isinstance(datamodule, llm.HFDatasetDataModule)
assert hasattr(datamodule, 'train')
assert datamodule.train is not None
assert len(datamodule.train) == 101
assert hasattr(datamodule, 'val')
assert datamodule.val is None
assert hasattr(datamodule, 'test')
assert datamodule.test is None
assert exception_msg == expected_msg, exception_msg
Loading

0 comments on commit 3d94b7e

Please sign in to comment.