Skip to content

Commit

Permalink
Mistral-NeMo-12B recipe (#10607)
Browse files Browse the repository at this point in the history
* Mistral-NeMo-12B recipe

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* rename mistral to mistral_7b

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* include mistral_nemo_12b in __init__

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

* add to __init__

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

* Remove stale imports

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* TP=2

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* remove finetune_reci[e

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Rename MistralNeMo2407Config12B to MistralNeMoConfig12B per review's suggestion

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* update config names in tests

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* mistral-nemo-12b from llama_8b

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* TP=2; SP=True

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix overlap value

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

* update mistral-nemo-base-12b finetune recipe

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa and akoumpa authored Oct 21, 2024
1 parent cec3e0a commit e8a801b
Show file tree
Hide file tree
Showing 9 changed files with 301 additions and 12 deletions.
2 changes: 2 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
MaskedTokenLossReduction,
MistralConfig7B,
MistralModel,
MistralNeMoConfig12B,
MixtralConfig8x3B,
MixtralConfig8x7B,
MixtralConfig8x22B,
Expand Down Expand Up @@ -115,6 +116,7 @@
"t5_forward_step",
"MaskedTokenLossReduction",
"MistralConfig7B",
"MistralNeMoConfig12B",
"MistralModel",
"MixtralConfig8x3B",
"MixtralConfig8x7B",
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
LlamaConfig,
LlamaModel,
)
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel, MistralNeMoConfig12B
from nemo.collections.llm.gpt.model.mixtral import (
MixtralConfig8x3B,
MixtralConfig8x7B,
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/llm/gpt/model/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class MistralConfig7B(GPTConfig):


@dataclass
class MistralNeMo2407Config12B(MistralConfig7B):
class MistralNeMoConfig12B(MistralConfig7B):
"""
https://mistral.ai/news/mistral-nemo/
"""
Expand All @@ -75,7 +75,7 @@ class MistralNeMo2407Config12B(MistralConfig7B):


@dataclass
class MistralNeMo2407Config123B(MistralConfig7B):
class MistralNeMoConfig123B(MistralConfig7B):
"""
https://mistral.ai/news/mistral-large-2407/
"""
Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/llm/recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
llama3_70b_16k,
llama3_70b_64k,
llama31_405b,
mistral,
mistral_7b,
mistral_nemo_12b,
mixtral_8x7b,
mixtral_8x7b_16k,
mixtral_8x7b_64k,
Expand All @@ -48,7 +49,8 @@
"llama3_70b_16k",
"llama3_70b_64k",
"llama31_405b",
"mistral",
"mistral_7b",
"mistral_nemo_12b",
"mixtral_8x7b",
"mixtral_8x7b_16k",
"mixtral_8x7b_64k",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.utils.exp_manager import TimingCallback

NAME = "mistral"
NAME = "mistral_7b"


@run.cli.factory(name=NAME)
Expand Down
285 changes: 285 additions & 0 deletions nemo/collections/llm/recipes/mistral_nemo_12b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Callable, Optional

import nemo_run as run
import pytorch_lightning as pl
import torch
from megatron.core.distributed import DistributedDataParallelConfig
from pytorch_lightning.callbacks.callback import Callback

from nemo import lightning as nl
from nemo.collections.llm.api import finetune, pretrain
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.squad import SquadDataModule
from nemo.collections.llm.gpt.model.mistral import MistralModel, MistralNeMoConfig12B
from nemo.collections.llm.peft.lora import LoRA
from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "mistral_nemo_base_12b"


@run.cli.factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Mistral-Nemo-Base-12B model configuration.
Returns:
run.Config[pl.LightningModule]: Configuration for the Mistral-Nemo-Base-12B model.
Examples:
CLI usage:
$ nemo llm pretrain model=mistral_nemo_base_12b ...
Python API usage:
>>> model_config = model()
>>> print(model_config)
"""
return run.Config(MistralModel, config=run.Config(MistralNeMoConfig12B))


def trainer(
tensor_parallelism: int = 2,
pipeline_parallelism: int = 1,
pipeline_parallelism_type: Optional[torch.dtype] = None,
virtual_pipeline_parallelism: Optional[int] = None,
context_parallelism: int = 2,
sequence_parallelism: bool = True,
num_nodes: int = 1,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[run.Config[Callback]]] = None,
) -> run.Config[nl.Trainer]:
"""
Configure the NeMo Lightning Trainer for Mistral-Nemo-Base-12B model.
This function sets up the distributed training strategy and other training parameters.
Args:
tensor_parallelism (int): Degree of tensor model parallelism.
pipeline_parallelism (int): Degree of pipeline model parallelism.
pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
context_parallelism (int): Degree of context parallelism.
sequence_parallelism (bool): Whether to use sequence parallelism.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
max_steps (int): Maximum number of training steps.
callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations.
Returns:
run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer.
Examples:
CLI usage:
$ nemo llm pretrain trainer=mistral_nemo_base_12b ...
Python API usage:
>>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8)
>>> print(trainer_config)
Note:
For more information on distributed training strategies, refer to the
NeMo documentation on multi-GPU and multi-node training.
"""
strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=tensor_parallelism,
pipeline_model_parallel_size=pipeline_parallelism,
pipeline_dtype=pipeline_parallelism_type,
virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism,
context_parallel_size=context_parallelism,
sequence_parallel=sequence_parallelism,
gradient_as_bucket_view=True,
ckpt_async_save=True,
ckpt_parallel_load=True,
ddp=run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=True,
overlap_param_gather=True,
),
)

trainer = run.Config(
nl.Trainer,
accelerator="gpu",
accumulate_grad_batches=1,
callbacks=callbacks,
devices=num_gpus_per_node,
limit_test_batches=50,
limit_val_batches=32,
log_every_n_steps=10,
max_steps=max_steps,
num_nodes=num_nodes,
plugins=bf16_mixed(),
strategy=strategy,
use_distributed_sampler=False,
val_check_interval=2000,
)

return trainer


@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a pre-training recipe for Mistral-Nemo-Base-12B model.
This function sets up a complete configuration for pre-training, including
model, trainer, data, logging, optimization, and resumption settings.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory mistral_nemo_base_12b
$ nemo llm pretrain --factory "mistral_nemo_base_12b(num_nodes=2, name='my_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe(name="mistral_nemo_base_12b", num_nodes=2)
>>> print(recipe)
Note:
For more details on pre-training LLMs with NeMo, see the pre-training
guide in the `examples/llm/pretrain/` directory.
"""
return run.Partial(
fn,
model=model(),
trainer=trainer(
num_nodes=num_nodes,
num_gpus_per_node=num_gpus_per_node,
callbacks=[run.Config(TimingCallback)],
),
data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1),
log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4),
resume=default_resume(),
)


@run.cli.factory(target=pretrain, name=NAME + "_optimized")
def pretrain_recipe_performance(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
fn: Callable = pretrain,
) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Mistral-Nemo-Base-12B model.
This recipe enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for performance-optimized pre-training.
Examples:
$ nemo llm pretrain --factory mistral_nemo_base_12b_optimized
Python API usage:
>>> recipe = pretrain_recipe_performance(name="mistral_nemo_base_12b_perf", num_nodes=4)
>>> print(recipe)
Note:
Use this recipe with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)

recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
)
)
return recipe


@run.cli.factory(target=finetune, name=NAME)
def finetune_recipe(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = 'lora',
) -> run.Partial:
"""
Create a fine-tuning recipe for Mistral-Nemo-Base-12B model.
This function sets up a complete configuration for fine-tuning, including
model, trainer, data, logging, optimization, and resumption settings.
The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the fine-tuning run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None.
Returns:
run.Partial: Partial configuration for fine-tuning.
Examples:
CLI usage:
$ nemo llm finetune --factory mistral_nemo_base_12b
Python API usage:
>>> recipe = finetune_recipe(name="mistral_nemo_base_12b_finetune", num_nodes=2)
>>> print(recipe)
Note:
This recipe uses the SQuAD dataset for fine-tuning. For more information
on fine-tuning LLMs with NeMo, see the fine-tuning guide in the
`examples/llm/finetune/` directory.
"""
recipe = default_finetune_recipe(
model(), "mistralai/Mistral-Nemo-Base-2407", dir, name, num_nodes, num_gpus_per_node
)
if peft_scheme is None or peft_scheme.lower() == 'none':
recipe.optim.config.lr = 5e-6
elif peft_scheme.lower() == 'lora':
recipe.peft = run.Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32)
recipe.optim.config.lr = 1e-4
else:
raise ValueError(f"Unrecognized peft scheme: {peft_scheme}")
return recipe
6 changes: 3 additions & 3 deletions tests/collections/llm/gpt/model/test_mistral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.nn.functional as F

from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralNeMo2407Config12B, MistralNeMo2407Config123B
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralNeMoConfig12B, MistralNeMoConfig123B


def test_mistral_config7b():
Expand All @@ -25,7 +25,7 @@ def test_mistral_config7b():


def test_mistral_nemo_config_12b():
config = MistralNeMo2407Config12B()
config = MistralNeMoConfig12B()
assert config.normalization == "RMSNorm"
assert config.activation_func == F.silu
assert config.position_embedding_type == "rope"
Expand All @@ -49,7 +49,7 @@ def test_mistral_nemo_config_12b():


def test_mistral_nemo_config_123b():
config = MistralNeMo2407Config123B()
config = MistralNeMoConfig123B()
assert config.normalization == "RMSNorm"
assert config.activation_func == F.silu
assert config.position_embedding_type == "rope"
Expand Down
2 changes: 1 addition & 1 deletion tests/collections/llm/recipes/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from nemo.collections.llm.gpt.data.squad import SquadDataModule
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel
from nemo.collections.llm.peft.lora import LoRA
from nemo.collections.llm.recipes import mistral
from nemo.collections.llm.recipes import mistral_7b as mistral
from nemo.lightning import AutoResume, Trainer


Expand Down
Loading

0 comments on commit e8a801b

Please sign in to comment.