Skip to content

Commit

Permalink
Add more recipes (#10957)
Browse files Browse the repository at this point in the history
* add recipes

Signed-off-by: Chen Cui <[email protected]>

* adjust finetuning recipe

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: cuichenx <[email protected]>
  • Loading branch information
cuichenx and cuichenx authored Oct 22, 2024
1 parent 70d8cc1 commit c20e892
Show file tree
Hide file tree
Showing 9 changed files with 1,158 additions and 10 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _import_qkv(ctx: io.TransformCTX, qkv_weights):
q = qkv_weights[0].squeeze().view(*new_q_tensor_shape)
k = qkv_weights[1].squeeze().view(*new_kv_tensor_shape)
v = qkv_weights[2].squeeze().view(*new_kv_tensor_shape)
qkv_weights = torch.empty((0, head_size) + old_tensor_shape[1:])
qkv_weights = torch.empty((0, head_size) + old_tensor_shape[1:]).type_as(qkv_weights)
for i in range(num_query_groups):
qkv_weights = torch.cat((qkv_weights, q[i * heads_per_group : (i + 1) * heads_per_group, :, :]))
qkv_weights = torch.cat((qkv_weights, k[i : i + 1, :, :]))
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/llm/gpt/model/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _import_qkv_weight(ctx: io.TransformCTX, hf_qkv_weights):
k = k.view(*new_kv_tensor_shape)
v = v.view(*new_kv_tensor_shape)

qkv_weights = torch.empty((0, head_size, old_tensor_shape[1]))
qkv_weights = torch.empty((0, head_size, old_tensor_shape[1])).type_as(hf_qkv_weights)
for i in range(num_query_groups):
qkv_weights = torch.cat((qkv_weights, q[i * heads_per_group : (i + 1) * heads_per_group, :, :]))
qkv_weights = torch.cat((qkv_weights, k[i : i + 1, :, :]))
Expand Down Expand Up @@ -251,7 +251,7 @@ def _import_qkv_bias(ctx: io.TransformCTX, hf_qkv_bias):
q = q.view(*new_q_tensor_shape)
k = k.view(*new_kv_tensor_shape)
v = v.view(*new_kv_tensor_shape)
qkv_bias = torch.empty((0, head_size))
qkv_bias = torch.empty((0, head_size)).type_as(hf_qkv_bias)
for i in range(num_query_groups):
qkv_bias = torch.cat((qkv_bias, q[i * heads_per_group : (i + 1) * heads_per_group, :]))
qkv_bias = torch.cat((qkv_bias, k[i : i + 1, :]))
Expand Down
8 changes: 8 additions & 0 deletions nemo/collections/llm/recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@


from nemo.collections.llm.recipes import (
baichuan2_7b,
chatglm3_6b,
gemma_2b,
gemma_7b,
llama3_8b,
llama3_8b_16k,
llama3_8b_64k,
Expand Down Expand Up @@ -49,6 +53,10 @@
from nemo.collections.llm.recipes.optim import adam

__all__ = [
"baichuan2_7b",
"chatglm3_6b",
"gemma_2b",
"gemma_7b",
"llama3_8b",
"llama3_8b_16k",
"llama3_8b_64k",
Expand Down
285 changes: 285 additions & 0 deletions nemo/collections/llm/recipes/baichuan2_7b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Callable, Optional

import nemo_run as run
import pytorch_lightning as pl
import torch
from megatron.core.distributed import DistributedDataParallelConfig
from pytorch_lightning.callbacks.callback import Callback

from nemo import lightning as nl
from nemo.collections.llm import Baichuan2Config7B, Baichuan2Model
from nemo.collections.llm.api import finetune, pretrain
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.peft.lora import LoRA
from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "baichuan2_7b"


@run.cli.factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Baichuan2 7B model configuration.
Returns:
run.Config[pl.LightningModule]: Configuration for the Baichuan2 7B model.
Examples:
CLI usage:
$ nemo llm pretrain model=baichuan2_7b ...
Python API usage:
>>> model_config = model()
>>> print(model_config)
"""
return run.Config(Baichuan2Model, config=run.Config(Baichuan2Config7B))


def trainer(
tensor_parallelism: int = 1,
pipeline_parallelism: int = 1,
pipeline_parallelism_type: Optional[torch.dtype] = None,
virtual_pipeline_parallelism: Optional[int] = None,
context_parallelism: int = 2,
sequence_parallelism: bool = False,
num_nodes: int = 1,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[run.Config[Callback]]] = None,
) -> run.Config[nl.Trainer]:
"""
Configure the NeMo Lightning Trainer for Baichuan2 7B model.
This function sets up the distributed training strategy and other training parameters.
Args:
tensor_parallelism (int): Degree of tensor model parallelism.
pipeline_parallelism (int): Degree of pipeline model parallelism.
pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
context_parallelism (int): Degree of context parallelism.
sequence_parallelism (bool): Whether to use sequence parallelism.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
max_steps (int): Maximum number of training steps.
callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations.
Returns:
run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer.
Examples:
CLI usage:
$ nemo llm pretrain trainer=baichuan2_7b ...
Python API usage:
>>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8)
>>> print(trainer_config)
Note:
For more information on distributed training strategies, refer to the
NeMo documentation on multi-GPU and multi-node training.
"""
strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=tensor_parallelism,
pipeline_model_parallel_size=pipeline_parallelism,
pipeline_dtype=pipeline_parallelism_type,
virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism,
context_parallel_size=context_parallelism,
sequence_parallel=sequence_parallelism,
gradient_as_bucket_view=True,
ckpt_async_save=True,
ckpt_parallel_load=True,
ddp=run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=True,
overlap_param_gather=True,
),
)

trainer = run.Config(
nl.Trainer,
accelerator="gpu",
accumulate_grad_batches=1,
callbacks=callbacks,
devices=num_gpus_per_node,
limit_test_batches=50,
limit_val_batches=32,
log_every_n_steps=10,
max_steps=max_steps,
num_nodes=num_nodes,
plugins=bf16_mixed(),
strategy=strategy,
use_distributed_sampler=False,
val_check_interval=2000,
)

return trainer


@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a pre-training recipe for Baichuan2 7B model.
This function sets up a complete configuration for pre-training, including
model, trainer, data, logging, optimization, and resumption settings.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory baichuan2_7b
$ nemo llm pretrain --factory "baichuan2_7b(num_nodes=2, name='my_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe(name="baichuan2_7b_pretrain", num_nodes=2)
>>> print(recipe)
Note:
For more details on pre-training LLMs with NeMo, see the pre-training
guide in the `examples/llm/pretrain/` directory.
"""
return run.Partial(
fn,
model=model(),
trainer=trainer(
num_nodes=num_nodes,
num_gpus_per_node=num_gpus_per_node,
callbacks=[run.Config(TimingCallback)],
),
data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1),
log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4),
resume=default_resume(),
)


@run.cli.factory(target=pretrain, name=NAME + "_optimized")
def pretrain_recipe_performance(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
fn: Callable = pretrain,
) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Baichuan2 7B model.
This recipe enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for performance-optimized pre-training.
Examples:
$ nemo llm pretrain --factory baichuan2_7b_optimized
Python API usage:
>>> recipe = pretrain_recipe_performance(name="baichuan2_7b_perf", num_nodes=4)
>>> print(recipe)
Note:
Use this recipe with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)

recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=False,
)
)
return recipe


@run.cli.factory(target=finetune, name=NAME)
def finetune_recipe(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = 'lora',
) -> run.Partial:
"""
Create a fine-tuning recipe for Baichuan2 7B model.
This function sets up a complete configuration for fine-tuning, including
model, trainer, data, logging, optimization, and resumption settings.
The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the fine-tuning run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None.
Returns:
run.Partial: Partial configuration for fine-tuning.
Examples:
CLI usage:
$ nemo llm finetune --factory baichuan2_7b
Python API usage:
>>> recipe = finetune_recipe(name="baichuan2_7b_finetune", num_nodes=2)
>>> print(recipe)
Note:
This recipe uses the SQuAD dataset for fine-tuning. For more information
on fine-tuning LLMs with NeMo, see the fine-tuning guide in the
`examples/llm/finetune/` directory.
"""
recipe = default_finetune_recipe(
model(), "baichuan-inc/Baichuan2-7B-Base", dir, name, num_nodes, num_gpus_per_node
)
if peft_scheme is None or peft_scheme.lower() == 'none':
recipe.trainer.strategy.tensor_model_parallel_size = 2
recipe.optim.config.lr = 5e-6
elif peft_scheme.lower() == 'lora':
recipe.peft = run.Config(LoRA)
recipe.optim.config.lr = 1e-4
else:
raise ValueError(f"Unrecognized peft scheme: {peft_scheme}")
return recipe
Loading

0 comments on commit c20e892

Please sign in to comment.