Skip to content

Commit

Permalink
Performance fine-tuning recipes for llama3 8b + 70b (NVIDIA#11046)
Browse files Browse the repository at this point in the history
* llama3 finetuning perf recipes progress capture

Signed-off-by: Valerie Sarge <[email protected]>

* Small syntax fix

Signed-off-by: Valerie Sarge <[email protected]>

* syntax

Signed-off-by: Valerie Sarge <[email protected]>

* Apply isort and black reformatting

Signed-off-by: vysarge <[email protected]>

* Correct ddp setting

Signed-off-by: Valerie Sarge <[email protected]>

* Fix hasattr check

Signed-off-by: Valerie Sarge <[email protected]>

* bf16 grad

Signed-off-by: Valerie Sarge <[email protected]>

* Update configs for 8b + 70b

Signed-off-by: Valerie Sarge <[email protected]>

* Set wgrad_deferral_limit

Signed-off-by: Valerie Sarge <[email protected]>

---------

Signed-off-by: Valerie Sarge <[email protected]>
Signed-off-by: vysarge <[email protected]>
Co-authored-by: vysarge <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
2 people authored and Hainan Xu committed Nov 5, 2024
1 parent 697abdf commit 05b28c2
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 3 deletions.
1 change: 0 additions & 1 deletion nemo/collections/llm/gpt/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class LlamaConfig(GPTConfig):
persist_layer_norm: bool = True
bias_dropout_fusion: bool = True
apply_rope_fusion: bool = True
cross_entropy_loss_fusion: bool = False


@dataclass
Expand Down
100 changes: 99 additions & 1 deletion nemo/collections/llm/recipes/llama3_70b.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
from nemo import lightning as nl
from nemo.collections.llm.api import finetune, pretrain
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs
from nemo.collections.llm.gpt.model.llama import Llama3Config70B, LlamaModel
from nemo.collections.llm.peft.lora import LoRA
from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

Expand Down Expand Up @@ -245,7 +247,9 @@ def finetune_recipe(
num_nodes: int = 1,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = 'lora',
packed_sequence: bool = False,
seq_length: Optional[int] = None,
packed_sequence: Optional[bool] = None,
performance_mode: bool = False,
) -> run.Partial:
"""
Create a fine-tuning recipe for Llama3 70B model.
Expand All @@ -260,6 +264,9 @@ def finetune_recipe(
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None.
seq_length (int): Maximum number of tokens per microbatch.
packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode.
performance_mode (bool): If true, enables optimizations for maximum performance.
Returns:
run.Partial: Partial configuration for fine-tuning.
Expand All @@ -277,6 +284,15 @@ def finetune_recipe(
This recipe uses the SQuAD dataset for fine-tuning. Be aware that fine-tuning a 70B model
requires substantial computational resources.
"""
# Default to unpacked data in normal mode and packed data in performance mode
# once packing recipe is well tested, change this default to true
if packed_sequence is None:
packed_sequence = performance_mode

# For unpacked sequence, most samples in SQuAD dataset are shorter than 2K
if seq_length is None:
seq_length = 4096 if packed_sequence else 2048

recipe = default_finetune_recipe(
model(), "meta-llama/Meta-Llama-3-70B", dir, name, num_nodes, num_gpus_per_node, packed_sequence
)
Expand All @@ -287,8 +303,90 @@ def finetune_recipe(
recipe.optim.config.lr = 5e-6
elif peft_scheme.lower() == 'lora':
recipe.peft = run.Config(LoRA)
recipe.peft.dim = 16
recipe.peft.alpha = 32
recipe.peft.target_modules = ['linear_qkv']

# some settings currently do not function correctly with LoRA
recipe.model.config.cross_entropy_loss_fusion = False

recipe.trainer.strategy.tensor_model_parallel_size = 8
recipe.optim.config.lr = 1e-4
else:
raise ValueError(f"Unrecognized peft scheme: {peft_scheme}")

# Sequence length settings in the model and dataset must agree
recipe.model.config.seq_length = seq_length
recipe.data.seq_length = seq_length
if packed_sequence:
recipe.data.pad_to_max_length = True
recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length)

if performance_mode:
recipe = finetune_performance_optimizations(recipe, peft_scheme)

return recipe


def finetune_performance_optimizations(
recipe: run.Partial,
peft_scheme: str,
) -> run.Partial:
"""
Modify the given recipe to optimize settings for performance.
This method enables performance optimizations that may not be suitable for all use cases.
Intended to build upon the standard fine-tuning recipe.
Args:
recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added
peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None.
Returns:
run.Partial: Partial configuration for performance-optimized fine-tuning.
Note:
Use this method with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""

if not hasattr(recipe.trainer, "callbacks"):
recipe.trainer.callbacks = []

if peft_scheme is None or peft_scheme.lower() == 'none':
recipe.trainer.strategy.tensor_model_parallel_size = 4
recipe.trainer.strategy.pipeline_model_parallel_size = 4
recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5
recipe.trainer.plugins.grad_reduce_in_fp32 = False
recipe.trainer.strategy.ddp = run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=False,
overlap_grad_reduce=True,
overlap_param_gather=True,
average_in_collective=True,
)
recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
defer_embedding_wgrad_compute=True,
wgrad_deferral_limit=22,
)
)
else:
recipe.trainer.strategy.tensor_model_parallel_size = 2
recipe.trainer.strategy.pipeline_model_parallel_size = 4

recipe.trainer.strategy.sequence_parallel = True

recipe.trainer.callbacks.append(run.Config(TimingCallback))
recipe.trainer.callbacks.append(
run.Config(
GarbageCollectionCallback,
100,
100,
)
)

return recipe
91 changes: 90 additions & 1 deletion nemo/collections/llm/recipes/llama3_8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
from nemo import lightning as nl
from nemo.collections.llm.api import finetune, pretrain
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs
from nemo.collections.llm.gpt.data.squad import SquadDataModule
from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel
from nemo.collections.llm.peft.lora import LoRA
from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

Expand Down Expand Up @@ -233,7 +235,9 @@ def finetune_recipe(
num_nodes: int = 1,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = 'lora',
packed_sequence: bool = False, # once packing recipe is well tested, change this default to true
seq_length: Optional[int] = None,
packed_sequence: Optional[bool] = None,
performance_mode: bool = False,
) -> run.Partial:
"""
Create a fine-tuning recipe for Llama3 8B model.
Expand All @@ -248,6 +252,9 @@ def finetune_recipe(
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None.
seq_length (int): Maximum number of tokens per microbatch.
packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode.
performance_mode (bool): If true, enables optimizations for maximum performance.
Returns:
run.Partial: Partial configuration for fine-tuning.
Expand All @@ -265,6 +272,15 @@ def finetune_recipe(
on fine-tuning LLMs with NeMo, see the fine-tuning guide in the
`examples/llm/finetune/` directory.
"""
# Default to unpacked data in normal mode and packed data in performance mode
# once packing recipe is well tested, change this default to true
if packed_sequence is None:
packed_sequence = performance_mode

# For unpacked sequence, most samples in SQuAD dataset are shorter than 2K
if seq_length is None:
seq_length = 4096 if packed_sequence else 2048

recipe = default_finetune_recipe(
model(), "meta-llama/Meta-Llama-3-8B", dir, name, num_nodes, num_gpus_per_node, packed_sequence
)
Expand All @@ -273,7 +289,80 @@ def finetune_recipe(
recipe.optim.config.lr = 5e-6
elif peft_scheme.lower() == 'lora':
recipe.peft = run.Config(LoRA)
recipe.peft.dim = 8
recipe.peft.alpha = 16
recipe.peft.target_modules = ['linear_qkv']

# some settings currently do not function correctly with LoRA
recipe.model.config.cross_entropy_loss_fusion = False

recipe.optim.config.lr = 1e-4
else:
raise ValueError(f"Unrecognized peft scheme: {peft_scheme}")

# Sequence length settings in the model and dataset must agree
recipe.model.config.seq_length = seq_length
recipe.data.seq_length = seq_length
if packed_sequence:
recipe.data.pad_to_max_length = True
recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length)

if performance_mode:
recipe = finetune_performance_optimizations(recipe, peft_scheme)

return recipe


def finetune_performance_optimizations(
recipe: run.Partial,
peft_scheme: str,
) -> run.Partial:
"""
Modify the given recipe to optimize settings for performance.
This method enables performance optimizations that may not be suitable for all use cases.
Intended to build upon the standard fine-tuning recipe.
Args:
recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added
peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None.
Returns:
run.Partial: Partial configuration for performance-optimized fine-tuning.
Note:
Use this method with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe.trainer.strategy.tensor_model_parallel_size = 1

if not hasattr(recipe.trainer, "callbacks"):
recipe.trainer.callbacks = []

if peft_scheme is None or peft_scheme.lower() == 'none':
recipe.trainer.plugins.grad_reduce_in_fp32 = False
recipe.trainer.strategy.ddp = run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=False,
overlap_grad_reduce=True,
overlap_param_gather=True,
average_in_collective=True,
)
recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=False,
)
)

recipe.trainer.callbacks.append(run.Config(TimingCallback))
recipe.trainer.callbacks.append(
run.Config(
GarbageCollectionCallback,
100,
100,
)
)

return recipe

0 comments on commit 05b28c2

Please sign in to comment.