Skip to content

Commit

Permalink
NeMo-UX: Mistral/mixtral peft ci test (NVIDIA#11094)
Browse files Browse the repository at this point in the history
* add mistral/mixtral peft ci test

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add mistral/mixtral peft ci test

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add mistral tp2

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

* add tests to NEMO_CICD_Test

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Update .github/workflows/cicd-main.yml

Co-authored-by: oliver könig <[email protected]>
Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix params

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* rm devices arg

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add --dist-opt arg

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add tp=2 mixtral

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add ep test

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Signed-off-by: Alexandros Koumparoulis <[email protected]>
Co-authored-by: akoumpa <[email protected]>
Co-authored-by: oliver könig <[email protected]>
  • Loading branch information
3 people authored and HuiyingLi committed Nov 15, 2024
1 parent a74d6bb commit f7a88d1
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 0 deletions.
80 changes: 80 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4266,6 +4266,81 @@ jobs:
--pp_size 1 \
--mbs 1 --packed
L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/lora_mistralai.py \
--max-steps 3 \
--ep 1 \
--mbs 2 \
--model mixtral
L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/lora_mistralai.py \
--max-steps 3 \
--tp 1 \
--mbs 1 \
--model mixtral \
--dist-opt
L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/lora_mistralai.py \
--max-steps 3 \
--tp 2 \
--mbs 1 \
--model mixtral \
--dist-opt
L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/lora_mistralai.py \
--max-steps 3 \
--tp 1 \
--mbs 1 \
--model mistral \
--dist-opt
L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/lora_mistralai.py \
--max-steps 3 \
--tp 2 \
--mbs 1 \
--model mistral \
--dist-opt
L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -4422,6 +4497,11 @@ jobs:
- L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2
- L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2
- L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED
- L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2
- L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1
- L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1
- L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1
- L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1
- L2_NeMo_2_Mixtral_Pretraining
- L2_PTQ_Llama2_FP8
- L2_Community_LLM_Checkpoints_tests_Llama3
Expand Down
139 changes: 139 additions & 0 deletions tests/collections/llm/lora_mistralai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import pytorch_lightning as pl
import torch
from megatron.core.optimizer import OptimizerConfig

from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning.io.mixin import track_io


def get_args():
parser = argparse.ArgumentParser(description='Finetune a small GPT model using NeMo 2.0')
parser.add_argument('--model', type=str.lower, choices=['mistral', 'mixtral'], help="model")
parser.add_argument('--max-steps', type=int, default=9, help="number of devices")
parser.add_argument('--mbs', type=int, default=2, help="micro batch size")
parser.add_argument('--gbs', type=int, default=4, help="global batch size")
parser.add_argument('--tp', type=int, default=1, help="tensor parallel size")
parser.add_argument('--ep', type=int, default=1, help="expert parallel size")
parser.add_argument('--dist-opt', action='store_true', help='use dist opt')
return parser.parse_args()


def trainer(devices, tp, ep, sp, max_steps) -> nl.Trainer:
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=tp,
expert_model_parallel_size=ep,
sequence_parallel=sp,
)

return nl.Trainer(
devices=max(ep, tp),
max_steps=max_steps,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
log_every_n_steps=1,
limit_val_batches=0,
val_check_interval=0,
num_sanity_val_steps=0,
)


@track_io
class OrdTokenizer:
def __init__(self, vocab_size=30_000, num_reserved_tokens=128, special_token_names=['bos_id', 'eos_id', 'pad_id']):
self.vocab_size = vocab_size
self.num_reserved_tokens = num_reserved_tokens
self.special_token_names = special_token_names
assert len(self.special_token_names) < num_reserved_tokens

def __getattr__(self, name):
if name in self.__dict__.get('special_token_names', {}):
return self.__dict__['special_token_names'].index(name)
elif name in self.__dict__:
return self.__dict__[name]
else:
raise AttributeError

def text_to_ids(self, text):
token_ids = list(map(lambda x: self.num_reserved_tokens + ord(x), list(text)))
assert max(token_ids) < self.vocab_size
return token_ids


def logger() -> nl.NeMoLogger:
ckpt = nl.ModelCheckpoint(
save_last=True,
every_n_train_steps=10,
monitor="reduced_train_loss",
save_top_k=1,
save_on_train_epoch_end=True,
save_optim_on_train_end=True,
)

return nl.NeMoLogger(
name="nemo2_peft",
log_dir="/tmp/peft_logs",
use_datetime_version=False, # must be false if using auto resume
ckpt=ckpt,
wandb=None,
)


def squad(mbs, gbs) -> pl.LightningDataModule:
return llm.SquadDataModule(seq_length=2048, micro_batch_size=mbs, global_batch_size=gbs, num_workers=0)


def mixtral_8x7b() -> pl.LightningModule:
tokenizer = OrdTokenizer()
model = llm.MixtralModel(llm.MixtralConfig8x7B(num_layers=2), tokenizer=tokenizer)
lora = llm.peft.LoRA()
return model, lora


def mistral_7b() -> pl.LightningModule:
tokenizer = OrdTokenizer()
model = llm.MistralModel(llm.MistralConfig7B(num_layers=2), tokenizer=tokenizer)
lora = llm.peft.LoRA()
return model, lora


if __name__ == '__main__':
args = get_args()
if args.model == 'mistral':
model, lora = mistral_7b()
else:
model, lora = mixtral_8x7b()
llm.finetune(
model=model,
data=squad(args.mbs, args.gbs),
trainer=trainer(args.tp, args.tp, args.ep, args.tp > 1, args.max_steps),
peft=lora,
log=logger(),
optim=nl.MegatronOptimizerModule(
config=OptimizerConfig(
optimizer="adam",
lr=0.0001,
adam_beta2=0.98,
use_distributed_optimizer=args.dist_opt,
clip_grad=1.0,
bf16=True,
),
),
)

0 comments on commit f7a88d1

Please sign in to comment.