-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support and recipes for HF models via AutoModelForCausalLM (#10962)
* initial hf_lit_module Signed-off-by: Alexandros Koumparoulis <[email protected]> * make sft gpt dataset sanity check optional Signed-off-by: Alexandros Koumparoulis <[email protected]> * HF sft example Signed-off-by: Alexandros Koumparoulis <[email protected]> * Rename HfLitModule to HfAutoModel Signed-off-by: Alexandros Koumparoulis <[email protected]> * update default model id Signed-off-by: Alexandros Koumparoulis <[email protected]> * move rank&world_size as params Signed-off-by: Alexandros Koumparoulis <[email protected]> * fix mbs in example Signed-off-by: Alexandros Koumparoulis <[email protected]> * fix for fsdp and logger Signed-off-by: Alexandros Koumparoulis <[email protected]> * make loss_fn configurable Signed-off-by: Alexandros Koumparoulis <[email protected]> * Apply isort and black reformatting Signed-off-by: akoumpa <[email protected]> * remove optim from HfAutoModel Signed-off-by: Alexandros Koumparoulis <[email protected]> * add pytorch native optim Signed-off-by: Alexandros Koumparoulis <[email protected]> * add hfAutoModel pretrain nemorun recipe Signed-off-by: Alexandros Koumparoulis <[email protected]> * remove debug Signed-off-by: Alexandros Koumparoulis <[email protected]> * remove stale imports Signed-off-by: Alexandros Koumparoulis <[email protected]> * remove stale import Signed-off-by: Alexandros Koumparoulis <[email protected]> * rm stale imports Signed-off-by: Alexandros Koumparoulis <[email protected]> * rm stale imports Signed-off-by: Alexandros Koumparoulis <[email protected]> * tokenizer fix Signed-off-by: Alexandros Koumparoulis <[email protected]> * update example Signed-off-by: Alexandros Koumparoulis <[email protected]> * rename pytorch_adam to pytorch_adam_with_cosine_annealing Signed-off-by: Alexandros Koumparoulis <[email protected]> * small refactor Signed-off-by: Alexandros Koumparoulis <[email protected]> * fix no_weight_decay_cond Signed-off-by: Alexandros Koumparoulis <[email protected]> * fix Signed-off-by: Alexandros Koumparoulis <[email protected]> * fix Signed-off-by: Alexandros Koumparoulis <[email protected]> * switch to flat_lr optim for example Signed-off-by: Alexandros Koumparoulis <[email protected]> * Apply isort and black reformatting Signed-off-by: akoumpa <[email protected]> * remove imports & update docstrings Signed-off-by: Alexandros Koumparoulis <[email protected]> * add a tokenizer setter to allow it to work with nemo/collections/llm/api.py::_use_tokenizer Signed-off-by: Alexandros Koumparoulis <[email protected]> * remove unused import Signed-off-by: Alexandros Koumparoulis <[email protected]> * allow loss_mask to be none Signed-off-by: Alexandros Koumparoulis <[email protected]> * Add HF-dataset lightning module Signed-off-by: Alexandros Koumparoulis <[email protected]> * check if pad_token_id is None Signed-off-by: Alexandros Koumparoulis <[email protected]> * rename hf_lit_module.py to hf_auto_model.py Signed-off-by: Alexandros Koumparoulis <[email protected]> * class rename Signed-off-by: Alexandros Koumparoulis <[email protected]> * rename Signed-off-by: Alexandros Koumparoulis <[email protected]> * update example Signed-off-by: Alexandros Koumparoulis <[email protected]> * HfAutoModelForCausalLM Signed-off-by: Alexandros Koumparoulis <[email protected]> * rm stale import Signed-off-by: Alexandros Koumparoulis <[email protected]> * add option to start with random weights Signed-off-by: Alexandros Koumparoulis <[email protected]> * add check in megatron-strategy Signed-off-by: Alexandros Koumparoulis <[email protected]> * rename param Signed-off-by: Alexandros Koumparoulis <[email protected]> * drop mcore sampler from squadmodule Signed-off-by: Alexandros Koumparoulis <[email protected]> * make megatron_sampler optional in HfDatasetDataModule Signed-off-by: Alexandros Koumparoulis <[email protected]> * copyright Signed-off-by: Alexandros Koumparoulis <[email protected]> * use is_hf_model to mark hf classes Signed-off-by: Alexandros Koumparoulis <[email protected]> * Apply isort and black reformatting Signed-off-by: akoumpa <[email protected]> --------- Signed-off-by: Alexandros Koumparoulis <[email protected]> Signed-off-by: akoumpa <[email protected]> Co-authored-by: akoumpa <[email protected]>
- Loading branch information
Showing
18 changed files
with
733 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import fiddle as fdl | ||
import pytorch_lightning as pl | ||
from pytorch_lightning.loggers import WandbLogger | ||
from torch.utils.data import DataLoader | ||
|
||
from nemo import lightning as nl | ||
from nemo.collections import llm | ||
|
||
|
||
class SquadDataModuleWithPthDataloader(llm.SquadDataModule): | ||
def _create_dataloader(self, dataset, **kwargs) -> DataLoader: | ||
return DataLoader( | ||
dataset, | ||
num_workers=self.num_workers, | ||
pin_memory=self.pin_memory, | ||
persistent_workers=self.persistent_workers, | ||
collate_fn=dataset.collate_fn, | ||
batch_size=self.micro_batch_size, | ||
**kwargs, | ||
) | ||
|
||
|
||
def squad(tokenizer) -> pl.LightningDataModule: | ||
return SquadDataModuleWithPthDataloader( | ||
tokenizer=tokenizer, | ||
seq_length=2048, | ||
micro_batch_size=2, | ||
global_batch_size=128, # assert gbs == mbs * accumulate_grad_batches | ||
num_workers=0, | ||
sanity_check_dist_workers=False, | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model', default='meta-llama/Llama-3.2-1B') | ||
parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp']) | ||
parser.add_argument('--devices', default=1) | ||
parser.add_argument('--accelerator', default='gpu', choices=['gpu']) | ||
parser.add_argument('--max-steps', type=int, default=100) | ||
parser.add_argument('--wandb-project', type=str, default=None) | ||
args = parser.parse_args() | ||
|
||
wandb = None | ||
if args.wandb_project is not None: | ||
model = '_'.join(args.model.split('/')[-2:]) | ||
wandb = WandbLogger( | ||
project=args.wandb_project, | ||
name=f'{model}_dev{args.devices}_strat_{args.strategy}', | ||
) | ||
grad_clip = 0.5 | ||
if args.strategy == 'fsdp': | ||
# See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81 | ||
grad_clip = None | ||
use_dist_samp = False | ||
|
||
llm.api.finetune( | ||
model=llm.HfAutoModelForCausalLM(args.model), | ||
data=squad(llm.HfAutoModelForCausalLM.configure_tokenizer(args.model)), | ||
trainer=nl.Trainer( | ||
devices=args.devices, | ||
max_steps=args.max_steps, | ||
accelerator=args.accelerator, | ||
strategy=args.strategy, | ||
log_every_n_steps=1, | ||
limit_val_batches=0.0, | ||
num_sanity_val_steps=0, | ||
accumulate_grad_batches=10, | ||
gradient_clip_val=grad_clip, | ||
use_distributed_sampler=use_dist_samp, | ||
logger=wandb, | ||
), | ||
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(max_lr=1e-5, clip_grad=0.5)), | ||
log=None, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
class HfDatasetDataModule(pl.LightningDataModule): | ||
def __init__( | ||
self, | ||
dataset, | ||
num_workers=2, | ||
pin_memory=True, | ||
persistent_workers=True, | ||
micro_batch_size=2, | ||
global_batch_size=2, | ||
pad_token_id=0, | ||
use_mcore_sampler=False, | ||
mcore_dataloader_type='cyclic', | ||
) -> None: | ||
super().__init__() | ||
assert pad_token_id is not None | ||
|
||
self.dataset = dataset | ||
self.num_workers = num_workers | ||
self.pin_memory = pin_memory | ||
self.persistent_workers = persistent_workers | ||
self.micro_batch_size = micro_batch_size | ||
self.global_batch_size = global_batch_size | ||
self.pad_token_id = pad_token_id | ||
|
||
self.use_mcore_sampler = use_mcore_sampler | ||
self.mcore_dataloader_type = mcore_dataloader_type | ||
|
||
@staticmethod | ||
def collate_fn(batch, pad_token_id=0): | ||
def batchify(tensor): | ||
if tensor.ndim == 1: | ||
return tensor.unsqueeze_(0) | ||
return tensor | ||
|
||
def extract_key_from_dicts(batch, key): | ||
return list(map(lambda x: x[key], batch)) | ||
|
||
def pad_within_micro(batch, pad_token_id): | ||
max_len = max(map(len, batch)) | ||
return [item + [pad_token_id] * (max_len - len(item)) for item in batch] | ||
|
||
return { | ||
key: batchify( | ||
torch.LongTensor( | ||
pad_within_micro( | ||
extract_key_from_dicts(batch, key), | ||
pad_token_id, | ||
) | ||
) | ||
) | ||
for key in ['tokens', 'labels'] | ||
} | ||
|
||
def train_dataloader(self, collate_fn=None): | ||
from nemo.lightning.data import add_megatron_sampler | ||
|
||
if collate_fn is None: | ||
collate_fn = lambda x: HfDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) | ||
|
||
dataloader = DataLoader( | ||
self.dataset, | ||
num_workers=self.num_workers, | ||
pin_memory=self.pin_memory, | ||
persistent_workers=self.persistent_workers, | ||
collate_fn=collate_fn, | ||
batch_size=self.micro_batch_size, | ||
) | ||
if not self.use_mcore_sampler: | ||
return dataloader | ||
|
||
rank = 0 | ||
world_size = 1 | ||
if torch.distributed.is_initialized(): | ||
rank = torch.distributed.get_rank() | ||
world_size = torch.distributed.get_world_size() | ||
|
||
return add_megatron_sampler( | ||
dataloader, | ||
self.micro_batch_size, | ||
self.global_batch_size, | ||
dataloader_type=self.mcore_dataloader_type, | ||
rank=rank, | ||
world_size=world_size, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.