Skip to content

Commit

Permalink
Add support and recipes for HF models via AutoModelForCausalLM (#10962)
Browse files Browse the repository at this point in the history
* initial hf_lit_module

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* make sft gpt dataset sanity check optional

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* HF sft example

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Rename HfLitModule to HfAutoModel

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* update default model id

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* move rank&world_size as params

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix mbs in example

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix for fsdp and logger

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* make loss_fn configurable

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

* remove optim from HfAutoModel

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add pytorch native optim

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add hfAutoModel pretrain nemorun recipe

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* remove debug

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* remove stale imports

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* remove stale import

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* rm stale imports

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* rm stale imports

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* tokenizer fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* update example

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* rename pytorch_adam to pytorch_adam_with_cosine_annealing

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* small refactor

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix no_weight_decay_cond

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* switch to flat_lr optim for example

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

* remove imports & update docstrings

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add a tokenizer setter to allow it to work with nemo/collections/llm/api.py::_use_tokenizer

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* remove unused import

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* allow loss_mask to be none

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Add HF-dataset lightning module

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* check if pad_token_id is None

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* rename hf_lit_module.py to hf_auto_model.py

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* class rename

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* rename

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* update example

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* HfAutoModelForCausalLM

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* rm stale import

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add option to start with random weights

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add check in megatron-strategy

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* rename param

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* drop mcore sampler from squadmodule

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* make megatron_sampler optional in HfDatasetDataModule

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* copyright

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* use is_hf_model to mark hf classes

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa and akoumpa authored Oct 23, 2024
1 parent 69e3c3f commit 8f26236
Show file tree
Hide file tree
Showing 18 changed files with 733 additions and 23 deletions.
91 changes: 91 additions & 0 deletions examples/llm/sft/hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import fiddle as fdl
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader

from nemo import lightning as nl
from nemo.collections import llm


class SquadDataModuleWithPthDataloader(llm.SquadDataModule):
def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
return DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=dataset.collate_fn,
batch_size=self.micro_batch_size,
**kwargs,
)


def squad(tokenizer) -> pl.LightningDataModule:
return SquadDataModuleWithPthDataloader(
tokenizer=tokenizer,
seq_length=2048,
micro_batch_size=2,
global_batch_size=128, # assert gbs == mbs * accumulate_grad_batches
num_workers=0,
sanity_check_dist_workers=False,
)


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--model', default='meta-llama/Llama-3.2-1B')
parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp'])
parser.add_argument('--devices', default=1)
parser.add_argument('--accelerator', default='gpu', choices=['gpu'])
parser.add_argument('--max-steps', type=int, default=100)
parser.add_argument('--wandb-project', type=str, default=None)
args = parser.parse_args()

wandb = None
if args.wandb_project is not None:
model = '_'.join(args.model.split('/')[-2:])
wandb = WandbLogger(
project=args.wandb_project,
name=f'{model}_dev{args.devices}_strat_{args.strategy}',
)
grad_clip = 0.5
if args.strategy == 'fsdp':
# See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81
grad_clip = None
use_dist_samp = False

llm.api.finetune(
model=llm.HfAutoModelForCausalLM(args.model),
data=squad(llm.HfAutoModelForCausalLM.configure_tokenizer(args.model)),
trainer=nl.Trainer(
devices=args.devices,
max_steps=args.max_steps,
accelerator=args.accelerator,
strategy=args.strategy,
log_every_n_steps=1,
limit_val_batches=0.0,
num_sanity_val_steps=0,
accumulate_grad_batches=10,
gradient_clip_val=grad_clip,
use_distributed_sampler=use_dist_samp,
logger=wandb,
),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(max_lr=1e-5, clip_grad=0.5)),
log=None,
)
3 changes: 3 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nemo.collections.llm.gpt.data import (
DollyDataModule,
FineTuningDataModule,
HfDatasetDataModule,
MockDataModule,
PreTrainingDataModule,
SquadDataModule,
Expand Down Expand Up @@ -57,6 +58,7 @@
GPTConfig126M,
GPTConfig175B,
GPTModel,
HfAutoModelForCausalLM,
Llama2Config7B,
Llama2Config13B,
Llama2Config70B,
Expand Down Expand Up @@ -182,6 +184,7 @@
"squad",
"dolly",
"peft",
"HfAutoModelForCausalLM",
]


Expand Down
10 changes: 9 additions & 1 deletion nemo/collections/llm/gpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,16 @@

from nemo.collections.llm.gpt.data.dolly import DollyDataModule
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule
from nemo.collections.llm.gpt.data.squad import SquadDataModule

__all__ = ["FineTuningDataModule", "SquadDataModule", "DollyDataModule", "MockDataModule", "PreTrainingDataModule"]
__all__ = [
"FineTuningDataModule",
"SquadDataModule",
"DollyDataModule",
"MockDataModule",
"PreTrainingDataModule",
"HfDatasetDataModule",
]
5 changes: 5 additions & 0 deletions nemo/collections/llm/gpt/data/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
persistent_workers: bool = False,
pad_to_max_length: bool = False,
packed_sequence_specs: Optional["PackedSequenceSpecs"] = None,
sanity_check_dist_workers: bool = True,
):
super().__init__()
self.seq_length = seq_length
Expand All @@ -89,6 +90,7 @@ def __init__(
self.packed_sequence_specs = packed_sequence_specs
self.packed_sequence_size = -1 if not packed_sequence_specs else packed_sequence_specs.packed_sequence_size
self.validate_batch_size_for_packed_sequence()
self._sanity_check_dist_workers = sanity_check_dist_workers

def validate_batch_size_for_packed_sequence(self):
if self.packed_sequence_size > 0 and self.micro_batch_size > 1:
Expand Down Expand Up @@ -134,6 +136,7 @@ def train_dataloader(self) -> DataLoader:
self.train_path if self.packed_sequence_size <= 0 else self.train_path_packed,
max_num_samples=self.max_train_samples,
pad_to_max_length=self.pad_to_max_length,
sanity_check_dist_workers=self._sanity_check_dist_workers,
)
)

Expand All @@ -143,6 +146,7 @@ def val_dataloader(self) -> DataLoader:
self.validation_path,
is_test=True,
pad_to_max_length=self.pad_to_max_length,
sanity_check_dist_workers=self._sanity_check_dist_workers,
),
)

Expand All @@ -153,6 +157,7 @@ def test_dataloader(self) -> DataLoader:
tokens_to_generate=32,
is_test=True,
pad_to_max_length=self.pad_to_max_length,
sanity_check_dist_workers=self._sanity_check_dist_workers,
)
)

Expand Down
103 changes: 103 additions & 0 deletions nemo/collections/llm/gpt/data/hf_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader


class HfDatasetDataModule(pl.LightningDataModule):
def __init__(
self,
dataset,
num_workers=2,
pin_memory=True,
persistent_workers=True,
micro_batch_size=2,
global_batch_size=2,
pad_token_id=0,
use_mcore_sampler=False,
mcore_dataloader_type='cyclic',
) -> None:
super().__init__()
assert pad_token_id is not None

self.dataset = dataset
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.pad_token_id = pad_token_id

self.use_mcore_sampler = use_mcore_sampler
self.mcore_dataloader_type = mcore_dataloader_type

@staticmethod
def collate_fn(batch, pad_token_id=0):
def batchify(tensor):
if tensor.ndim == 1:
return tensor.unsqueeze_(0)
return tensor

def extract_key_from_dicts(batch, key):
return list(map(lambda x: x[key], batch))

def pad_within_micro(batch, pad_token_id):
max_len = max(map(len, batch))
return [item + [pad_token_id] * (max_len - len(item)) for item in batch]

return {
key: batchify(
torch.LongTensor(
pad_within_micro(
extract_key_from_dicts(batch, key),
pad_token_id,
)
)
)
for key in ['tokens', 'labels']
}

def train_dataloader(self, collate_fn=None):
from nemo.lightning.data import add_megatron_sampler

if collate_fn is None:
collate_fn = lambda x: HfDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id)

dataloader = DataLoader(
self.dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=collate_fn,
batch_size=self.micro_batch_size,
)
if not self.use_mcore_sampler:
return dataloader

rank = 0
world_size = 1
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

return add_megatron_sampler(
dataloader,
self.micro_batch_size,
self.global_batch_size,
dataloader_type=self.mcore_dataloader_type,
rank=rank,
world_size=world_size,
)
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/data/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
persistent_workers: bool = False,
pad_to_max_length: bool = False,
packed_sequence_specs: Optional["PackedSequenceSpecs"] = None,
sanity_check_dist_workers: bool = True,
):
self.force_redownload = force_redownload
self.delete_raw = delete_raw
Expand All @@ -74,6 +75,7 @@ def __init__(
persistent_workers=persistent_workers,
pad_to_max_length=pad_to_max_length,
packed_sequence_specs=packed_sequence_specs,
sanity_check_dist_workers=sanity_check_dist_workers,
)

def prepare_data(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
GemmaConfig7B,
GemmaModel,
)
from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HfAutoModelForCausalLM
from nemo.collections.llm.gpt.model.llama import (
CodeLlamaConfig7B,
CodeLlamaConfig13B,
Expand Down Expand Up @@ -166,4 +167,5 @@
"gpt_forward_step",
"transformer_engine_layer_spec",
"local_layer_spec",
"HfAutoModelForCausalLM",
]
Loading

0 comments on commit 8f26236

Please sign in to comment.