diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 8b6d2c0251bd..37d8b903afa4 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -3574,17 +3574,60 @@ jobs: inference.outfile_path=/tmp/nlp_mcore_t5_lora_tuning_tp2/out.jsonl + L2_HF_Transformer_PEFT: + needs: [ cicd-test-container-setup ] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_PEFT') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure-gpus-1 + SCRIPT: | + TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/peft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --disable-ckpt + AFTER_SCRIPT: | + rm -rf nemo_experiments + + L2_HF_Transformer_PEFT_2gpu: + needs: [ cicd-test-container-setup ] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_PEFT_2gpu') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/peft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2 --strategy ddp --disable-ckpt + AFTER_SCRIPT: | + rm -rf nemo_experiments + + L2_HF_Transformer_SFT_2gpu: + needs: [ cicd-test-container-setup ] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_SFT_2gpu') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2 --strategy ddp + AFTER_SCRIPT: | + rm -rf nemo_experiments + + L2_HF_Transformer_SFT: + needs: [ cicd-test-container-setup ] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_SFT') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure-gpus-1 + SCRIPT: | + TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 + AFTER_SCRIPT: | + rm -rf nemo_experiments + L2_HF_Transformer_SFT_TE_Acceleration: needs: [ cicd-test-container-setup ] uses: ./.github/workflows/_test_template.yml if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_SFT_TE_Acceleration') || needs.cicd-test-container-setup.outputs.all == 'true' with: - RUNNER: self-hosted-azure + RUNNER: self-hosted-azure-gpus-1 SCRIPT: | - python examples/llm/sft/hf.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --model-accelerator te + TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --model-accelerator te --max-steps 10 AFTER_SCRIPT: | rm -rf nemo_experiments - # L2: Megatron Mock Data Generation L2_Megatron_Mock_Data_Generation_MockGPTDataset: @@ -4685,6 +4728,10 @@ jobs: - L2_NeMo_2_llama3_pretraining_recipe - L2_NeMo_2_llama3_fault_tolerance_plugin - L2_NeMo_2_llama3_straggler_detection + - L2_HF_Transformer_PEFT + - L2_HF_Transformer_PEFT_2gpu + - L2_HF_Transformer_SFT + - L2_HF_Transformer_SFT_2gpu - L2_HF_Transformer_SFT_TE_Acceleration - L2_NeMo_2_SSM_Pretraining - L2_NeMo_2_SSM_Finetuning diff --git a/examples/llm/peft/hf.py b/examples/llm/peft/hf.py index c24c5958b388..3a0930732e87 100644 --- a/examples/llm/peft/hf.py +++ b/examples/llm/peft/hf.py @@ -39,14 +39,17 @@ def formatting_prompts_func(examples): output = output[0] text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN ans = tokenizer(text) - tokens = ans['input_ids'] - return { - 'tokens': tokens, - 'labels': tokens[1:] + [tokens[-1]], - } + ans['labels'] = ans['input_ids'] + return ans - datamodule = llm.HFDatasetDataModule("rajpurkar/squad", split="train", pad_token_id=tokenizer.eos_token_id) - datamodule.map(formatting_prompts_func, batched=False, batch_size=2) + tokenizer = getattr(tokenizer, 'tokenizer', tokenizer) + datamodule = llm.HFDatasetDataModule("rajpurkar/squad", split="train[:100]", pad_token_id=tokenizer.eos_token_id) + datamodule.map( + formatting_prompts_func, + batched=False, + batch_size=2, + remove_columns=["id", "title", "context", "question", 'answers'], + ) return datamodule diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 73b6444a6e9c..7880e26cf6b1 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -14,13 +14,19 @@ import lightning.pytorch as pl import torch -from datasets import load_dataset +from datasets import Dataset, DatasetDict, load_dataset from torch.utils.data import DataLoader from nemo.lightning.pytorch.plugins import MegatronDataSampler from nemo.utils import logging +def clean_split(name): + if '[' in name: + return name.split('[')[0] + return name + + def make_dataset_splits(dataset, split, split_aliases): """ Given a dataset (e.g. from datasets.load_dataset or datasets.Dataset.from_dict) it @@ -51,19 +57,18 @@ def make_dataset_splits(dataset, split, split_aliases): > "val": Dataset .. (with 10570 rows), > } """ - from datasets import Dataset, DatasetDict - - split_names = ['train', 'test', 'val'] - dataset_splits = {_split: None for _split in split_names} + valid_split_names = ['train', 'test', 'val'] + dataset_splits = {_split: None for _split in valid_split_names} alias_to_split = {} for split_name, _split_aliases in split_aliases.items(): - assert split_name in split_names + assert split_name in valid_split_names for alias in _split_aliases: alias_to_split[alias] = split_name if isinstance(dataset, Dataset): assert isinstance(split, str), "Expected split to be a string, but got " + str(type(split)) + split = clean_split(split) dataset_splits[split] = dataset elif isinstance(dataset, DatasetDict): dataset_split_names = dataset.keys() @@ -75,7 +80,7 @@ def make_dataset_splits(dataset, split, split_aliases): elif isinstance(split, list): logging.info(f"Loaded HF dataset will use " + str(split) + " splits.") assert isinstance(dataset, list) - for i, alias_split_name in enumerate(split): + for i, alias_split_name in enumerate(map(clean_split, split)): split_name = alias_to_split[alias_split_name] assert dataset_splits[split_name] is None dataset_splits[split_name] = dataset[i] @@ -93,6 +98,7 @@ def make_dataset_splits(dataset, split, split_aliases): else: raise ValueError("Expected split name to be None, str or a list") + assert set(valid_split_names) == set(dataset_splits.keys()), dataset_splits.keys() num_init_splits = sum(map(lambda x: x is not None, dataset_splits.values())) assert num_init_splits > 0, f"Expected at least one split to have been initialized {num_init_splits}" return dataset_splits @@ -133,8 +139,6 @@ def __init__( ) -> None: super().__init__() assert pad_token_id is not None - from datasets import Dataset, DatasetDict - # A dataset usually will have several splits (e.g. train, val, test, etc). # We map synonym names to canonical names (train, test, val). # A synonym can be a prefix/suffixed word e.g. train <> training. @@ -172,8 +176,6 @@ def __init__( @staticmethod def from_dict(dataset_dict, split, **kwargs): - from datasets import Dataset - dataset = Dataset.from_dict(dataset_dict) return HFDatasetDataModule(path_or_dataset=dataset, split=split, **kwargs) @@ -191,7 +193,6 @@ def pad_within_micro(batch, pad_token_id): max_len = max(map(len, batch)) return [item + [pad_token_id] * (max_len - len(item)) for item in batch] - keys = list(filter(lambda x: x in batch[0], ['tokens', 'labels', 'position_ids', 'loss_mask'])) return { key: batchify( torch.LongTensor( @@ -201,7 +202,7 @@ def pad_within_micro(batch, pad_token_id): ) ) ) - for key in keys + for key in batch[0].keys() } def setup(self, stage: str): diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index 481dd9a0e187..a51bbffdd6b6 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -31,6 +31,19 @@ def masked_cross_entropy(logits, targets, mask=None): return F.cross_entropy(logits, targets) +def align_labels(logits, labels): + logits = logits.float() + n_cls = logits.shape[-1] + if logits.shape[-2] == labels.shape[-1]: + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + elif logits.shape[-2] == labels.shape[-1] + 1: + logits = logits[..., :-1, :].contiguous() + else: + raise ValueError("Mismatched labels and logits shapes (" + str(labels.shape) + " " + str(logits.shape)) + return logits.view(-1, n_cls), labels.view(-1) + + class HFAutoModelForCausalLM(pl.LightningModule, io.IOMixin, fn.FNMixin): def __init__( self, @@ -91,41 +104,34 @@ def configure_model(self): self.model.train() - def forward(self, input_ids, attention_mask=None, labels=None, loss_mask=None): - outputs = self.model( - input_ids=input_ids.to(self.model.device), - attention_mask=attention_mask, - ) - labels = labels.to(self.model.device) - if loss_mask is not None: - loss_mask = loss_mask.to(self.model.device).view(-1) - n_cls = outputs.logits.shape[-1] - outputs.loss = self.loss_fn(outputs.logits.view(-1, n_cls), labels.view(-1), loss_mask) - return outputs + def forward(self, batch): + return self.model(**batch) def training_step(self, batch): - tokens = batch['tokens'] - labels = batch['labels'] - loss_mask = batch.get('loss_mask', None) - output = self.forward( - input_ids=tokens, - labels=labels, - loss_mask=loss_mask, - ) - - loss = output.loss + labels = batch.pop('labels').to(self.model.device) + loss_mask = batch.pop('loss_mask', None) + + outputs = self.forward(batch) + + # Prepare for loss calculation + logits, labels = align_labels(outputs.logits.float(), labels) + assert logits.shape[-2] == labels.shape[-1] + + loss = self.loss_fn(logits, labels, loss_mask) self.log('train_log', loss, on_step=True, on_epoch=True, prog_bar=True) return loss + @torch.no_grad def validation_step(self, batch, batch_idx): - tokens = batch['tokens'] - labels = batch['labels'] - output = self.forward( - input_ids=tokens, - labels=labels, - ) - - loss = output.loss + labels = batch.pop('labels').to(self.model.device) + loss_mask = batch.pop('loss_mask', None) + + outputs = self.forward(**batch) + + logits, labels = align_labels(outputs.logits.float(), labels) + assert logits.shape[-2] == labels.shape[-1] + loss = self.loss_fn(logits, labels, loss_mask) + self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True) def save_pretrained(self, path): diff --git a/tests/collections/llm/gpt/data/test_hf_datamodule.py b/tests/collections/llm/gpt/data/test_hf_datamodule.py index 58f7c02e091b..af035d91034d 100644 --- a/tests/collections/llm/gpt/data/test_hf_datamodule.py +++ b/tests/collections/llm/gpt/data/test_hf_datamodule.py @@ -41,6 +41,30 @@ def test_load_single_split(): assert ds.test is None +def test_load_single_split_with_subset(): + ds = llm.HFDatasetDataModule( + path_or_dataset=DATA_PATH, + split='train[:10]', + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + ) + from datasets.arrow_dataset import Dataset + + assert isinstance(ds.dataset_splits, dict) + assert len(ds.dataset_splits) == 3 + assert 'train' in ds.dataset_splits + assert ds.dataset_splits['train'] is not None + assert ds.train is not None + assert isinstance(ds.dataset_splits['train'], Dataset) + assert 'val' in ds.dataset_splits + assert ds.dataset_splits['val'] is None + assert ds.val is None + assert 'test' in ds.dataset_splits + assert ds.dataset_splits['test'] is None + assert ds.test is None + + def test_load_nonexistent_split(): exception_msg = '' expected_msg = '''Unknown split "this_split_name_should_not_exist". Should be one of ['train', 'validation'].''' @@ -84,6 +108,33 @@ def test_load_multiple_split(): assert ds.test is None +def test_load_multiple_split_with_subset(): + ds = llm.HFDatasetDataModule( + path_or_dataset=DATA_PATH, + split=['train[:100]', 'validation'], + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + ) + from datasets.arrow_dataset import Dataset + + assert isinstance(ds.dataset_splits, dict) + assert len(ds.dataset_splits) == 3 + assert 'train' in ds.dataset_splits + assert ds.dataset_splits['train'] is not None + assert ds.train is not None + assert isinstance(ds.dataset_splits['train'], Dataset) + assert isinstance(ds.train, Dataset) + assert 'val' in ds.dataset_splits + assert ds.dataset_splits['val'] is not None + assert ds.val is not None + assert isinstance(ds.dataset_splits['val'], Dataset) + assert isinstance(ds.val, Dataset) + assert 'test' in ds.dataset_splits + assert ds.dataset_splits['test'] is None + assert ds.test is None + + def test_validate_dataset_asset_accessibility_file_does_not_exist(): raised_exception = False try: @@ -99,8 +150,9 @@ def test_validate_dataset_asset_accessibility_file_does_not_exist(): assert raised_exception == True, "Expected to raise a FileNotFoundError" -def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trainer): - raised_exception = False +def test_validate_dataset_asset_accessibility_file_is_none(): + exception_msg = '' + expected_msg = "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got " try: llm.HFDatasetDataModule( path_or_dataset=None, @@ -109,28 +161,6 @@ def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trai global_batch_size=2, ) except ValueError as e: - raised_exception = ( - str(e) == "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got " - ) - - assert raised_exception == True, "Expected to raise a ValueError" - - -def test_load_from_dict(): - data = {'text': "Below is an instruction that describes a task, paired with an input that "} + exception_msg = str(e) - datamodule = llm.HFDatasetDataModule.from_dict( - {"text": [data['text'] for _ in range(101)]}, - split='train', - global_batch_size=4, - micro_batch_size=1, - ) - assert datamodule is not None - assert isinstance(datamodule, llm.HFDatasetDataModule) - assert hasattr(datamodule, 'train') - assert datamodule.train is not None - assert len(datamodule.train) == 101 - assert hasattr(datamodule, 'val') - assert datamodule.val is None - assert hasattr(datamodule, 'test') - assert datamodule.test is None + assert exception_msg == expected_msg, exception_msg diff --git a/tests/collections/llm/hf/peft.py b/tests/collections/llm/hf/peft.py new file mode 100644 index 000000000000..018774280946 --- /dev/null +++ b/tests/collections/llm/hf/peft.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fiddle as fdl +from lightning.pytorch.loggers import WandbLogger +from nemo import lightning as nl +from nemo.collections import llm + +DATA_PATH = '/home/TestData/lite/hf_cache/squad/' + + +def make_squad_hf_dataset(data_path, tokenizer): + EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN + + def formatting_prompts_func(examples): + alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + ### Instruction: + {} + + ### Input: + {} + + ### Response: + {}""" + instruction = examples["context"] + input = examples["question"] + output = examples["answers"]['text'] + if isinstance(output, list): + output = output[0] + text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN + ans = tokenizer(text) + ans['labels'] = ans['input_ids'] + return ans + + tokenizer = getattr(tokenizer, 'tokenizer', tokenizer) + datamodule = llm.HFDatasetDataModule(data_path, split="train[:100]", pad_token_id=tokenizer.eos_token_id) + + datamodule.map( + formatting_prompts_func, + batched=False, + batch_size=2, + remove_columns=["id", "title", "context", "question", 'answers'], + ) + + return datamodule + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='meta-llama/Llama-3.2-1B') + parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp']) + parser.add_argument('--devices', default=1) + parser.add_argument('--accelerator', default='gpu', choices=['gpu']) + parser.add_argument('--max-steps', type=int, default=100) + parser.add_argument('--wandb-project', type=str, default=None) + parser.add_argument('--disable-ckpt', action='store_false') + args = parser.parse_args() + + wandb = None + if args.wandb_project is not None: + model = '_'.join(args.model.split('/')[-2:]) + wandb = WandbLogger( + project=args.wandb_project, + name=f'{model}_dev{args.devices}_strat_{args.strategy}', + ) + grad_clip = 0.5 + if args.strategy == 'fsdp': + # See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81 + grad_clip = None + use_dist_samp = False + tokenizer = llm.HFAutoModelForCausalLM.configure_tokenizer(args.model) + + llm.api.finetune( + model=llm.HFAutoModelForCausalLM(args.model), + data=make_squad_hf_dataset(DATA_PATH, tokenizer), + trainer=nl.Trainer( + devices=args.devices, + max_steps=args.max_steps, + accelerator=args.accelerator, + strategy=args.strategy, + log_every_n_steps=1, + limit_val_batches=0.0, + num_sanity_val_steps=0, + accumulate_grad_batches=10, + gradient_clip_val=grad_clip, + use_distributed_sampler=use_dist_samp, + logger=wandb, + enable_checkpointing=args.disable_ckpt, + ), + optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), + log=None, + peft=llm.peft.LoRA( + target_modules=['*_proj'], + dim=32, + ), + ) diff --git a/tests/collections/llm/hf/sft.py b/tests/collections/llm/hf/sft.py new file mode 100755 index 000000000000..44b0dabbb2d0 --- /dev/null +++ b/tests/collections/llm/hf/sft.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fiddle as fdl +from lightning.pytorch.loggers import WandbLogger + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated + + +DATA_PATH = '/home/TestData/lite/hf_cache/squad/' + + +def make_squad_hf_dataset(data_path, tokenizer): + EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN + + def formatting_prompts_func(examples): + alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + ### Instruction: + {} + + ### Input: + {} + + ### Response: + {}""" + instruction = examples["context"] + input = examples["question"] + output = examples["answers"]['text'] + if isinstance(output, list): + output = output[0] + text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN + ans = tokenizer(text) + ans['labels'] = ans['input_ids'] + return ans + + tokenizer = getattr(tokenizer, 'tokenizer', tokenizer) + datamodule = llm.HFDatasetDataModule(data_path, split="train[:100]", pad_token_id=tokenizer.eos_token_id) + + datamodule.map( + formatting_prompts_func, + batched=False, + batch_size=2, + remove_columns=["id", "title", "context", "question", 'answers'], + ) + + return datamodule + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='meta-llama/Llama-3.2-1B') + parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp']) + parser.add_argument('--devices', default=1) + parser.add_argument('--accelerator', default='gpu', choices=['gpu']) + parser.add_argument('--model-accelerator', default=None, choices=['te']) + parser.add_argument('--max-steps', type=int, default=100) + parser.add_argument("--fp8-autocast", default=False, action='store_true') + parser.add_argument('--wandb-project', type=str, default=None) + parser.add_argument('--model-save-path', type=str, default=None) + args = parser.parse_args() + + wandb = None + if args.wandb_project is not None: + model = '_'.join(args.model.split('/')[-2:]) + wandb = WandbLogger( + project=args.wandb_project, + name=f'{model}_dev{args.devices}_strat_{args.strategy}', + ) + grad_clip = 0.5 + if args.strategy == 'fsdp': + # See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81 + grad_clip = None + use_dist_samp = False + + model_accelerator = None + if args.model_accelerator == "te": + from functools import partial + from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate + + model_accelerator = partial(te_accelerate, fp8_autocast=args.fp8_autocast) + + from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate + + model = llm.HFAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator) + tokenizer = model.tokenizer + + llm.api.finetune( + model=model, + data=make_squad_hf_dataset(DATA_PATH, tokenizer), + trainer=nl.Trainer( + devices=args.devices, + max_steps=args.max_steps, + accelerator=args.accelerator, + strategy=args.strategy, + log_every_n_steps=1, + limit_val_batches=0.0, + num_sanity_val_steps=0, + accumulate_grad_batches=10, + gradient_clip_val=grad_clip, + use_distributed_sampler=use_dist_samp, + callbacks=[], + logger=wandb, + ), + optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), + log=None, + ) + + if args.model_accelerator: + if args.model_accelerator == "te": + te_acc = is_te_accelerated(model.model) + assert te_acc, "Transformer Engine acceleration was unsuccessful" + print("TE Accelerated: ", te_acc) + + if args.model_save_path is not None: + model.save_pretrained(args.model_save_path)