Skip to content

Commit

Permalink
Merge branch 'r2.0.0' into chcui/packed_seq_recipes_2.0.0
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Cui <[email protected]>
  • Loading branch information
cuichenx authored Oct 25, 2024
2 parents 6a8cefc + fe4d09b commit e9fc77c
Show file tree
Hide file tree
Showing 39 changed files with 756 additions and 365 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def export_ckpt(
def generate(
path: Union[Path, str],
prompts: list[str],
trainer: Optional[nl.Trainer] = None,
trainer: nl.Trainer,
params_dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 4,
random_seed: Optional[int] = None,
Expand Down
8 changes: 6 additions & 2 deletions nemo/collections/llm/gpt/data/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ def __init__(
self.persistent_workers = persistent_workers
self.create_attention_mask = create_attention_mask or not HAVE_TE

from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
if tokenizer is None:
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

self.tokenizer = get_nmt_tokenizer("megatron", "GPT2BPETokenizer")
else:
self.tokenizer = tokenizer

self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "GPT2BPETokenizer")
self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=micro_batch_size,
Expand Down
16 changes: 12 additions & 4 deletions nemo/collections/llm/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import nemo.lightning as nl
from nemo.lightning import io
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy
from nemo.lightning.pytorch.strategies.utils import RestoreConfig

Expand Down Expand Up @@ -44,6 +45,7 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.
load_optim_state=False,
)
trainer.strategy.restore_config = restore_config
trainer.strategy._setup_optimizers = False
trainer.ckpt_path = None
trainer.strategy.connect(model)
if trainer.strategy.launcher is not None:
Expand All @@ -61,16 +63,22 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.

def setup_model_and_tokenizer(
path: Path,
trainer: Optional[nl.Trainer] = None,
trainer: nl.Trainer,
params_dtype: torch.dtype = torch.bfloat16,
inference_batch_times_seqlen_threshold: int = 1000,
) -> tuple[MCoreGPTModel, MCoreTokenizerWrappper]:
model: io.TrainerContext = io.load_context(path=path, subpath="model")
trainer = trainer or io.load_context(path=path, subpath="trainer")
model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model")
_setup_trainer_and_restore_model(path=path, trainer=trainer, model=model)

# This is to get the MCore model required in GPTInferenceWrapper.
mcore_model = model.module.module.module
mcore_model = model
while mcore_model:
if type(mcore_model) is MCoreGPTModel:
break
mcore_model = getattr(mcore_model, "module", None)
if mcore_model is None or type(mcore_model) is not MCoreGPTModel:
raise ValueError("Exact McoreGPTModel instance not found in the model structure.")

inference_wrapped_model = GPTInferenceWrapper(
mcore_model,
InferenceWrapperConfig(
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
chatglm3_6b,
gemma_2b,
gemma_7b,
gpt3_175b,
llama3_8b,
llama3_8b_16k,
llama3_8b_64k,
Expand Down Expand Up @@ -69,6 +70,7 @@
"nemotron4_22b_16k",
"nemotron4_22b_64k",
"nemotron4_340b",
"gpt3_175b",
"adam",
"default_log",
"default_resume",
Expand Down
52 changes: 22 additions & 30 deletions nemo/collections/llm/recipes/gpt3_175b.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,12 @@ def trainer(

@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
performance_mode: bool = False,
fn: Callable = pretrain,
) -> run.Partial:
"""
Create a pre-training recipe for GPT3 175B model.
Expand All @@ -155,6 +160,7 @@ def pretrain_recipe(
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
performance_mode (bool): If true, enables optimizations for maximum performance.
fn (Callable): The pre-training function to use.
Returns:
Expand All @@ -172,7 +178,7 @@ def pretrain_recipe(
Note:
This recipe is optimized for the large 175B model and requires significant computational resources.
"""
return run.Partial(
recipe = run.Partial(
fn,
model=model(),
trainer=trainer(
Expand All @@ -186,49 +192,35 @@ def pretrain_recipe(
resume=default_resume(),
)

if performance_mode:
recipe = pretrain_performance_optimizations(recipe)

@run.cli.factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
fn: Callable = pretrain,
) -> run.Partial:
return recipe


def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for GPT3 175B model.
This recipe enables performance optimizations that may not be suitable for all use cases.
This method enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
recipe (run.Partial): Base pre-train recipe to which performance optimizations will be added
Returns:
run.Partial: Partial configuration for performance-optimized pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory "gpt3_175b.pretrain_recipe_performance(num_nodes=64, name='perf_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe_performance(name="gpt3_175b_perf", num_nodes=64)
>>> print(recipe)
Note:
Use this recipe with caution and only when you need maximum performance.
Use this method with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)

# 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically by MegatronCommOverlapCallback
# They are added here for user's knowledge
# overlap_param_gather_with_optimizer_step- If true, overlap param all-gather of first bucket with optimizer step.
# align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else each PP stage launches independently as needed
# 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically
# by MegatronCommOverlapCallback. They are added here for user's knowledge.
# overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step.
# align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else
# each PP stage launches independently as needed.

recipe.trainer.callbacks.append(
run.Config(
Expand Down
52 changes: 22 additions & 30 deletions nemo/collections/llm/recipes/llama31_405b.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,12 @@ def trainer(

@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
performance_mode: bool = False,
fn: Callable = pretrain,
) -> run.Partial:
"""
Create a pre-training recipe for Llama3.1 405B model.
Expand All @@ -157,6 +162,7 @@ def pretrain_recipe(
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
performance_mode (bool): If true, enables optimizations for maximum performance.
fn (Callable): The pre-training function to use.
Returns:
Expand All @@ -174,7 +180,7 @@ def pretrain_recipe(
Note:
This recipe is optimized for the large 405B model and requires significant computational resources.
"""
return run.Partial(
recipe = run.Partial(
fn,
model=model(),
trainer=trainer(
Expand All @@ -188,49 +194,35 @@ def pretrain_recipe(
resume=default_resume(),
)

if performance_mode:
recipe = pretrain_performance_optimizations(recipe)

@run.cli.factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
fn: Callable = pretrain,
) -> run.Partial:
return recipe


def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Llama3.1 405B model.
This recipe enables performance optimizations that may not be suitable for all use cases.
This method enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
recipe (run.Partial): Base pre-train recipe to which performance optimizations will be added
Returns:
run.Partial: Partial configuration for performance-optimized pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory "llama31_405b.pretrain_recipe_performance(num_nodes=4, name='perf_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe_performance(name="llama31_405b_perf", num_nodes=4)
>>> print(recipe)
Note:
Use this recipe with caution and only when you need maximum performance.
Use this method with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)

# 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically by MegatronCommOverlapCallback
# They are added here for user's knowledge
# overlap_param_gather_with_optimizer_step- If true, overlap param all-gather of first bucket with optimizer step.
# align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else each PP stage launches independently as needed
# 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically
# by MegatronCommOverlapCallback. They are added here for user's knowledge.
# overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step.
# align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else
# each PP stage launches independently as needed.

recipe.trainer.callbacks.append(
run.Config(
Expand Down
51 changes: 24 additions & 27 deletions nemo/collections/llm/recipes/llama3_70b.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from typing import Optional
from typing import Callable, Optional

import nemo_run as run
import pytorch_lightning as pl
Expand Down Expand Up @@ -142,7 +142,12 @@ def trainer(

@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 4, num_gpus_per_node: int = 8, fn=pretrain
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
performance_mode: bool = False,
fn: Callable = pretrain,
) -> run.Partial:
"""
Create a pre-training recipe for Llama3 70B model.
Expand All @@ -155,6 +160,7 @@ def pretrain_recipe(
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
performance_mode (bool): If true, enables optimizations for maximum performance.
fn (Callable): The pre-training function to use.
Returns:
Expand All @@ -172,7 +178,8 @@ def pretrain_recipe(
Note:
This recipe is optimized for the large 70B model and requires significant computational resources.
"""
return run.Partial(

recipe = run.Partial(
fn,
model=model(),
trainer=trainer(
Expand All @@ -186,45 +193,35 @@ def pretrain_recipe(
resume=default_resume(),
)

if performance_mode:
recipe = pretrain_performance_optimizations(recipe)

@run.cli.factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 4, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
return recipe


def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Llama3 70B model.
This recipe enables performance optimizations that may not be suitable for all use cases.
This method enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
recipe (run.Partial): Base pre-train recipe to which performance optimizations will be added
Returns:
run.Partial: Partial configuration for performance-optimized pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory "llama3_70b.pretrain_recipe_performance(num_nodes=4, name='perf_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe_performance(name="llama3_70b_perf", num_nodes=4)
>>> print(recipe)
Note:
Use this recipe with caution and only when you need maximum performance.
Use this method with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)

# 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically by MegatronCommOverlapCallback
# They are added here for user's knowledge
# overlap_param_gather_with_optimizer_step- If true, overlap param all-gather of first bucket with optimizer step.
# align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else each PP stage launches independently as needed
# 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically
# by MegatronCommOverlapCallback. They are added here for user's knowledge.
# overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step.
# align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else
# each PP stage launches independently as needed.

recipe.trainer.callbacks.append(
run.Config(
Expand Down
Loading

0 comments on commit e9fc77c

Please sign in to comment.