Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance fine-tuning recipes for llama3 8b + 70b #11046

Merged
merged 10 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion nemo/collections/llm/gpt/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class LlamaConfig(GPTConfig):
persist_layer_norm: bool = True
bias_dropout_fusion: bool = True
apply_rope_fusion: bool = True
cross_entropy_loss_fusion: bool = False

JimmyZhang12 marked this conversation as resolved.
Show resolved Hide resolved

@dataclass
Expand Down
99 changes: 98 additions & 1 deletion nemo/collections/llm/recipes/llama3_70b.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
from nemo import lightning as nl
from nemo.collections.llm.api import finetune, pretrain
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs
from nemo.collections.llm.gpt.model.llama import Llama3Config70B, LlamaModel
from nemo.collections.llm.peft.lora import LoRA
from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

Expand Down Expand Up @@ -245,7 +247,9 @@ def finetune_recipe(
num_nodes: int = 1,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = 'lora',
packed_sequence: bool = False,
seq_length: Optional[int] = None,
packed_sequence: Optional[bool] = None,
performance_mode: bool = False,
) -> run.Partial:
"""
Create a fine-tuning recipe for Llama3 70B model.
Expand All @@ -260,6 +264,9 @@ def finetune_recipe(
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None.
seq_length (int): Maximum number of tokens per microbatch.
packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode.
performance_mode (bool): If true, enables optimizations for maximum performance.

Returns:
run.Partial: Partial configuration for fine-tuning.
Expand All @@ -277,6 +284,15 @@ def finetune_recipe(
This recipe uses the SQuAD dataset for fine-tuning. Be aware that fine-tuning a 70B model
requires substantial computational resources.
"""
# Default to unpacked data in normal mode and packed data in performance mode
# once packing recipe is well tested, change this default to true
if packed_sequence is None:
packed_sequence = performance_mode

# For unpacked sequence, most samples in SQuAD dataset are shorter than 2K
if seq_length is None:
seq_length = 4096 if packed_sequence else 2048

recipe = default_finetune_recipe(
model(), "meta-llama/Meta-Llama-3-70B", dir, name, num_nodes, num_gpus_per_node, packed_sequence
)
Expand All @@ -287,8 +303,89 @@ def finetune_recipe(
recipe.optim.config.lr = 5e-6
elif peft_scheme.lower() == 'lora':
recipe.peft = run.Config(LoRA)
recipe.peft.dim = 16
recipe.peft.alpha = 32
recipe.peft.target_modules = ['linear_qkv']

# some settings currently do not function correctly with LoRA
recipe.model.config.cross_entropy_loss_fusion = False

recipe.trainer.strategy.tensor_model_parallel_size = 8
recipe.optim.config.lr = 1e-4
else:
raise ValueError(f"Unrecognized peft scheme: {peft_scheme}")

# Sequence length settings in the model and dataset must agree
recipe.model.config.seq_length = seq_length
recipe.data.seq_length = seq_length
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without packed sequence, we usually use seq_length=2048 (because most samples in Squad are well shorter than 2048 tokens)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just altered this to use 4K for packed and 2K for unpacked by default.

if packed_sequence:
recipe.data.pad_to_max_length = True
recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length)

if performance_mode:
recipe = finetune_performance_optimizations(recipe, peft_scheme)

return recipe


def finetune_performance_optimizations(
recipe: run.Partial,
peft_scheme: str,
) -> run.Partial:
"""
Modify the given recipe to optimize settings for performance.

This method enables performance optimizations that may not be suitable for all use cases.
Intended to build upon the standard fine-tuning recipe.

Args:
recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added
peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None.

Returns:
run.Partial: Partial configuration for performance-optimized fine-tuning.

Note:
Use this method with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""

if not hasattr(recipe.trainer, "callbacks"):
recipe.trainer.callbacks = []

if peft_scheme is None or peft_scheme.lower() == 'none':
recipe.trainer.strategy.tensor_model_parallel_size = 4
recipe.trainer.strategy.pipeline_model_parallel_size = 4
recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5
recipe.trainer.plugins.grad_reduce_in_fp32 = False
recipe.trainer.strategy.ddp = run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=False,
overlap_grad_reduce=True,
overlap_param_gather=True,
average_in_collective=True,
)
recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
JimmyZhang12 marked this conversation as resolved.
Show resolved Hide resolved
tp_comm_overlap=True,
defer_embedding_wgrad_compute=True,
)
)
else:
recipe.trainer.strategy.tensor_model_parallel_size = 2
recipe.trainer.strategy.pipeline_model_parallel_size = 4

recipe.trainer.strategy.sequence_parallel = True

recipe.trainer.callbacks.append(run.Config(TimingCallback))
recipe.trainer.callbacks.append(
run.Config(
GarbageCollectionCallback,
100,
100,
)
)

return recipe
91 changes: 90 additions & 1 deletion nemo/collections/llm/recipes/llama3_8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
from nemo import lightning as nl
from nemo.collections.llm.api import finetune, pretrain
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs
from nemo.collections.llm.gpt.data.squad import SquadDataModule
from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel
from nemo.collections.llm.peft.lora import LoRA
from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

Expand Down Expand Up @@ -233,7 +235,9 @@ def finetune_recipe(
num_nodes: int = 1,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = 'lora',
packed_sequence: bool = False, # once packing recipe is well tested, change this default to true
seq_length: Optional[int] = None,
packed_sequence: Optional[bool] = None,
performance_mode: bool = False,
) -> run.Partial:
"""
Create a fine-tuning recipe for Llama3 8B model.
Expand All @@ -248,6 +252,9 @@ def finetune_recipe(
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None.
seq_length (int): Maximum number of tokens per microbatch.
packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode.
performance_mode (bool): If true, enables optimizations for maximum performance.

Returns:
run.Partial: Partial configuration for fine-tuning.
Expand All @@ -265,6 +272,15 @@ def finetune_recipe(
on fine-tuning LLMs with NeMo, see the fine-tuning guide in the
`examples/llm/finetune/` directory.
"""
# Default to unpacked data in normal mode and packed data in performance mode
# once packing recipe is well tested, change this default to true
if packed_sequence is None:
packed_sequence = performance_mode

# For unpacked sequence, most samples in SQuAD dataset are shorter than 2K
if seq_length is None:
seq_length = 4096 if packed_sequence else 2048

recipe = default_finetune_recipe(
model(), "meta-llama/Meta-Llama-3-8B", dir, name, num_nodes, num_gpus_per_node, packed_sequence
)
Expand All @@ -273,7 +289,80 @@ def finetune_recipe(
recipe.optim.config.lr = 5e-6
elif peft_scheme.lower() == 'lora':
recipe.peft = run.Config(LoRA)
recipe.peft.dim = 8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you verified that dim=8 would not compromise accuracy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rank 8 follows the default from torchtune here; in tests with 8b and 70b loss looks to decrease smoothly across 50-100 train steps.

recipe.peft.alpha = 16
recipe.peft.target_modules = ['linear_qkv']

# some settings currently do not function correctly with LoRA
recipe.model.config.cross_entropy_loss_fusion = False

recipe.optim.config.lr = 1e-4
else:
raise ValueError(f"Unrecognized peft scheme: {peft_scheme}")

# Sequence length settings in the model and dataset must agree
recipe.model.config.seq_length = seq_length
recipe.data.seq_length = seq_length
if packed_sequence:
recipe.data.pad_to_max_length = True
recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length)

if performance_mode:
recipe = finetune_performance_optimizations(recipe, peft_scheme)

return recipe


def finetune_performance_optimizations(
recipe: run.Partial,
peft_scheme: str,
) -> run.Partial:
"""
Modify the given recipe to optimize settings for performance.

This method enables performance optimizations that may not be suitable for all use cases.
Intended to build upon the standard fine-tuning recipe.

Args:
recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added
peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None.

Returns:
run.Partial: Partial configuration for performance-optimized fine-tuning.

Note:
Use this method with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe.trainer.strategy.tensor_model_parallel_size = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just checking, TP1 works for full finetuning as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With BF16 grad it should. Let me double-check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, both fit.


if not hasattr(recipe.trainer, "callbacks"):
recipe.trainer.callbacks = []

if peft_scheme is None or peft_scheme.lower() == 'none':
recipe.trainer.plugins.grad_reduce_in_fp32 = False
recipe.trainer.strategy.ddp = run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=False,
overlap_grad_reduce=True,
overlap_param_gather=True,
average_in_collective=True,
)
recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=False,
)
)

recipe.trainer.callbacks.append(run.Config(TimingCallback))
recipe.trainer.callbacks.append(
run.Config(
GarbageCollectionCallback,
100,
100,
)
)

return recipe
Loading