diff --git a/.github/workflows/_test_template.yml b/.github/workflows/_test_template.yml index 54579ab2d850..e2401cab7f64 100644 --- a/.github/workflows/_test_template.yml +++ b/.github/workflows/_test_template.yml @@ -79,7 +79,7 @@ jobs: echo "log=$(tail -c 2000 err.log | base64 -w 0)" >> "$GITHUB_OUTPUT" - potential_infra_failure=$(cat err.log | grep -Eqi "gpu|cuda|device" && echo true || echo false) + potential_infra_failure=$(cat err.log | grep -Eqiw "device" && echo true || echo false) echo "potential_infra_failure=$potential_infra_failure" >> "$GITHUB_OUTPUT" exit $EXIT_CODE diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 6b39d2a9082e..22bbb3c1a447 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -31,6 +31,7 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true + jobs: pre-flight: runs-on: ubuntu-latest @@ -2589,27 +2590,19 @@ jobs: mkdir examples/llm/auto_configurator/auto_conf_logs python examples/llm/auto_configurator/auto_config.py \ - --logs_dir=/workspace/examples/llm/auto_configurator/auto_conf_logs \ - --data_path=/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document \ - --tokenizer_path=/home/TestData/nlp/gpt2_tokenizer \ + --log_dir=/workspace/examples/llm/auto_configurator/auto_conf_logs \ --run_number=1 python examples/llm/auto_configurator/auto_config.py \ - --logs_dir=/workspace/examples/llm/auto_configurator/auto_conf_logs \ - --data_path=/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document \ - --tokenizer_path=/home/TestData/nlp/gpt2_tokenizer \ + --log_dir=/workspace/examples/llm/auto_configurator/auto_conf_logs \ --run_number=2 python examples/llm/auto_configurator/auto_config.py \ - --logs_dir=/workspace/examples/llm/auto_configurator/auto_conf_logs \ - --data_path=/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document \ - --tokenizer_path=/home/TestData/nlp/gpt2_tokenizer \ + --log_dir=/workspace/examples/llm/auto_configurator/auto_conf_logs \ --run_number=3 python examples/llm/auto_configurator/auto_config.py \ - --logs_dir=/workspace/examples/llm/auto_configurator/auto_conf_logs \ - --data_path=/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document \ - --tokenizer_path=/home/TestData/nlp/gpt2_tokenizer \ + --log_dir=/workspace/examples/llm/auto_configurator/auto_conf_logs \ --get_results AFTER_SCRIPT: | rm -rf examples/llm/auto_configurator/auto_conf_logs @@ -3887,6 +3880,34 @@ jobs: rm -rf tests/collections/llm/gpt_pretrain_results rm -rf tests/collections/llm/gpt_index_mappings + L2_NeMo_2_llama3_pretraining_recipe: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_llama3_pretraining_recipe') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + + python tests/collections/llm/llama3_pretraining.py \ + --seq-length 1024 \ + --devices=2 \ + --max-steps=6 \ + --early-stop=3 \ + --experiment-dir=/tmp/llm_tests/llama_pretrain_results \ + --data-path=/home/TestData/nlp/megatron_llama/data/rp2_sample_sentencepiece_preproc_text_document \ + --tokenizer-path=/home/TestData/nlp/megatron_llama/tokenizer.model \ + --index-mapping-dir=/tmp/llm_tests/llama_index_mappings \ + + python tests/collections/llm/llama3_pretraining.py \ + --seq-length 1024 \ + --devices=2 \ + --max-steps=6 \ + --experiment-dir=/tmp/llm_tests/llama_pretrain_results \ + --data-path=/home/TestData/nlp/megatron_llama/data/rp2_sample_sentencepiece_preproc_text_document \ + --tokenizer-path=/home/TestData/nlp/megatron_llama/tokenizer.model \ + --index-mapping-dir=/tmp/llm_tests/llama_index_mappings \ + --cp 1 --tp 2 --sp 1 + L2_NeMo_2_GPT_DDP_Param_Parity_check: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -4439,6 +4460,7 @@ jobs: - L2_NeMo_2_GPT_Pretraining_no_transformer_engine - L2_NeMo_2_GPT_DDP_Param_Parity_check - L2_NeMo_2_HF_MODEL_IMPORT + - L2_NeMo_2_llama3_pretraining_recipe - L2_NeMo_2_SSM_Pretraining - L2_NeMo_2_SSM_Finetuning - L2_NeMo_2_T5_Pretraining diff --git a/.github/workflows/copyright-check.yml b/.github/workflows/copyright-check.yml new file mode 100644 index 000000000000..724f3afb6177 --- /dev/null +++ b/.github/workflows/copyright-check.yml @@ -0,0 +1,59 @@ +# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: Copyright check + +on: + pull_request: + +jobs: + main: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} + fetch-depth: 0 + + - name: Check files have copyright notice + run: | + cd ${{ github.run_id }} + + # Files ending with .py should have Copyright notice in the first 10 lines + find_files_with_missing_copyright() { + find ./ -type f -name '*.py' -not -path "./.git/*" -not -path "./*__init__.py" | while read path; do + echo -en $path"\t" + head -n 10 $path | tr '\n' '\t' | sed 's/\t$/\n/' + done \ + | egrep -iv 'Copyright.*NVIDIA CORPORATION.*' \ + | egrep -iv '*MIT.*Licen.e.*' \ + | egrep -iv '*Copyright.*Apache.*' \ + | egrep -iv '*Apache.*License.*' \ + | while read line; do + echo $line | cut -d' ' -f1 + done + } + + + declare RESULT=($(find_files_with_missing_copyright)) # (..) = array + + if [ "${#RESULT[@]}" -gt 0 ]; then + echo "Error: Found files with missing copyright:" + for (( i=0; i<"${#RESULT[@]}"; i++ )); do + echo "path= ${RESULT[$i]}" + done + exit 1; + else + echo "Ok: All (Python) files start with copyright notice" + fi diff --git a/.github/workflows/monitor-vms.yml b/.github/workflows/monitor-vms.yml index 03f37d48e8ea..6795f87abf68 100644 --- a/.github/workflows/monitor-vms.yml +++ b/.github/workflows/monitor-vms.yml @@ -2,9 +2,8 @@ name: Reboots VMs in a controlled way on: schedule: - - cron: /15 * * * * + - cron: 0/15 * * * * workflow_dispatch: - pull_request: jobs: pre-flight: @@ -28,7 +27,7 @@ jobs: | jq -c '[ .runners[] | select(.status == "online") - | select(.name | contains("gpu") + | select(.name | contains("gpu")) | { "vm": .name, "n_gpus": [ @@ -47,8 +46,8 @@ jobs: fail-fast: false matrix: include: ${{ fromJSON(needs.pre-flight.outputs.list-of-vms )}} - uses: .github/workflows/monitor-single-vm.yml + uses: ./.github/workflows/monitor-single-vm.yml with: vm: ${{ matrix.vm }} n_gpus: ${{ matrix.n_gpus }} - secrets: inherit + secrets: inherit # pragma: allowlist secret diff --git a/.github/workflows/secrets-detector.yml b/.github/workflows/secrets-detector.yml index 0cf73e961bd4..cf8ccc189ab6 100644 --- a/.github/workflows/secrets-detector.yml +++ b/.github/workflows/secrets-detector.yml @@ -1,35 +1,37 @@ -# # Copyright (c) 2020-2021, NVIDIA CORPORATION. -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. -# name: Secrets detector +# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: Secrets detector -# on: -# pull_request: +on: + pull_request: + branches: + - 'main' -# jobs: -# main: -# runs-on: ubuntu-latest -# steps: -# - name: Checkout repository -# uses: actions/checkout@v4 -# with: -# path: ${{ github.run_id }} -# fetch-depth: 0 +jobs: + main: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} + fetch-depth: 0 -# - name: Install secrets detector -# run: pip install detect-secrets + - name: Install secrets detector + run: pip install detect-secrets -# - name: Run on change-set -# run: | -# cd ${{ github.run_id }} -# git diff --name-only --diff-filter=d --merge-base origin/${{ github.base_ref }} -z | xargs -0 detect-secrets-hook --baseline .github/workflows/config/.secrets.baseline \ No newline at end of file + - name: Run on change-set + run: | + cd ${{ github.run_id }} + git diff --name-only --diff-filter=d --merge-base origin/main -z | xargs -0 detect-secrets-hook --baseline .secrets.baseline \ No newline at end of file diff --git a/.github/workflows/config/.secrets.baseline b/.secrets.baseline similarity index 99% rename from .github/workflows/config/.secrets.baseline rename to .secrets.baseline index 4a56aaad3c58..c26f70775c5a 100644 --- a/.github/workflows/config/.secrets.baseline +++ b/.secrets.baseline @@ -123,13 +123,13 @@ } ], "results": { - ".github/workflows/cicd-main.yml": [ + ".github/workflows/node-reboot.yml": [ { - "type": "Base64 High Entropy String", - "filename": ".github/workflows/cicd-main.yml", - "hashed_secret": "593951c440200143335452427205ae7c8580d463", + "type": "Secret Keyword", + "filename": ".github/workflows/node-reboot.yml", + "hashed_secret": "3e26d6750975d678acb8fa35a0f69237881576b0", "is_verified": false, - "line_number": 1503 + "line_number": 52 } ], "docs/source/nlp/question_answering.rst": [ @@ -1229,9 +1229,9 @@ { "type": "Base64 High Entropy String", "filename": "tests/infer_data_path.py", - "hashed_secret": "e3fb89ccb261c88146519164f7e8a47786d33fee", + "hashed_secret": "8e0937151cfd9750db688fbe66be37d0c53ed6ab", "is_verified": false, - "line_number": 271 + "line_number": 63 } ], "tutorials/asr/Multilang_ASR.ipynb": [ @@ -1902,7 +1902,7 @@ "filename": "tutorials/multimodal/Multimodal Data Preparation.ipynb", "hashed_secret": "b641cbe299c9e27b480cc8a823bb020d45962236", "is_verified": false, - "line_number": 660 + "line_number": 658 } ], "tutorials/nlp/ITN_with_Thutmose_Tagger.ipynb": [ @@ -2083,5 +2083,5 @@ } ] }, - "generated_at": "2024-09-08T19:00:15Z" + "generated_at": "2024-10-25T13:43:17Z" } diff --git a/examples/llm/auto_configurator/auto_config.py b/examples/llm/auto_configurator/auto_config.py index 777e0d290fcb..b9b9c7b023d5 100644 --- a/examples/llm/auto_configurator/auto_config.py +++ b/examples/llm/auto_configurator/auto_config.py @@ -14,48 +14,76 @@ import argparse import os +from dataclasses import dataclass +from functools import partial import fiddle as fdl import nemo_run as run -from nemo.collections.llm import GPTConfig126M +from nemo.collections import llm +from nemo.collections.llm.gpt.model.llama import Llama3Config, LlamaModel from nemo.collections.llm.tools.auto_configurator import AutoConfigurator, generate_configs, get_results def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--run_number", type=int, help="Number of config to run") - parser.add_argument("--logs_dir", type=str, help="Path where to save training logs") - parser.add_argument("--data_path", type=str, help="Path to the dataset") - parser.add_argument("--tokenizer_path", type=str, help="Path to the tokenizer") + parser.add_argument("--log_dir", type=str, help="Path where to save training logs") parser.add_argument("--get_results", action="store_true") return parser.parse_args() +@dataclass +class Llama3Config145M(Llama3Config): + num_layers: int = 12 + hidden_size: int = 768 + num_attention_heads: int = 16 + num_query_groups: int = 8 + ffn_hidden_size: int = 2688 + + +@run.cli.factory(target=llm.pretrain, name="llama3_145m") +def llama3_145m(num_nodes=1, num_gpus_per_node=1): + # Setup Llama3 145M config + recipe = partial(llm.llama3_8b.pretrain_recipe, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node)() + recipe.data.global_batch_size = 16 + recipe.data.seq_length = 2048 + + recipe.trainer.strategy.context_parallel_size = 1 + recipe.model.config.seq_length = recipe.data.seq_length + + recipe = run.Partial( + llm.pretrain, + model=run.Config(LlamaModel, config=run.Config(Llama3Config145M)), + trainer=recipe.trainer, + data=recipe.data, + log=recipe.log, + optim=recipe.optim, + resume=None, + ) + + return recipe + + def train_config(args): - # GPT-3 126M + # Llama3 145M # This example will generate 3 configs. - # It is expected that this script will be run 3 times with changing --run_number flag for each run from 0 to 2. + # It is expected that this script will be run 3 times with changing --run_number flag for each run from 1 to 3. # After all configurations are trained, please trigger the script using --get_results flag. + + # Get Auto Conf runner runner = AutoConfigurator( - model=run.Config(GPTConfig126M), - num_nodes=1, - gpus_per_node=1, + recipe=partial(llama3_145m)(), gpu_memory_gb=40, - global_batch_size=16, - seq_length=512, tensor_parallel_sizes=[1], pipeline_parallel_sizes=[1], micro_batch_sizes=[1, 2, 4], max_training_days=1, - max_steps_per_run=25, + max_steps_per_run=10, num_tokens_in_b=10, - vocab_size=51200, - tokenizer_type="autotokenizer", - tokenizer_path=args.tokenizer_path, - data_paths=args.data_path, - path_to_logs=args.logs_dir, + vocab_size=32000, + path_to_logs=args.log_dir, ) base_cfg, configs = generate_configs(runner) @@ -65,14 +93,13 @@ def train_config(args): names = list(configs.keys()) # Run pre-training - partial = partials[args.run_number - 1] - partial.log.log_dir = os.path.join(args.logs_dir, names[args.run_number - 1]) - pretrain = fdl.build(partial) + pretrain_cfg = partials[args.run_number - 1] # partial(llama3_145m)() # + pretrain = fdl.build(pretrain_cfg) pretrain() else: # # Get Auto Configurator results - get_results(base_cfg, runner, args.logs_dir) - print(f"The results were successfully saved to {args.logs_dir}.") + get_results(base_cfg, runner, args.log_dir) + print(f"The results were successfully saved to {args.log_dir}.") def main(): diff --git a/examples/nlp/dialogue/dialogue.py b/examples/nlp/dialogue/dialogue.py index 4284fed42d22..578895a2ad43 100644 --- a/examples/nlp/dialogue/dialogue.py +++ b/examples/nlp/dialogue/dialogue.py @@ -63,7 +63,7 @@ @hydra_runner(config_path="conf", config_name="dialogue_config") def main(cfg: DictConfig) -> None: pl.seed_everything(42) - logging.warning('This script is no longer supported in NeMo and is scheduled for removal in the 23.11 release.') + logging.warning('This script is no longer supported in NeMo and is scheduled for removal in the 24.11 release.') logging.info(f'Config: {OmegaConf.to_yaml(cfg)}') try: diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py index b48f99e061c9..5bc45b1049f3 100644 --- a/nemo/collections/llm/gpt/model/llama.py +++ b/nemo/collections/llm/gpt/model/llama.py @@ -56,7 +56,6 @@ class LlamaConfig(GPTConfig): persist_layer_norm: bool = True bias_dropout_fusion: bool = True apply_rope_fusion: bool = True - cross_entropy_loss_fusion: bool = False @dataclass diff --git a/nemo/collections/llm/quantization/utils.py b/nemo/collections/llm/quantization/utils.py index 86c343ad54ec..c4c533fe38d0 100644 --- a/nemo/collections/llm/quantization/utils.py +++ b/nemo/collections/llm/quantization/utils.py @@ -44,11 +44,12 @@ def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: def load_with_modelopt_layer_spec(nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1) -> llm.GPTModel: trainer = nl.Trainer( - devices=calib_tp * calib_pp, + devices=calib_tp, + num_nodes=calib_pp, strategy=nl.MegatronStrategy( - tensor_model_parallel_size=calib_tp, pipeline_model_parallel_size=calib_pp, pipeline_dtype=torch.float32 + tensor_model_parallel_size=calib_tp, pipeline_model_parallel_size=calib_pp, pipeline_dtype=torch.bfloat16 ), - plugins=nl.MegatronMixedPrecision(precision='32', pipeline_dtype=torch.float32), + plugins=nl.MegatronMixedPrecision(precision='bf16', pipeline_dtype=torch.bfloat16, autocast_enabled=True), ) fabric = trainer.to_fabric() fabric.launch() diff --git a/nemo/collections/llm/recipes/__init__.py b/nemo/collections/llm/recipes/__init__.py index 2aa6eb8bf784..b02547acfffe 100644 --- a/nemo/collections/llm/recipes/__init__.py +++ b/nemo/collections/llm/recipes/__init__.py @@ -55,6 +55,14 @@ qwen2_7b, qwen2_72b, qwen2_500m, + starcoder, + starcoder2, + starcoder2_3b, + starcoder2_7b, + starcoder2_15b, + t5_3b, + t5_11b, + t5_220m, ) from nemo.collections.llm.recipes.log.default import default_log, default_resume from nemo.collections.llm.recipes.optim import adam @@ -95,6 +103,14 @@ "nemotron4_22b_16k", "nemotron4_22b_64k", "nemotron4_340b", + "t5_220m", + "t5_3b", + "t5_11b", + "starcoder", + "starcoder2", + "starcoder2_3b", + "starcoder2_7b", + "starcoder2_15b", "qwen2", "qwen2_500m", "qwen2_1p5b", diff --git a/nemo/collections/llm/recipes/llama3_70b.py b/nemo/collections/llm/recipes/llama3_70b.py index 6e9da5c5116d..5b721c7d531e 100644 --- a/nemo/collections/llm/recipes/llama3_70b.py +++ b/nemo/collections/llm/recipes/llama3_70b.py @@ -24,6 +24,7 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama3Config70B, LlamaModel from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe @@ -31,6 +32,7 @@ from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192 +from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.utils.exp_manager import TimingCallback @@ -245,7 +247,9 @@ def finetune_recipe( num_nodes: int = 1, num_gpus_per_node: int = 8, peft_scheme: Optional[str] = 'lora', - packed_sequence: bool = False, + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, ) -> run.Partial: """ Create a fine-tuning recipe for Llama3 70B model. @@ -260,6 +264,9 @@ def finetune_recipe( num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. Returns: run.Partial: Partial configuration for fine-tuning. @@ -277,6 +284,15 @@ def finetune_recipe( This recipe uses the SQuAD dataset for fine-tuning. Be aware that fine-tuning a 70B model requires substantial computational resources. """ + # Default to unpacked data in normal mode and packed data in performance mode + # once packing recipe is well tested, change this default to true + if packed_sequence is None: + packed_sequence = performance_mode + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + recipe = default_finetune_recipe( model(), "meta-llama/Meta-Llama-3-70B", dir, name, num_nodes, num_gpus_per_node, packed_sequence ) @@ -287,8 +303,90 @@ def finetune_recipe( recipe.optim.config.lr = 5e-6 elif peft_scheme.lower() == 'lora': recipe.peft = run.Config(LoRA) + recipe.peft.dim = 16 + recipe.peft.alpha = 32 + recipe.peft.target_modules = ['linear_qkv'] + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.pad_to_max_length = True + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5 + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=22, + ) + ) + else: + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + + recipe.trainer.strategy.sequence_parallel = True + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + return recipe diff --git a/nemo/collections/llm/recipes/llama3_8b.py b/nemo/collections/llm/recipes/llama3_8b.py index 394a7718b8bd..29c5c25f94fe 100644 --- a/nemo/collections/llm/recipes/llama3_8b.py +++ b/nemo/collections/llm/recipes/llama3_8b.py @@ -24,6 +24,7 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel from nemo.collections.llm.peft.lora import LoRA @@ -31,6 +32,7 @@ from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.utils.exp_manager import TimingCallback @@ -233,7 +235,9 @@ def finetune_recipe( num_nodes: int = 1, num_gpus_per_node: int = 8, peft_scheme: Optional[str] = 'lora', - packed_sequence: bool = False, # once packing recipe is well tested, change this default to true + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, ) -> run.Partial: """ Create a fine-tuning recipe for Llama3 8B model. @@ -248,6 +252,9 @@ def finetune_recipe( num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. Returns: run.Partial: Partial configuration for fine-tuning. @@ -265,6 +272,15 @@ def finetune_recipe( on fine-tuning LLMs with NeMo, see the fine-tuning guide in the `examples/llm/finetune/` directory. """ + # Default to unpacked data in normal mode and packed data in performance mode + # once packing recipe is well tested, change this default to true + if packed_sequence is None: + packed_sequence = performance_mode + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + recipe = default_finetune_recipe( model(), "meta-llama/Meta-Llama-3-8B", dir, name, num_nodes, num_gpus_per_node, packed_sequence ) @@ -273,7 +289,80 @@ def finetune_recipe( recipe.optim.config.lr = 5e-6 elif peft_scheme.lower() == 'lora': recipe.peft = run.Config(LoRA) + recipe.peft.dim = 8 + recipe.peft.alpha = 16 + recipe.peft.target_modules = ['linear_qkv'] + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.pad_to_max_length = True + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + recipe.trainer.strategy.tensor_model_parallel_size = 1 + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=False, + ) + ) + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + return recipe diff --git a/nemo/collections/llm/recipes/starcoder.py b/nemo/collections/llm/recipes/starcoder.py new file mode 100644 index 000000000000..b90cec0fbd7e --- /dev/null +++ b/nemo/collections/llm/recipes/starcoder.py @@ -0,0 +1,310 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.model.starcoder import StarcoderConfig15B, StarcoderModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed, fp16_mixed +from nemo.utils.exp_manager import TimingCallback + +NAME = "starcoder_15b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Starcoder 15B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Starcoder 15B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=starcoder_15b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + + return run.Config(StarcoderModel, config=run.Config(StarcoderConfig15B)) + + +def starcoder_trainer( + tensor_parallelism: int = 4, + pipeline_parallelism: int = 2, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + precision: str = "bf16-mixed", + accumulate_grad_batches: int = 1, + limit_test_batches: int = 32, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + val_check_interval: int = 2000, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Starcoder 15B models. + + This function sets up the distributed training strategy and other training parameters. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + precision (str): Precision configuration, one of fp32, 16-mixed or bf16-mixed. + accumulate_grad_batches (int): Number of steps per gradient accumulation. + limit_test_batches (int): Limit the number of test batches. + limit_val_batches (int): Limit the number of validation batches. + log_every_n_steps (int): Log every n steps. + val_check_interval (int): Run validation every N steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_include_optimizer=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ) + + precision_plugin = None + if precision == "16-mixed": + precision_plugin = fp16_mixed() + elif precision == "bf16-mixed": + precision_plugin = bf16_mixed() + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + callbacks=callbacks, + devices=num_gpus_per_node, + accumulate_grad_batches=accumulate_grad_batches, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=precision_plugin, + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=val_check_interval, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + # General + dir: Optional[str] = None, + name: str = "default", + # Trainer + tensor_parallelism: int = 2, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 300000, + precision: str = "bf16-mixed", + accumulate_grad_batches: int = 1, + gradient_clip_val: float = 1.0, + limit_test_batches: int = 32, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + val_check_interval: int = 1000, + # Data + global_batch_size=32, + micro_batch_size=2, + seq_length=4096, + # Optimizer + warmup_steps=500, + constant_steps=0, + min_lr=3e-5, + max_lr=3e-4, + # Training function + fn=pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Starcoder 15B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + precision (str): Precision configuration, one of fp32, 16-mixed or bf16-mixed. + accumulate_grad_batches (int): Number of steps per gradient accumulation. + gradient_clip_val (float): Value for gradient clipping. + limit_test_batches (int): Limit the number of test batches. + limit_val_batches (int): Limit the number of validation batches. + log_every_n_steps (int): Log every n steps. + val_check_interval (int): Run validation every N steps. + global_batch_size (int): Global batch size. + micro_batch_size (int): Micro batch size. + seq_length (int): Sequence length. + warmup_steps (int): Number of warmup steps. + constant_steps (int): Number of constant steps. + min_lr (float): Minimum learning rate. + max_lr (float): Maximum learning rate. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory starcoder_15b + $ nemo llm pretrain --factory "starcoder_15b(num_nodes=1, name='my_starcoder2_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="starcoder2_pretrain", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses a mock dataset, look for the finetune examples to see how to change the dataset. + """ + return run.Partial( + fn, + model=model(), + trainer=starcoder_trainer( + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, + pipeline_parallelism_type=pipeline_parallelism_type, + virtual_pipeline_parallelism=virtual_pipeline_parallelism, + context_parallelism=context_parallelism, + sequence_parallelism=sequence_parallelism, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + max_steps=max_steps, + precision=precision, + accumulate_grad_batches=accumulate_grad_batches, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, + val_check_interval=val_check_interval, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config( + MockDataModule, + seq_length=seq_length, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + ), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing( + precision=precision, + warmup_steps=warmup_steps, + constant_steps=constant_steps, + min_lr=min_lr, + max_lr=max_lr, + clip_grad=gradient_clip_val, + ), + resume=default_resume(), + ) + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', +) -> run.Partial: + """ + Create a fine-tuning recipe for Starcoder 15B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory starcoder_15b + + Python API usage: + >>> recipe = finetune_recipe(name="starcoder_15b_finetune", num_nodes=2) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + recipe = default_finetune_recipe(model(), "bigcode/starcoder", dir, name, num_nodes, num_gpus_per_node) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + return recipe diff --git a/nemo/collections/llm/recipes/starcoder2.py b/nemo/collections/llm/recipes/starcoder2.py new file mode 100644 index 000000000000..c3a19326585c --- /dev/null +++ b/nemo/collections/llm/recipes/starcoder2.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks.callback import Callback +from nemo import lightning as nl +from nemo.collections.llm.gpt.model.starcoder2 import ( + Starcoder2Config3B, + Starcoder2Config7B, + Starcoder2Config15B, + Starcoder2Model, +) +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed, fp16_mixed + + +def starcoder2_model(version: str) -> run.Config[pl.LightningModule]: + """ + A function to create a Starcoder2 models. + + Args: + version (str): The version of the Starcoder2 model to create. one of ["starcoder2_3b", "starcoder2_7b", + "starcoder2_15b"]. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Starcoder2 model. + """ + config = None + if version == "starcoder2_3b": + config = run.Config(Starcoder2Config3B) + elif version == "starcoder2_7b": + config = run.Config(Starcoder2Config7B) + elif version == "starcoder2_15b": + config = run.Config(Starcoder2Config15B) + + assert config is not None, f"Invalid version: {version}" + return run.Config(Starcoder2Model, config=config) + + +def starcoder2_trainer( + tensor_parallelism: int = 2, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + precision: str = "bf16-mixed", + accumulate_grad_batches: int = 1, + limit_test_batches: int = 32, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + val_check_interval: int = 2000, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Starcoder2 models. + + This function sets up the distributed training strategy and other training parameters. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + precision (str): Precision configuration, one of fp32, 16-mixed or bf16-mixed. + accumulate_grad_batches (int): Number of steps per gradient accumulation. + limit_test_batches (int): Limit the number of test batches. + limit_val_batches (int): Limit the number of validation batches. + log_every_n_steps (int): Log every n steps. + val_check_interval (int): Run validation every N steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_include_optimizer=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ) + + precision_plugin = None + if precision == "16-mixed": + precision_plugin = fp16_mixed() + elif precision == "bf16-mixed": + precision_plugin = bf16_mixed() + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + callbacks=callbacks, + devices=num_gpus_per_node, + accumulate_grad_batches=accumulate_grad_batches, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=precision_plugin, + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=val_check_interval, + ) + + return trainer diff --git a/nemo/collections/llm/recipes/starcoder2_15b.py b/nemo/collections/llm/recipes/starcoder2_15b.py new file mode 100644 index 000000000000..5faebb9460f3 --- /dev/null +++ b/nemo/collections/llm/recipes/starcoder2_15b.py @@ -0,0 +1,223 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch + +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.starcoder2 import starcoder2_model, starcoder2_trainer +from nemo.utils.exp_manager import TimingCallback + +NAME = "starcoder2_15b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Starcoder2 15B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Starcoder2 15b model. + + Examples: + CLI usage: + $ nemo llm pretrain model=starcoder2_15b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + + return starcoder2_model(version=NAME) + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + # General + dir: Optional[str] = None, + name: str = "default", + # Trainer + tensor_parallelism: int = 4, + pipeline_parallelism: int = 2, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 300000, + precision: str = "bf16-mixed", + accumulate_grad_batches: int = 1, + gradient_clip_val: float = 1.0, + limit_test_batches: int = 32, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + val_check_interval: int = 1000, + # Data + global_batch_size=32, + micro_batch_size=2, + seq_length=4096, + # Optimizer + warmup_steps=500, + constant_steps=0, + min_lr=3e-5, + max_lr=3e-4, + # Training function + fn=pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Starcoder2 15B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + precision (str): Precision configuration, one of fp32, 16-mixed or bf16-mixed. + accumulate_grad_batches (int): Number of steps per gradient accumulation. + gradient_clip_val (float): Value for gradient clipping. + limit_test_batches (int): Limit the number of test batches. + limit_val_batches (int): Limit the number of validation batches. + log_every_n_steps (int): Log every n steps. + val_check_interval (int): Run validation every N steps. + global_batch_size (int): Global batch size. + micro_batch_size (int): Micro batch size. + seq_length (int): Sequence length. + warmup_steps (int): Number of warmup steps. + constant_steps (int): Number of constant steps. + min_lr (float): Minimum learning rate. + max_lr (float): Maximum learning rate. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory starcoder2_15b + $ nemo llm pretrain --factory "starcoder2_15b(num_nodes=1, name='my_starcoder2_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="starcoder2_pretrain", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses a mock dataset, look for the finetune examples to see how to change the dataset. + """ + return run.Partial( + fn, + model=model(), + trainer=starcoder2_trainer( + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, + pipeline_parallelism_type=pipeline_parallelism_type, + virtual_pipeline_parallelism=virtual_pipeline_parallelism, + context_parallelism=context_parallelism, + sequence_parallelism=sequence_parallelism, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + max_steps=max_steps, + precision=precision, + accumulate_grad_batches=accumulate_grad_batches, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, + val_check_interval=val_check_interval, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config( + MockDataModule, + seq_length=seq_length, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + ), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing( + precision=precision, + warmup_steps=warmup_steps, + constant_steps=constant_steps, + min_lr=min_lr, + max_lr=max_lr, + clip_grad=gradient_clip_val, + ), + resume=default_resume(), + ) + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', +) -> run.Partial: + """ + Create a fine-tuning recipe for Starcoder2 15B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory starcoder2_15b + + Python API usage: + >>> recipe = finetune_recipe(name="starcoder2_15b_finetune", num_nodes=2) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + recipe = default_finetune_recipe(model(), "bigcode/starcoder2-15b", dir, name, num_nodes, num_gpus_per_node) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + return recipe diff --git a/nemo/collections/llm/recipes/starcoder2_3b.py b/nemo/collections/llm/recipes/starcoder2_3b.py new file mode 100644 index 000000000000..232f5842ff84 --- /dev/null +++ b/nemo/collections/llm/recipes/starcoder2_3b.py @@ -0,0 +1,223 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch + +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.starcoder2 import starcoder2_model, starcoder2_trainer +from nemo.utils.exp_manager import TimingCallback + +NAME = "starcoder2_3b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Starcoder2 3B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Starcoder2 3b model. + + Examples: + CLI usage: + $ nemo llm pretrain model=starcoder2_3b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + + return starcoder2_model(version=NAME) + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + # General + dir: Optional[str] = None, + name: str = "default", + # Trainer + tensor_parallelism: int = 2, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 300000, + precision: str = "bf16-mixed", + accumulate_grad_batches: int = 1, + gradient_clip_val: float = 1.0, + limit_test_batches: int = 32, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + val_check_interval: int = 1000, + # Data + global_batch_size=32, + micro_batch_size=2, + seq_length=4096, + # Optimizer + warmup_steps=500, + constant_steps=0, + min_lr=3e-5, + max_lr=3e-4, + # Training function + fn=pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Starcoder2 3B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + precision (str): Precision configuration, one of fp32, 16-mixed or bf16-mixed. + accumulate_grad_batches (int): Number of steps per gradient accumulation. + gradient_clip_val (float): Value for gradient clipping. + limit_test_batches (int): Limit the number of test batches. + limit_val_batches (int): Limit the number of validation batches. + log_every_n_steps (int): Log every n steps. + val_check_interval (int): Run validation every N steps. + global_batch_size (int): Global batch size. + micro_batch_size (int): Micro batch size. + seq_length (int): Sequence length. + warmup_steps (int): Number of warmup steps. + constant_steps (int): Number of constant steps. + min_lr (float): Minimum learning rate. + max_lr (float): Maximum learning rate. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory starcoder2_3b + $ nemo llm pretrain --factory "starcoder2_3b(num_nodes=1, name='my_starcoder2_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="starcoder2_pretrain", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses a mock dataset, look for the finetune examples to see how to change the dataset. + """ + return run.Partial( + fn, + model=model(), + trainer=starcoder2_trainer( + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, + pipeline_parallelism_type=pipeline_parallelism_type, + virtual_pipeline_parallelism=virtual_pipeline_parallelism, + context_parallelism=context_parallelism, + sequence_parallelism=sequence_parallelism, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + max_steps=max_steps, + precision=precision, + accumulate_grad_batches=accumulate_grad_batches, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, + val_check_interval=val_check_interval, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config( + MockDataModule, + seq_length=seq_length, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + ), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing( + precision=precision, + warmup_steps=warmup_steps, + constant_steps=constant_steps, + min_lr=min_lr, + max_lr=max_lr, + clip_grad=gradient_clip_val, + ), + resume=default_resume(), + ) + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', +) -> run.Partial: + """ + Create a fine-tuning recipe for Starcoder2 3B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory starcoder2_3b + + Python API usage: + >>> recipe = finetune_recipe(name="starcoder2_3b_finetune", num_nodes=2) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + recipe = default_finetune_recipe(model(), "bigcode/starcoder2-3b", dir, name, num_nodes, num_gpus_per_node) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + return recipe diff --git a/nemo/collections/llm/recipes/starcoder2_7b.py b/nemo/collections/llm/recipes/starcoder2_7b.py new file mode 100644 index 000000000000..ee6dacdc98e9 --- /dev/null +++ b/nemo/collections/llm/recipes/starcoder2_7b.py @@ -0,0 +1,223 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch + +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.starcoder2 import starcoder2_model, starcoder2_trainer +from nemo.utils.exp_manager import TimingCallback + +NAME = "starcoder2_7b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Starcoder2 7b model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Starcoder2 7b model. + + Examples: + CLI usage: + $ nemo llm pretrain model=starcoder2_7b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + + return starcoder2_model(version=NAME) + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + # General + dir: Optional[str] = None, + name: str = "default", + # Trainer + tensor_parallelism: int = 2, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 300000, + precision: str = "bf16-mixed", + accumulate_grad_batches: int = 1, + gradient_clip_val: float = 1.0, + limit_test_batches: int = 32, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + val_check_interval: int = 1000, + # Data + global_batch_size=32, + micro_batch_size=2, + seq_length=4096, + # Optimizer + warmup_steps=500, + constant_steps=0, + min_lr=3e-5, + max_lr=3e-4, + # Training function + fn=pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Starcoder2 7B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + precision (str): Precision configuration, one of fp32, 16-mixed or bf16-mixed. + accumulate_grad_batches (int): Number of steps per gradient accumulation. + gradient_clip_val (float): Value for gradient clipping. + limit_test_batches (int): Limit the number of test batches. + limit_val_batches (int): Limit the number of validation batches. + log_every_n_steps (int): Log every n steps. + val_check_interval (int): Run validation every N steps. + global_batch_size (int): Global batch size. + micro_batch_size (int): Micro batch size. + seq_length (int): Sequence length. + warmup_steps (int): Number of warmup steps. + constant_steps (int): Number of constant steps. + min_lr (float): Minimum learning rate. + max_lr (float): Maximum learning rate. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory starcoder2_7b + $ nemo llm pretrain --factory "starcoder2_7b(num_nodes=1, name='my_starcoder2_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="starcoder2_pretrain", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses a mock dataset, look for the finetune examples to see how to change the dataset. + """ + return run.Partial( + fn, + model=model(), + trainer=starcoder2_trainer( + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, + pipeline_parallelism_type=pipeline_parallelism_type, + virtual_pipeline_parallelism=virtual_pipeline_parallelism, + context_parallelism=context_parallelism, + sequence_parallelism=sequence_parallelism, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + max_steps=max_steps, + precision=precision, + accumulate_grad_batches=accumulate_grad_batches, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, + val_check_interval=val_check_interval, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config( + MockDataModule, + seq_length=seq_length, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + ), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing( + precision=precision, + warmup_steps=warmup_steps, + constant_steps=constant_steps, + min_lr=min_lr, + max_lr=max_lr, + clip_grad=gradient_clip_val, + ), + resume=default_resume(), + ) + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', +) -> run.Partial: + """ + Create a fine-tuning recipe for Starcoder2 7B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory starcoder2_7b + + Python API usage: + >>> recipe = finetune_recipe(name="starcoder2_7b_finetune", num_nodes=2) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + recipe = default_finetune_recipe(model(), "bigcode/starcoder2-7b", dir, name, num_nodes, num_gpus_per_node) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + return recipe diff --git a/nemo/collections/llm/recipes/t5_11b.py b/nemo/collections/llm/recipes/t5_11b.py new file mode 100644 index 000000000000..09d469879364 --- /dev/null +++ b/nemo/collections/llm/recipes/t5_11b.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.llm.t5.data.mock import MockDataModule +from nemo.collections.llm.t5.model.t5 import T5Config11B, T5Model +from nemo.lightning.pytorch.optim.lr_scheduler import WarmupAnnealingScheduler +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.utils.exp_manager import TimingCallback + +NAME = "t5_11b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a T5 11B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the T5 11B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=t5_11b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(T5Model, config=run.Config(T5Config11B)) + + +def trainer( + tensor_parallelism: int = 4, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 20, + num_gpus_per_node: int = 8, + max_steps: int = 1000000, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for T5 model. + + This function sets up the distributed training strategy and other training parameters. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=t5_11b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + For more information on distributed training strategies, refer to the + NeMo documentation on multi-GPU and multi-node training. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, name: str = "default", num_nodes: int = 20, num_gpus_per_node: int = 8, fn=pretrain +) -> run.Partial: + """ + Create a pre-training recipe for T5 11b model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory t5_11b + $ nemo llm pretrain --factory "t5_11b(num_nodes=2, name='my_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="t5_11b_pretrain", num_nodes=2) + >>> print(recipe) + + Note: + For more details on pre-training LLMs with NeMo, see the pre-training + guide in the `examples/llm/pretrain/` directory. + """ + + opt_config = OptimizerConfig( + optimizer='adam', + lr=0.0001, + use_distributed_optimizer=True, + bf16=True, + weight_decay=0.01, + ) + + lr_scheduler = WarmupAnnealingScheduler( + warmup_steps=None, + warmup_ratio=0.01, + max_steps=1000000, + min_lr=0.00001, + ) + + return run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config( + MockDataModule, seq_length=512, seq_length_dec=128, global_batch_size=1920, micro_batch_size=24 + ), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + resume=default_resume(), + ) diff --git a/nemo/collections/llm/recipes/t5_220m.py b/nemo/collections/llm/recipes/t5_220m.py new file mode 100644 index 000000000000..a3b2b761b65b --- /dev/null +++ b/nemo/collections/llm/recipes/t5_220m.py @@ -0,0 +1,203 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.llm.t5.data.mock import MockDataModule +from nemo.collections.llm.t5.model.t5 import T5Config220M, T5Model +from nemo.lightning.pytorch.optim.lr_scheduler import WarmupAnnealingScheduler +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.utils.exp_manager import TimingCallback + +NAME = "t5_220m" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a T5 220M model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the T5 220M model. + + Examples: + CLI usage: + $ nemo llm pretrain model=t5_220m ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(T5Model, config=run.Config(T5Config220M)) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1000000, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for T5 model. + + This function sets up the distributed training strategy and other training parameters. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=t5_220m ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + For more information on distributed training strategies, refer to the + NeMo documentation on multi-GPU and multi-node training. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + # DEBUGGING + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain +) -> run.Partial: + """ + Create a pre-training recipe for T5 220m model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory t5_220m + $ nemo llm pretrain --factory "t5_220m(num_nodes=2, name='my_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="t5_220m_pretrain", num_nodes=2) + >>> print(recipe) + + Note: + For more details on pre-training LLMs with NeMo, see the pre-training + guide in the `examples/llm/pretrain/` directory. + """ + + opt_config = OptimizerConfig( + optimizer='adam', + lr=0.0001, + use_distributed_optimizer=True, + bf16=True, + weight_decay=0.01, + ) + + lr_scheduler = WarmupAnnealingScheduler( + warmup_steps=None, + warmup_ratio=0.01, + max_steps=1000000, + min_lr=0.00001, + ) + + return run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=512, seq_length_dec=128, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + resume=default_resume(), + ) diff --git a/nemo/collections/llm/recipes/t5_3b.py b/nemo/collections/llm/recipes/t5_3b.py new file mode 100644 index 000000000000..08bcae895c3e --- /dev/null +++ b/nemo/collections/llm/recipes/t5_3b.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.llm.t5.data.mock import MockDataModule +from nemo.collections.llm.t5.model.t5 import T5Config3B, T5Model +from nemo.lightning.pytorch.optim.lr_scheduler import WarmupAnnealingScheduler +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.utils.exp_manager import TimingCallback + +NAME = "t5_3b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a T5 3B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the T5 3B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=t5_3b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(T5Model, config=run.Config(T5Config3B)) + + +def trainer( + tensor_parallelism: int = 2, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 20, + num_gpus_per_node: int = 8, + max_steps: int = 1000000, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for T5 model. + + This function sets up the distributed training strategy and other training parameters. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=t5_3b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + For more information on distributed training strategies, refer to the + NeMo documentation on multi-GPU and multi-node training. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, name: str = "default", num_nodes: int = 20, num_gpus_per_node: int = 8, fn=pretrain +) -> run.Partial: + """ + Create a pre-training recipe for T5 3b model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory t5_3b + $ nemo llm pretrain --factory "t5_3b(num_nodes=2, name='my_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="t5_3b_pretrain", num_nodes=2) + >>> print(recipe) + + Note: + For more details on pre-training LLMs with NeMo, see the pre-training + guide in the `examples/llm/pretrain/` directory. + """ + + opt_config = OptimizerConfig( + optimizer='adam', + lr=0.0001, + use_distributed_optimizer=True, + bf16=True, + weight_decay=0.01, + ) + + lr_scheduler = WarmupAnnealingScheduler( + warmup_steps=None, + warmup_ratio=0.01, + max_steps=1000000, + min_lr=0.00001, + ) + + return run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config( + MockDataModule, seq_length=512, seq_length_dec=128, global_batch_size=1920, micro_batch_size=24 + ), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + resume=default_resume(), + ) diff --git a/nemo/collections/llm/t5/data/mock.py b/nemo/collections/llm/t5/data/mock.py new file mode 100644 index 000000000000..eaf41d290da4 --- /dev/null +++ b/nemo/collections/llm/t5/data/mock.py @@ -0,0 +1,189 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, List, Optional + +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from torch.utils import data +from torch.utils.data import DataLoader, Dataset + +from nemo.lightning.pytorch.plugins import MegatronDataSampler +from nemo.utils.import_utils import safe_import + +_, HAVE_TE = safe_import("transformer_engine") + +if TYPE_CHECKING: + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +class MockDataModule(pl.LightningDataModule): + def __init__( + self, + seq_length: int = 512, + seq_length_dec: int = 128, + tokenizer: Optional["TokenizerSpec"] = None, + micro_batch_size: int = 4, + global_batch_size: int = 8, + rampup_batch_size: Optional[List[int]] = None, + num_train_samples: int = 10_000, + num_val_samples: int = 10_000, + num_test_samples: int = 10_000, + num_workers: int = 8, + pin_memory: bool = True, + persistent_workers: bool = False, + create_attention_mask: bool = False, + ): + super().__init__() + self.seq_length = seq_length + self.seq_length_dec = seq_length_dec + self.num_train_samples = num_train_samples + self.num_val_samples = num_val_samples + self.num_test_samples = num_test_samples + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + self.create_attention_mask = create_attention_mask or not HAVE_TE + + from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + + self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "BertWordPieceCase") + self.data_sampler = MegatronDataSampler( + seq_len=self.seq_length, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + rampup_batch_size=rampup_batch_size, + ) + + def setup(self, stage: str = "") -> None: + self._train_ds = _MockT5Dataset( + self.tokenizer, "train", self.num_train_samples, self.seq_length, self.seq_length_dec + ) + self._validation_ds = _MockT5Dataset( + self.tokenizer, "valid", self.num_val_samples, self.seq_length, self.seq_length_dec + ) + self._test_ds = _MockT5Dataset( + self.tokenizer, "test", self.num_test_samples, self.seq_length, self.seq_length_dec + ) + + def train_dataloader(self) -> TRAIN_DATALOADERS: + if not hasattr(self, "_train_ds"): + self.setup() + return self._create_dataloader(self._train_ds) + + def val_dataloader(self) -> EVAL_DATALOADERS: + if not hasattr(self, "_validation_ds"): + self.setup() + return self._create_dataloader(self._validation_ds) + + def test_dataloader(self) -> EVAL_DATALOADERS: + if not hasattr(self, "_test_ds"): + self.setup() + return self._create_dataloader(self._test_ds) + + def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + return DataLoader( + dataset, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + collate_fn=dataset.collate_fn, + **kwargs, + ) + + +class _MockT5Dataset(Dataset): + def __init__( + self, + tokenizer: "TokenizerSpec", + name: str, + num_samples: int, + seq_length: int, + seq_length_dec: int, + seed: int = 42, + create_attention_mask: bool = False, + ) -> None: + super().__init__() + self.name = name + self.seq_length = seq_length + self.seq_length_dec = seq_length_dec + self.vocab_size = tokenizer.vocab_size + self.length = num_samples + self.seed = seed + self.create_attention_mask = create_attention_mask + + self.mask_encoder = torch.ones((self.seq_length, self.seq_length), device='cpu') + self.mask_decoder = torch.tril(torch.ones((self.seq_length_dec, self.seq_length_dec), device='cpu')) + self.mask_encoder_decoder = torch.ones((self.seq_length_dec, self.seq_length), device='cpu') + self.mask_encoder = self.mask_encoder < 0.5 + self.mask_decoder = self.mask_decoder < 0.5 + self.mask_encoder_decoder = self.mask_encoder_decoder < 0.5 + + self.loss_mask = torch.ones(self.seq_length_dec, dtype=torch.float) + + def __len__(self) -> int: + return self.length + + def _get_text(self, idx: int) -> np.ndarray: + np_gen = np.random.default_rng(seed=(self.seed + idx)) + return np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64) + + def __getitem__(self, idx) -> Dict[str, torch.Tensor]: + # Generate data of the expected size and datatype (based on GPTDataset). + np_gen = np.random.default_rng(seed=(self.seed + idx)) + encoder_input = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64)) + decoder_input = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length_dec], dtype=np.int64)) + labels = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length_dec], dtype=np.int64)) + + batch = { + "text_enc": encoder_input, + "text_dec": decoder_input, + "labels": labels, + "loss_mask": self.loss_mask, + "truncated": 0, + "enc_mask": self.mask_encoder, + "dec_mask": self.mask_decoder, + "enc_dec_mask": self.mask_encoder_decoder, + } + + return batch + + def _collate_fn(self, batch): + """ + A default implementation of a collation function. + Users should override this method to define custom data loaders. + """ + return data.dataloader.default_collate(batch) + + def collate_fn(self, batch): + """Method that user pass as functor to DataLoader. + + The method optionally performs neural type checking and add types to the outputs. + + Please note, subclasses of Dataset should not implement `input_types`. + + # Usage: + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) + + Returns + ------- + Collated batch, with or without types. + """ + return self._collate_fn(batch) diff --git a/nemo/collections/llm/t5/model/t5.py b/nemo/collections/llm/t5/model/t5.py index e6970cba3dd8..058acaaec7b0 100644 --- a/nemo/collections/llm/t5/model/t5.py +++ b/nemo/collections/llm/t5/model/t5.py @@ -200,6 +200,38 @@ def configure_model(self, tokenizer) -> "MCoreT5Model": return model +@dataclass +class T5Config220M(T5Config): + """ + NeMo's T5 model variant + https://github.com/NVIDIA/NeMo-Framework-Launcher/blob/main/launcher_scripts/conf/training/t5/220m.yaml + """ + + num_layers: int = 12 + encoder_num_layers: int = 12 + hidden_size: int = 768 + ffn_hidden_size: int = 3072 + num_attention_heads: int = 12 + + +@dataclass +class T5Config3B(T5Config): + num_layers: int = 24 + encoder_num_layers: int = 24 + hidden_size: int = 2048 + ffn_hidden_size: int = 5120 + num_attention_heads: int = 32 + + +@dataclass +class T5Config11B(T5Config): + num_layers: int = 24 + encoder_num_layers: int = 24 + hidden_size: int = 4096 + ffn_hidden_size: int = 10240 + num_attention_heads: int = 64 + + class T5Model(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin): def __init__( self, diff --git a/nemo/collections/llm/tools/auto_configurator/core/__init__.py b/nemo/collections/llm/tools/auto_configurator/core/__init__.py deleted file mode 100644 index d9155f923f18..000000000000 --- a/nemo/collections/llm/tools/auto_configurator/core/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/nemo/collections/llm/tools/auto_configurator/core/base_config.py b/nemo/collections/llm/tools/auto_configurator/core/base_config.py index a82823c71248..b621b0567c05 100644 --- a/nemo/collections/llm/tools/auto_configurator/core/base_config.py +++ b/nemo/collections/llm/tools/auto_configurator/core/base_config.py @@ -12,214 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import TensorBoardLogger - -from nemo import lightning as nl -from nemo.collections.common.tokenizers import AutoTokenizer, SentencePieceTokenizer -from nemo.collections.llm import PreTrainingDataModule -from nemo.collections.llm.utils import Config -from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, MegatronOptimizerModule -from nemo.utils.exp_manager import TimingCallback - -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class BaseConfig: - def __init__(self, config=None): - """ - Args: - config (AutoConfigurator): auto configurator runner config. - """ - - self.config = config - - self.model = self.get_model() - self.optim = self.get_optim() - self.trainer = self.get_trainer() - self.data = self.get_data() - self.log = self.get_logger() - self.run = self.get_run_config() - self.tokenizer = self.get_tokenizer(config.tokenizer_type, config.tokenizer_path) - - def get_model(self): - """Function that returns model config. - - Returns: - Config: model config. - """ - - self.config.model.seq_length = self.config.seq_length - - return self.config.model - - def get_optim(self) -> Config[OptimizerConfig]: - """Function that returns optimizer config. - - Returns: - Config[OptimizerConfig]: optimizer config. - """ - optim_params = { - "optimizer": "adam", - "lr": 1e-4, - "min_lr": 1e-5, - "use_distributed_optimizer": True, - "bf16": True, - "adam_beta1": 0.9, - "adam_beta2": 0.95, - "clip_grad": 1.0, - "adam_eps": 1e-5, - } - - optim_config = Config( - OptimizerConfig, - **optim_params, - ) - - sched = Config( - CosineAnnealingScheduler, - warmup_steps=10, - constant_steps=0, - min_lr=optim_config.min_lr, - ) - - return Config( - MegatronOptimizerModule, - config=optim_config, - lr_scheduler=sched, - ) - - def get_trainer(self) -> Config[nl.Trainer]: - """Function that returns config for PTL trainer. - - Returns: - Config[nl.Trainer]: trainer config. - """ - - trainer_config = { - "accelerator": "gpu", - "enable_checkpointing": False, - "use_distributed_sampler": False, - "max_epochs": None, - "log_every_n_steps": 1, - "limit_val_batches": 1, - "limit_test_batches": 1, - "accumulate_grad_batches": 1, - "num_nodes": self.config.num_nodes, - "devices": self.config.num_gpus, - "max_steps": self.config.max_steps_per_run, - "val_check_interval": self.config.max_steps_per_run, - } - - strategy = Config( - nl.MegatronStrategy, - pipeline_dtype=torch.bfloat16, - ) - - return Config( - nl.Trainer, - **trainer_config, - strategy=strategy, - plugins=Config(nl.MegatronMixedPrecision, precision="bf16-mixed"), - callbacks=[Config(TimingCallback)], - ) - - def get_tokenizer(self, tokenizer_type: str, tokenizer_path: str) -> Config: - """Function that returns the tokenizer config. - - Args: - tokenizer_type (str): tokenizer type. - tokenizer_path (str): path to the tokenizer. - - Returns: - Config: tokenizer config. - """ - - if tokenizer_type == "sentencepiece": - return Config(SentencePieceTokenizer, model_path=tokenizer_path) - else: - return Config(AutoTokenizer, pretrained_model_name=tokenizer_path) - - def get_data(self) -> Config[PreTrainingDataModule]: - """Function that returns dataset config. - - Returns: - Config[PreTrainingDataModule]: data config. - """ - - # Data config - data_config = { - "paths": self.config.data_paths, - "seq_length": self.config.seq_length, - "global_batch_size": self.config.global_batch_size, - "num_workers": 2, - "index_mapping_dir": None, - } - - # Define the tokenizer - tokenizer = self.get_tokenizer( - self.config.tokenizer_type, - self.config.tokenizer_path, - ) - - return Config( - PreTrainingDataModule, - **data_config, - tokenizer=tokenizer, - ) - - def get_logger(self) -> Config[nl.NeMoLogger]: - """Function that returns the training strategy. - - Returns: - Config[nl.NeMoLogger]: NeMo Logger config. - """ - - # Define TensorBoard Logger - tb_logger = Config(TensorBoardLogger, save_dir="tb_logs") - - ckpt = Config( - nl.ModelCheckpoint, - monitor="reduced_train_loss", - save_last=False, - save_top_k=0, - ) - - return Config( - nl.NeMoLogger, - ckpt=ckpt, - tensorboard=tb_logger, - wandb=None, - log_dir=self.config.path_to_logs, - ) - - def get_run_config(self) -> dict: - """Function that returns config for cluster job. - - Returns: - dict: cluster job config. - """ - - run_config = { - "name": self.config.model.__class__.__name__, - "time_limit": f"0-00:{self.config.max_minutes_per_run}:00", - } - - return run_config - def calculate_model_size( gpu_count: int, diff --git a/nemo/collections/llm/tools/auto_configurator/core/calculate_performance.py b/nemo/collections/llm/tools/auto_configurator/core/calculate_performance.py index 5b7ac0ebc4d3..1620c608e549 100644 --- a/nemo/collections/llm/tools/auto_configurator/core/calculate_performance.py +++ b/nemo/collections/llm/tools/auto_configurator/core/calculate_performance.py @@ -42,11 +42,11 @@ def get_results( vocab_size = train_config.vocab_size num_nodes = train_config.num_nodes - gpus_per_node = train_config.gpus_per_node + gpus_per_node = train_config.num_gpus - layers = base_config.model.num_layers - hs = base_config.model.hidden_size - ffn_hs = base_config.model.ffn_hidden_size + layers = base_config.model.config.num_layers + hs = base_config.model.config.hidden_size + ffn_hs = base_config.model.config.ffn_hidden_size training_logs = path_to_save final_result_logs = path_to_save @@ -60,9 +60,7 @@ def get_results( "CP", "EP", "MBS", - "Act Ckpt Layers", - "Act Ckpt Micro Bathes", - "Act Ckpt Layers per Pipeline", + "VP", "Num Layers", "Hidden Size", "FFN Hidden Size", @@ -83,9 +81,7 @@ def get_results( "CP", "EP", "MBS", - "Act Ckpt Layers", - "Act Ckpt Micro Bathes", - "Act Ckpt Layers per Pipeline", + "VP", "Num Layers", "Hidden Size", "FFN Hidden Size", @@ -96,105 +92,96 @@ def get_results( ] result = [] errors = [] + training_logs = os.path.abspath(training_logs) + error_files = find_tb_logs(training_logs, "nemo_error_log") + tb_files = find_tb_logs(training_logs, "events") dirs = [f.path for f in os.scandir(training_logs) if f.is_dir()] - for candidate_dir in dirs: - logs_dir = os.path.join(training_logs, candidate_dir, "tb_logs/lightning_logs") - logs_folder = [f.path for f in os.scandir(logs_dir) if f.is_dir()][0] - tp, pp, cp, ep, mbs, act_ckpt, num_mbs_act, act_per_pipe = get_config(candidate_dir) - - for f in os.listdir(logs_folder): - if f.endswith("0.txt"): - error_file = os.path.join(logs_folder, f) - error = find_error(error_file) - if error: - errors.append( - [ - model_name, - model_size, - seq_length, - tp, - pp, - cp, - ep, - mbs, - act_ckpt, - num_mbs_act, - act_per_pipe, - layers, - hs, - ffn_hs, - global_batch_size, - num_nodes, - gpus_per_node, - error, - ] - ) - - files = os.listdir(logs_folder) - for f in files: - if f.startswith("events"): - event_file = os.path.join(logs_folder, f) - ea = event_accumulator.EventAccumulator(event_file) - ea.Reload() - try: - timing_list = ea.Scalars("train_step_timing in s") - if len(timing_list) <= 6: - continue - timing_list = [x.value for x in timing_list[5:]] - avg_global_step_time = round(sum(timing_list) / len(timing_list), 4) - samples_per_s = round(global_batch_size / avg_global_step_time, 2) - m_tflops, m_tflops_gpu = calculate_tflops( - model_name=model_name, - gbs=global_batch_size, - enc_seq_len=seq_length, - dec_seq_len=seq_length, - hs=hs, - ffn_hs=ffn_hs, - layers=layers, - vocab=vocab_size, - nodes=num_nodes, - gpus_per_node=gpus_per_node, - time_per_step=avg_global_step_time, - ) - config_name = f"tp{tp}_pp{pp}_cp{cp}_ep{ep}_mbs{mbs}_act_{act_ckpt}_num_mbs_act_{num_mbs_act}_act_per_pipe_{act_per_pipe}" - result.append( - [ - model_name, - model_size, - seq_length, - tp, - pp, - cp, - ep, - mbs, - act_ckpt, - num_mbs_act, - act_per_pipe, - layers, - hs, - ffn_hs, - global_batch_size, - num_nodes, - gpus_per_node, - avg_global_step_time, - samples_per_s, - m_tflops_gpu, - m_tflops, - ] - ) - finally: - continue - result.sort(key=lambda x: x[17]) + for error_file, tb_file, candidate_dir in zip(error_files, tb_files, dirs): + tp, pp, cp, ep, mbs, vp = get_config(candidate_dir) + error = find_error(error_file) + if error: + errors.append( + [ + model_name, + model_size, + seq_length, + tp, + pp, + cp, + ep, + mbs, + vp, + layers, + hs, + ffn_hs, + global_batch_size, + num_nodes, + gpus_per_node, + error, + ] + ) + + ea = event_accumulator.EventAccumulator(tb_file) + ea.Reload() + try: + timing_list = ea.Scalars("train_step_timing in s") + if len(timing_list) < 10: + continue + timing_list = [x.value for x in timing_list[1:]] + print(timing_list) + avg_global_step_time = round(sum(timing_list) / len(timing_list), 2) + samples_per_s = round(global_batch_size / avg_global_step_time, 2) + print(samples_per_s) + m_tflops, m_tflops_gpu = calculate_tflops( + model_name=model_name, + gbs=global_batch_size, + enc_seq_len=seq_length, + dec_seq_len=seq_length, + hs=hs, + ffn_hs=ffn_hs, + layers=layers, + vocab=vocab_size, + nodes=num_nodes, + gpus_per_node=gpus_per_node, + time_per_step=avg_global_step_time, + ) + result.append( + [ + model_name, + model_size, + seq_length, + tp, + pp, + cp, + ep, + mbs, + vp, + layers, + hs, + ffn_hs, + global_batch_size, + num_nodes, + gpus_per_node, + avg_global_step_time, + samples_per_s, + m_tflops_gpu, + m_tflops, + ] + ) + finally: + continue + + result.sort(key=lambda x: x[15]) print(f"Top {min(output_top_n, len(result))} configs sorted from fastest to slowest:") for i, res in enumerate(result): - print(f"Config #{i+1}: {res[-1]} with {res[17]:.4f}s per global step.") + print(f"Config #{i+1}: {res[-1]} with {res[15]:.4f}s per global step.") if i + 1 == output_top_n: break - top_config = f"{model_name}_{model_size}b_{num_nodes}nodes_tp_{result[0][3]}_pp_{result[0][4]}_cp_{result[0][5]}_ep_{result[0][6]}_mbs_{result[0][7]}_act_ckpt_{result[0][8]}_num_mbs_act_{result[0][9]}_act_per_pipe_{result[0][10]}" + top_config = f"{model_name}_{model_size}b_{num_nodes}nodes_tp_{result[0][3]}_pp_{result[0][4]}_cp_{result[0][5]}_ep_{result[0][6]}_mbs_{result[0][7]}_vp_{result[0][8]}" print("\n==================================================") - print(f"Optimal config: {top_config} with {result[0][17]:.4f}s per global step.") + print(f"Optimal config: {top_config} with {result[0][15]:.4f}s per global step.") print("==================================================\n") # Save results as a CSV file. @@ -310,7 +297,8 @@ def get_config(run_name: str) -> tuple: Returns: tuple: model parallelism parameters. """ - pattern = r'_(tp|pp|cp|ep|mbs|act_ckpt|num_mbs_act|act_per_pipe)_([^_]+)' + + pattern = r'_(tp|pp|cp|ep|mbs|vp)_([^_]+)' # Find all matches in the input string matches = re.findall(pattern, run_name) @@ -324,11 +312,31 @@ def get_config(run_name: str) -> tuple: params["cp"], params["ep"], params["mbs"], - params["act_ckpt"], - params["num_mbs_act"], - params["act_per_pipe"], + params["vp"], ) +def find_tb_logs(logs_dir: str, tb_prefix: str) -> list: + """Function that finds tensorboard logs + + Args: + logs_dir (str): results directory. + + Returns: + list: list of tensorboard files. + """ + + tb_files = [] + # Walk through all directories and subdirectories + for root, dirs, files in os.walk(logs_dir): + for file in files: + # Check if the file starts with the tb prefix + if file.startswith(tb_prefix): + absolute_path = os.path.abspath(os.path.join(root, file)) + tb_files.append(absolute_path) + + return tb_files + + if __name__ == "__main__": main() diff --git a/nemo/collections/llm/tools/auto_configurator/core/training_config.py b/nemo/collections/llm/tools/auto_configurator/core/training_config.py index 087bf3c6fb0e..f7bf4d30427d 100644 --- a/nemo/collections/llm/tools/auto_configurator/core/training_config.py +++ b/nemo/collections/llm/tools/auto_configurator/core/training_config.py @@ -49,17 +49,22 @@ def generate_grid_search_configs( model_name = train_cfg.model_type model_size_in_b = train_cfg.model_size_in_b + path_to_logs = train_cfg.path_to_logs # 2 * num_layers is needed because of encoder/decoder architecture. multiplier = 1 if model_name in GPT_BASED_MODELS else 2 - seq_length = base_cfg.model.seq_length - num_layers = base_cfg.model.num_layers if model_name in GPT_BASED_MODELS else base_cfg.model.encoder.num_layers + seq_length = base_cfg.model.config.seq_length + num_layers = ( + base_cfg.model.config.num_layers + if model_name in GPT_BASED_MODELS + else base_cfg.model.config.encoder.num_layers + ) if model_name in GPT_BASED_MODELS: act_method = None else: - act_method = base_cfg.model.encoder.activations_checkpoint_method + act_method = base_cfg.model.config.encoder.activations_checkpoint_method params = _calculate_tp_pp_mbs_grid( model_size_in_b=model_size_in_b, @@ -69,7 +74,6 @@ def generate_grid_search_configs( train_cfg=train_cfg, ) - max_minutes = train_cfg.max_minutes_per_run max_steps = train_cfg.max_steps_per_run num_nodes = train_cfg.num_nodes @@ -82,11 +86,11 @@ def generate_grid_search_configs( num_gpus = base_cfg.trainer.num_nodes * base_cfg.trainer.devices base_cfg.data.global_batch_size = params.gbs if model_name in GPT_BASED_MODELS: - att_heads = base_cfg.model.num_attention_heads - num_layers = base_cfg.model.num_layers + att_heads = base_cfg.model.config.num_attention_heads + num_layers = base_cfg.model.config.num_layers else: - att_heads = base_cfg.model.encoder.num_attention_heads - num_layers = base_cfg.model.encoder.num_layers + att_heads = base_cfg.model.config.encoder.num_attention_heads + num_layers = base_cfg.model.config.encoder.num_layers model_parallelism = (tp * pp * cp * ep) if (cp and ep) else (tp * pp) mod_gbs = params.gbs % (mbs * num_gpus / model_parallelism) mod_att_heads = att_heads % tp @@ -134,9 +138,9 @@ def generate_grid_search_configs( "ep": ep, "virtual_pipelines": virtual_pipelines, "mbs": mbs, - "max_minutes": max_minutes, "max_steps": max_steps, "num_nodes": num_nodes, + "path_to_logs": path_to_logs, "model_name": model_name, "model_size": model_size_in_b, } @@ -151,12 +155,11 @@ def generate_grid_search_configs( kwargs["act_per_pipe"] = act_per_pipe new_cfg = utils.modify_cfg(**kwargs) if new_cfg: # Save candidate cfg. - configs[new_cfg["run"]["name"]] = new_cfg + configs[new_cfg["name"]] = new_cfg else: new_cfg = utils.modify_cfg(**kwargs) if new_cfg: # Save candidate cfg. - config_name = new_cfg["run"]["name"] - new_cfg.pop("run") + config_name = new_cfg["name"] configs[config_name] = new_cfg print(f"\nAll candidate configurations created correctly. Total number of configs: {len(configs)}.\n") diff --git a/nemo/collections/llm/tools/auto_configurator/core/utils.py b/nemo/collections/llm/tools/auto_configurator/core/utils.py index 3441c7cdbf9b..aeb23c0cafce 100644 --- a/nemo/collections/llm/tools/auto_configurator/core/utils.py +++ b/nemo/collections/llm/tools/auto_configurator/core/utils.py @@ -338,7 +338,7 @@ def generic_base_config(config) -> dict: AutoConfigurator: config object for the Auto Configurator tool. """ - from nemo.collections.llm.tools.auto_configurator.core.base_config import BaseConfig, calculate_model_size + from nemo.collections.llm.tools.auto_configurator.core.base_config import calculate_model_size default_model = False if config.model_size_in_b else True @@ -350,7 +350,7 @@ def generic_base_config(config) -> dict: config.num_tokens_in_b, config.model_type, ) - base_cfg = BaseConfig(config) + base_cfg = config.recipe if default_model: params = ModelSizeParams( @@ -362,14 +362,14 @@ def generic_base_config(config) -> dict: params.init_params() if config.model_type in GPT_BASED_MODELS: - base_cfg.model.num_layers = params.layers - base_cfg.model.hidden_size = params.hs - base_cfg.model.num_attention_heads = params.att_h - base_cfg.model.kv_channels = params.kv + base_cfg.model.config.num_layers = params.layers + base_cfg.model.config.hidden_size = params.hs + base_cfg.model.config.num_attention_heads = params.att_h + base_cfg.model.config.kv_channels = params.kv if not params.ffn: - base_cfg.model.ffn_hidden_size = params.hs * 4 + base_cfg.model.config.ffn_hidden_size = params.hs * 4 else: - base_cfg.model.ffn_hidden_size = params.ffn + base_cfg.model.config.ffn_hidden_size = params.ffn config.model_size_in_b = model_size_in_b @@ -387,10 +387,10 @@ def modify_cfg( ep: int, virtual_pipelines: int, mbs: int, - max_minutes: int, max_steps: int, num_nodes: int, model_name: str, + path_to_logs: str, model_size, ) -> dict: """Modify the base configuration for the model with the new parameters that are specific to the current model, which the Auto Configurator tool heuristics selected. @@ -406,7 +406,6 @@ def modify_cfg( ep (int): Expert Parallelism (EP) value to be set for the model. virtual_pipelines (int): Virtual Pipelines value to be set for the model. mbs (int): Micro Batch Size (MBS) value to be set for the model. - max_minutes (int): maximum amount of time to run this model for. max_steps (int): maximum number of steps to run this model for. num_nodes (int): number of nodes to use for the training run. model_name (str): name of the model, i.e. gpt3, t5, mt5... @@ -416,18 +415,18 @@ def modify_cfg( """ if model_name in GPT_BASED_MODELS: - att_heads = base_cfg.model.num_attention_heads - num_layers = base_cfg.model.num_layers + att_heads = base_cfg.model.config.num_attention_heads + num_layers = base_cfg.model.config.num_layers else: - att_heads = base_cfg.model.encoder.num_attention_heads - num_layers = base_cfg.model.encoder.num_layers + att_heads = base_cfg.model.config.encoder.num_attention_heads + num_layers = base_cfg.model.config.encoder.num_layers # gbs = mbs * num_gpus * accumulate_grad_batches / (tp * pp) num_gpus = base_cfg.trainer.num_nodes * base_cfg.trainer.devices gbs = base_cfg.data.global_batch_size - seq_len = base_cfg.model.seq_length + seq_len = base_cfg.model.config.seq_length - new_cfg = dict(run=base_cfg.run) + new_cfg = {} # dict(run=base_cfg.run) if act is not None: if model_name in GPT_BASED_MODELS: new_cfg["activations_checkpoint_num_layers"] = act @@ -448,6 +447,8 @@ def modify_cfg( new_cfg["pipeline_model_parallel_size"] = pp new_cfg["micro_batch_size"] = mbs new_cfg["global_batch_size"] = gbs + new_cfg["max_steps"] = max_steps + new_cfg["path_to_logs"] = path_to_logs if cp is not None: new_cfg["context_parallel_size"] = cp @@ -460,11 +461,11 @@ def modify_cfg( mod_layers = num_layers % pp if mod_gbs == 0 and mod_att_heads == 0 and mod_layers == 0: # Valid config - new_cfg["run"][ - "name" - ] = f"{model_name}_{str(model_size)}b_{num_nodes}nodes_tp_{tp}_pp_{pp}_cp_{cp}_ep_{ep}_mbs_{mbs}_act_ckpt_{act}_num_mbs_act_{num_mbs_act}_act_per_pipe_{act_per_pipe}" + new_cfg["name"] = ( + f"{model_name}_{str(model_size)}b_{num_nodes}nodes_tp_{tp}_pp_{pp}_cp_{cp}_ep_{ep}_mbs_{mbs}_vp_{virtual_pipelines}" + ) print( - f"Valid config: SeqLen={seq_len}, GBS={gbs}, MBS={mbs}, TP={tp}, PP={pp}, CP={cp}, EP={ep}, act_ckpt_layers={act}, num_mbs_act={num_mbs_act}, act_per_pipe={act_per_pipe}. Adding to directory." + f"Valid config: SeqLen={seq_len}, GBS={gbs}, MBS={mbs}, TP={tp}, PP={pp}, CP={cp}, EP={ep}, VP={virtual_pipelines}. Adding to directory." ) return new_cfg return None diff --git a/nemo/collections/llm/tools/auto_configurator/runner.py b/nemo/collections/llm/tools/auto_configurator/runner.py index 0c80c9a21a9e..7afefaa3170e 100644 --- a/nemo/collections/llm/tools/auto_configurator/runner.py +++ b/nemo/collections/llm/tools/auto_configurator/runner.py @@ -13,8 +13,8 @@ # limitations under the License. import copy +import os import re - from typing import List, Optional from nemo.collections.llm import GPTModel @@ -33,28 +33,15 @@ "nemotron", ] -SUPPORTED_TOKENIZERS = [ - "autotokenizer", - "sentencepiece", - "huggingface", -] - class AutoConfigurator: """Auto Configurator runner config class.""" def __init__( self, - model: Config = None, - num_nodes: int = None, - data_paths: List = None, + recipe: Partial = None, path_to_logs: str = None, - tokenizer_type: Optional[str] = "autotokenizer", - tokenizer_path: Optional[str] = "GPT2BPETokenizer", - gpus_per_node: Optional[int] = 8, gpu_memory_gb: Optional[int] = 80, - seq_length: Optional[int] = 2048, - global_batch_size: Optional[int] = "auto", tensor_parallel_sizes: Optional[List[int]] = "auto", pipeline_parallel_sizes: Optional[List[int]] = "auto", micro_batch_sizes: Optional[List[int]] = "auto", @@ -62,26 +49,18 @@ def __init__( expert_parallel_sizes: Optional[List[int]] = [1], min_model_parallel_size: Optional[int] = "auto", max_model_parallel_size: Optional[int] = "auto", - num_tokens_in_b: Optional[int] = 300, + num_tokens_in_b: Optional[int] = 1400, tflops_per_gpu: Optional[int] = 140, max_minutes_per_run: Optional[int] = 30, max_training_days: Optional[int] = 2, max_steps_per_run: Optional[int] = 50, - vocab_size: Optional[int] = 51200, + vocab_size: Optional[int] = 32000, ): """ Args: - model_type (Config): model type to be used for training. - num_nodes (int): number of nodes to be used for training. - data_paths (List): list of datafiles to be used for training. + recipe (Partial): recipe to be used for training. path_to_logs (str): path to the directory where the logs will be stored. - tokenizer_type (Optional[str]): tokenizer type. - tokenizer_path (Optional[str]): path to the tokenizer model. - model_size (Optional[int]): size of model to be trained. - gpus_per_node (Optional[int]): number of GPUs per node to be used. gpu_memory_gb (Optional[int]): memory per GPU, in GB. Currently 40GB and 80GB A100s/H100s supported. - seq_length (Optional[int]): model sequence length. Available seq_length list for GPT-based models: [2048, 4096, 8192, 16384, 32768]. - global_batch_size (Optional[int]): model global batch size. Set to "auto" if you want auto configurator to find optimal gbs. tensor_parallel_sizes (Optional[List[int]]): set to "auto" to use our recommendation, or a list, such as [1, 2, 4, 8]. pipeline_parallel_sizes (Optional[List[int]]): set to "auto" to use our recommendation, or a list, such as [1, 2, 4, 8]. micro_batch_sizes (Optional[List[int]]): set to "auto" to use our recommendation, or a list, such as [1, 2, 4, 8]. @@ -104,13 +83,20 @@ def __init__( setattr(self, key, value) logging.info(self._get_message(config)) - model_type = self._get_model_type(model) + model_type = self._get_model_type(recipe.model.config) assert model_type in SUPPORTED_MODELS, f"model_type must be set to one of {SUPPORTED_MODELS}." - assert tokenizer_type in SUPPORTED_TOKENIZERS, f"tokenizer_type must be set to one of {SUPPORTED_TOKENIZERS}." - assert num_nodes, "num_nodes value must be specified." - assert data_paths, "training data must be specified." + assert recipe.data.seq_length in [ + 2048, + 4096, + 8192, + 16384, + 32768, + ], "Available seq_length list for GPT-based models: [2048, 4096, 8192, 16384, 32768]." assert path_to_logs, f"path_to_logs parameter must be specified." - gpu_count = num_nodes * gpus_per_node + + self.num_gpus = recipe.trainer.devices + self.num_nodes = recipe.trainer.num_nodes + gpu_count = self.num_nodes * self.num_gpus assert gpu_count > 0, "num_nodes * gpus_per_node must be an int larger than zero." assert gpu_memory_gb in ( 40, @@ -119,9 +105,10 @@ def __init__( assert max_minutes_per_run >= 10, "max_minutes_per_run must be an int and be at least 10 minutes." self.model_type = model_type - self.model_size_in_b = self._get_model_size(model) + self.model_size_in_b = self._get_model_size(recipe.model.config) self.gpu_count = gpu_count - self.num_gpus = gpus_per_node + self.seq_length = recipe.data.seq_length + self.global_batch_size = recipe.data.global_batch_size def _get_message(self, config: dict) -> str: """ @@ -203,13 +190,10 @@ def generate_configs(runner_config: AutoConfigurator = None) -> dict: """ # Generate base config for the given model size - base_cfg, train_cfg = generic_base_config(runner_config) + base_config, train_config = generic_base_config(runner_config) # Launch grid search for training constraints - base_config, train_configs = generate_grid_search_configs(base_cfg, train_cfg) - - tokenizer = base_config.tokenizer - model = Config(GPTModel, config=base_config.model, tokenizer=tokenizer) + base_config, train_configs = generate_grid_search_configs(base_config, train_config) configs = {} for name, config in train_configs.items(): @@ -231,11 +215,16 @@ def generate_configs(runner_config: AutoConfigurator = None) -> dict: ) if config.get("tensor_model_parallel_size") > 1: trainer.strategy.sequence_parallel = True + trainer.max_steps = config.get("max_steps") + trainer.log_every_n_steps = 1 + + log.log_dir = os.path.join(config.get("path_to_logs"), name) + log.ckpt.save_last = False # Set the directory where to save the logs configs[name] = Partial( pretrain, - model=model, + model=base_config.model, trainer=trainer, data=data, optim=base_config.optim, @@ -243,4 +232,4 @@ def generate_configs(runner_config: AutoConfigurator = None) -> dict: resume=None, ) - return base_cfg, configs + return base_config, configs diff --git a/nemo/collections/nlp/modules/common/hyena/hyena.py b/nemo/collections/nlp/modules/common/hyena/hyena.py index f1b4fe20f537..4808bf1eb92c 100644 --- a/nemo/collections/nlp/modules/common/hyena/hyena.py +++ b/nemo/collections/nlp/modules/common/hyena/hyena.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Implementation of Hyena operator # # Michael Poli and Stefano Massaroli and Eric Nguyen and Daniel Y Fu and Tri Dao and Stephen Baccus and diff --git a/nemo/lightning/data.py b/nemo/lightning/data.py index e5acb1b5b8bf..6c7fd128e530 100644 --- a/nemo/lightning/data.py +++ b/nemo/lightning/data.py @@ -19,6 +19,7 @@ from typing import List, Literal, Optional import torch +from pytorch_lightning.overrides.distributed import _IndexBatchSamplerWrapper from torch.utils.data import DataLoader, Dataset @@ -139,6 +140,7 @@ def add_megatron_sampler( dataloader_type: Literal["single", "cyclic", "batch"] = "single", drop_last: bool = True, pad_samples_to_global_batch_size: bool = False, + dataloader_mode: Literal["train", "validation", "test", "predict"] = "train", rank: int = 0, world_size: int = 1, # data_sharding: bool = False @@ -170,6 +172,7 @@ def add_megatron_sampler( pad_samples_to_global_batch_size (bool, optional): Whether to pad the last incomplete batch to the `global_batch_size` (defaults to False, only applies when `drop_last` is False). + dataloader_mode (Literal["train", "validation", "test", "predict"]): The mode of dataloader. Returns: DataLoader: A new DataLoader instance with the configured Megatron sampler. @@ -214,6 +217,9 @@ def add_megatron_sampler( else: raise Exception(f'{dataloader_type} dataloader type is not supported.') + if dataloader_mode in ["test", "predict"]: + batch_sampler = _IndexBatchSamplerWrapper(batch_sampler) # BatchSampler wrapper to capture its indices + return DataLoader( dataloader.dataset, batch_sampler=batch_sampler, diff --git a/nemo/lightning/fabric/strategies.py b/nemo/lightning/fabric/strategies.py index 695595bca4d0..7445413b612e 100644 --- a/nemo/lightning/fabric/strategies.py +++ b/nemo/lightning/fabric/strategies.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from contextlib import ExitStack, contextmanager from datetime import timedelta from typing import ( diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index 28f16882305c..aad2a1696d61 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -29,6 +29,7 @@ from nemo.lightning.io.mixin import IOMixin from nemo.lightning.io.pl import ckpt_to_dir from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule from nemo.utils import logging from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO @@ -172,6 +173,16 @@ def apply_transform(self, trainer): if trainer.state.fn == TrainerFn.FITTING: trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=True) + for cb in trainer.callbacks[::-1]: + if isinstance(cb, MegatronOptimizerModule): + cb.on_fit_start(trainer, trainer.lightning_module) + break + else: + logging.warning( + "MegatronOptimizerModule not found in trainer callbacks. finalize_model_grads is not " + "properly set up for PEFT." + ) + def adapter_key_filter(self, key: str) -> bool: return key in self.trainable_params or ".adapter." in key or key.endswith(".adapters") diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py index 52ba9e3220ac..f37fd38adf53 100644 --- a/nemo/lightning/pytorch/plugins/data_sampler.py +++ b/nemo/lightning/pytorch/plugins/data_sampler.py @@ -44,7 +44,6 @@ def __init__( init_consumed_samples: int = 0, init_global_step: int = 0, output_log: bool = True, - drop_last: bool = True, ): self.seq_len = seq_len self.output_log = output_log @@ -57,7 +56,6 @@ def __init__( self.if_first_step = 0 self.prev_global_batch_size = None self.init_global_step = init_global_step - self.drop_last = drop_last def setup(self, global_rank: int) -> None: from nemo.lightning.data import setup_microbatch_calculator @@ -80,7 +78,8 @@ def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0 rampup_batch_size=self.rampup_batch_size, consumed_samples=self.init_consumed_samples if mode == 'train' else 0, dataloader_type=self.dataloader_type, - drop_last=self.drop_last, + drop_last=mode not in ["test", "predict"], # don't drop the incomplete batch in test and predict methods + dataloader_mode=mode, # dataloader wrapped with nemo.lightning.data.WrappedDataLoader has mode attribute rank=data_parallel_rank, world_size=data_parallel_size, ) diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index c869fe896279..c61c3371cc3c 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -297,6 +297,14 @@ def setup(self, trainer: pl.Trainer) -> None: self.accelerator.setup(trainer) self.trainer = trainer + try: + self.model.optim.lr_scheduler.max_steps = trainer.max_steps + logging.info(f"Copying Trainer's 'max_steps' ({trainer.max_steps}) to LR scheduler's 'max_steps'.") + except AttributeError: + logging.warning( + "Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored." + ) + # move the model to the correct device # self.model_to_device() diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index 164c07fe5b80..0d71c49bf198 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -45,7 +45,7 @@ def io_init(self, **kwargs) -> fdl.Config[Self]: return fdl.Config(type(self), **cfg_kwargs) def to_fabric(self, callbacks=None, loggers=None) -> Fabric: - accelerator, devices, strategy, plugins = None, None, None, None + accelerator, devices, strategy, plugins, num_nodes = None, None, None, None, None if hasattr(self.__io__, "devices"): devices = self.__io__.devices if hasattr(self.__io__, "accelerator"): @@ -62,6 +62,9 @@ def to_fabric(self, callbacks=None, loggers=None) -> Fabric: plugins = fdl.build(plugins) plugins = to_fabric(plugins) + if hasattr(self.__io__, "num_nodes"): + num_nodes = self.__io__.num_nodes + out = Fabric( devices=devices, accelerator=accelerator, @@ -69,6 +72,7 @@ def to_fabric(self, callbacks=None, loggers=None) -> Fabric: plugins=plugins, callbacks=callbacks, loggers=loggers, + num_nodes=num_nodes, ) return out diff --git a/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py b/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py index 18ddb8935942..ec048e4b6f19 100644 --- a/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py @@ -50,7 +50,11 @@ def get_args(): parser = ArgumentParser() parser.add_argument( - "--input_name_or_path", type=str, default=None, required=True, help="Path to .nemo file", + "--input_name_or_path", + type=str, + default=None, + required=True, + help="Path to .nemo file", ) parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to HF .bin file") parser.add_argument( @@ -94,6 +98,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True) model_config.use_cpu_initialization = True model_config.tensor_model_parallel_size = 1 + model_config.name = "te_gpt" else: map_location, model_config = None, None diff --git a/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py b/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py index 59bc0a64bbe9..5a8e52ee8be5 100644 --- a/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py @@ -50,7 +50,11 @@ def get_args(): parser = ArgumentParser() parser.add_argument( - "--input_name_or_path", type=str, default=None, required=True, help="Path to .nemo file", + "--input_name_or_path", + type=str, + default=None, + required=True, + help="Path to .nemo file", ) parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to HF .bin file") parser.add_argument( @@ -90,6 +94,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True) model_config.tensor_model_parallel_size = 1 model_config.pipeline_model_parallel_size = 1 + model_config.name = "te_gpt" if cpu_only: map_location = torch.device('cpu') model_config.use_cpu_initialization = True @@ -168,9 +173,21 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) qkv_bias_base_name = f'transformer.encoder.layers.{l}.self_attention.query_key_value.bias' - q_bias = param_to_weights(qkv_bias[q_slice].reshape(-1,)) - k_bias = param_to_weights(qkv_bias[k_slice].reshape(-1,)) - v_bias = param_to_weights(qkv_bias[v_slice].reshape(-1,)) + q_bias = param_to_weights( + qkv_bias[q_slice].reshape( + -1, + ) + ) + k_bias = param_to_weights( + qkv_bias[k_slice].reshape( + -1, + ) + ) + v_bias = param_to_weights( + qkv_bias[v_slice].reshape( + -1, + ) + ) checkpoint[qkv_bias_base_name] = torch.cat((q_bias, k_bias, v_bias)) # attention dense diff --git a/scripts/checkpoint_converters/convert_falcon_nemo_to_hf.py b/scripts/checkpoint_converters/convert_falcon_nemo_to_hf.py index 997f0ac23835..da8f15b92649 100644 --- a/scripts/checkpoint_converters/convert_falcon_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_falcon_nemo_to_hf.py @@ -51,7 +51,10 @@ def get_args(): parser = ArgumentParser() parser.add_argument( - "--input_name_or_path", type=str, required=True, help="Path to .nemo file", + "--input_name_or_path", + type=str, + required=True, + help="Path to .nemo file", ) parser.add_argument("--output_path", type=str, required=True, help="Path to HF .bin file") parser.add_argument( diff --git a/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py b/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py index 8da15148dfd8..a3c40676a980 100644 --- a/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py @@ -53,7 +53,11 @@ def get_args(): parser = ArgumentParser() parser.add_argument( - "--input_name_or_path", type=str, default=None, required=True, help="Path to .nemo file or extracted folder", + "--input_name_or_path", + type=str, + default=None, + required=True, + help="Path to .nemo file or extracted folder", ) parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to HF .bin file") parser.add_argument( @@ -105,6 +109,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True) model_config.tensor_model_parallel_size = 1 model_config.pipeline_model_parallel_size = 1 + model_config.name = "te_gpt" if cpu_only: map_location = torch.device('cpu') model_config.use_cpu_initialization = True @@ -226,13 +231,26 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> def replace_hf_weights_and_tokenizer( - weights_file, dtype, input_hf_path, output_hf_path, tokenizer_path, output_hf_tokenizer, + weights_file, + dtype, + input_hf_path, + output_hf_path, + tokenizer_path, + output_hf_tokenizer, ): - model = AutoModelForCausalLM.from_pretrained(input_hf_path, local_files_only=True, torch_dtype=dtype,) + model = AutoModelForCausalLM.from_pretrained( + input_hf_path, + local_files_only=True, + torch_dtype=dtype, + ) nemo_exported = torch.load(weights_file) if tokenizer_path: - tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path, local_files_only=True, legacy=False,) + tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_path, + local_files_only=True, + legacy=False, + ) tmp_tokenizer = convert_slow_tokenizer.convert_slow_tokenizer(tokenizer) fast_tokenizer = LlamaTokenizerFast(tokenizer_object=tmp_tokenizer) tokenizer_length = len(fast_tokenizer) diff --git a/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py b/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py index 796819c38ba4..b8c30a1b929d 100644 --- a/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py @@ -81,6 +81,7 @@ def convert(in_file, precision=None, cpu_only=True) -> None: model_config.tensor_model_parallel_size = 1 model_config.pipeline_model_parallel_size = 1 model_config.sequence_parallel = False + model_config.name = "te_gpt" if cpu_only: map_location = torch.device('cpu') model_config.use_cpu_initialization = True diff --git a/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py b/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py index 58311d0324c2..2bac2eaad616 100644 --- a/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py @@ -83,6 +83,7 @@ def convert(in_file, precision=None) -> None: model_config = MegatronGPTModel.restore_from(in_file, trainer=dummy_trainer, return_config=True) model_config.tensor_model_parallel_size = 1 model_config.pipeline_model_parallel_size = 1 + model_config.name = "te_gpt" cpu_only = True if cpu_only: map_location = torch.device('cpu') diff --git a/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py b/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py index 7a58573278af..fc0f660cbd42 100644 --- a/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py @@ -140,6 +140,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> model_config.pipeline_model_parallel_size = 1 model_config.sequence_parallel = False model_config.transformer_engine = True + model_config.name = "te_gpt" if cpu_only: map_location = torch.device("cpu") model_config.use_cpu_initialization = True diff --git a/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py b/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py index c6a218020c21..6080499ffdf8 100644 --- a/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py @@ -108,6 +108,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True) model_config.tensor_model_parallel_size = 1 model_config.pipeline_model_parallel_size = 1 + model_config.name = "te_gpt" if cpu_only: map_location = torch.device('cpu') model_config.use_cpu_initialization = True diff --git a/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py b/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py index 043d1fd35261..4b65533b74ec 100644 --- a/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py @@ -89,6 +89,7 @@ def convert(in_file, precision=None, cpu_only=True) -> None: model_config = MegatronGPTModel.restore_from(in_file, trainer=dummy_trainer, return_config=True) model_config.tensor_model_parallel_size = 1 model_config.pipeline_model_parallel_size = 1 + model_config.name = "te_gpt" if cpu_only: map_location = torch.device('cpu') model_config.use_cpu_initialization = True diff --git a/tests/collections/llm/auto_conf/test_base_configs.py b/tests/collections/llm/auto_conf/test_base_configs.py deleted file mode 100644 index d12f065d8168..000000000000 --- a/tests/collections/llm/auto_conf/test_base_configs.py +++ /dev/null @@ -1,353 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import nemo_run as run -import torch - -from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import TensorBoardLogger - -from nemo import lightning as nl -from nemo.collections.common.tokenizers import AutoTokenizer -from nemo.collections.llm import ( - GemmaConfig2B, - GPTConfig126M, - Llama3Config8B, - MistralConfig7B, - MixtralConfig8x3B, - Nemotron4Config22B, - PreTrainingDataModule, -) -from nemo.collections.llm.tools.auto_configurator import AutoConfigurator -from nemo.collections.llm.tools.auto_configurator.core.base_config import BaseConfig -from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, MegatronOptimizerModule -from nemo.utils.exp_manager import TimingCallback - - -def get_tokenizer() -> run.Config: - return run.Config(AutoTokenizer, pretrained_model_name="GPT2BPETokenizer") - - -def get_data(seq_length, global_batch_size) -> run.Config[PreTrainingDataModule]: - config = { - "paths": "/", - "seq_length": seq_length, - "global_batch_size": global_batch_size, - "num_workers": 2, - "index_mapping_dir": None, - } - - return run.Config( - PreTrainingDataModule, - **config, - tokenizer=get_tokenizer(), - ) - - -def get_trainer(num_nodes) -> run.Config[nl.Trainer]: - trainer_config = { - "accelerator": "gpu", - "enable_checkpointing": False, - "use_distributed_sampler": False, - "max_epochs": None, - "log_every_n_steps": 1, - "limit_val_batches": 1, - "limit_test_batches": 1, - "accumulate_grad_batches": 1, - "num_nodes": num_nodes, - "devices": 8, - "max_steps": 50, - "val_check_interval": 50, - } - - strategy = run.Config( - nl.MegatronStrategy, - pipeline_dtype=torch.bfloat16, - ) - - return run.Config( - nl.Trainer, - **trainer_config, - strategy=strategy, - plugins=run.Config(nl.MegatronMixedPrecision, precision="bf16-mixed"), - callbacks=[run.Config(TimingCallback)], - ) - - -def get_optim() -> run.Config[OptimizerConfig]: - optim_params = { - "optimizer": "adam", - "lr": 1e-4, - "min_lr": 1e-5, - "use_distributed_optimizer": True, - "bf16": True, - "adam_beta1": 0.9, - "adam_beta2": 0.95, - "clip_grad": 1.0, - "adam_eps": 1e-5, - } - - optim_config = run.Config( - OptimizerConfig, - **optim_params, - ) - - sched = run.Config( - CosineAnnealingScheduler, - warmup_steps=10, - constant_steps=0, - min_lr=optim_config.min_lr, - ) - - return run.Config( - MegatronOptimizerModule, - config=optim_config, - lr_scheduler=sched, - ) - - -def get_logger() -> run.Config[nl.NeMoLogger]: - tb_logger = run.Config(TensorBoardLogger, save_dir="tb_logs") - - ckpt = run.Config( - nl.ModelCheckpoint, - monitor="reduced_train_loss", - save_last=False, - save_top_k=0, - ) - - return run.Config( - nl.NeMoLogger, - ckpt=ckpt, - tensorboard=tb_logger, - wandb=None, - log_dir="/", - ) - - -class TestBaseConfigs: - def test_gpt3_base_config(self): - # GPT3 7B - model_config = run.Config(GPTConfig126M) - runner = AutoConfigurator(model=model_config, num_nodes=8, path_to_logs="/", data_paths="/") - base_config = BaseConfig(runner) - model_size = runner._get_model_size(model_config) - model_type = runner._get_model_type(model_config) - data_config = get_data(2048, 'auto') - trainer_config = get_trainer(8) - optim_config = get_optim() - logger_config = get_logger() - - assert ( - base_config.model == model_config - ), f"{model_config} is expected class object but got {base_config.model}" - assert model_size == 0.126, f"0.126 is expected size for {model_config} but got {model_size}" - assert model_type == "gpt3", f"gpt3 is expected model type for {model_config} but got {model_type}" - assert ( - base_config.data == data_config - ), f"f{data_config} is expected data config for {model_config} but got {base_config.data}" - assert ( - base_config.trainer == trainer_config - ), f"f{trainer_config} is expected trainer config for {model_config} but got {base_config.trainer}" - assert ( - base_config.optim == optim_config - ), f"f{optim_config} is expected trainer config for {model_config} but got {base_config.optim}" - assert ( - base_config.log == logger_config - ), f"f{logger_config} is expected trainer config for {model_config} but got {logger_config}" - - def test_llama_base_config(self): - # Llama3 8B - model_config = run.Config(Llama3Config8B) - runner = AutoConfigurator( - model=model_config, - num_nodes=16, - path_to_logs="/", - data_paths="/", - seq_length=8192, - global_batch_size=2048, - ) - base_config = BaseConfig(runner) - model_size = runner._get_model_size(model_config) - model_type = runner._get_model_type(model_config) - data_config = get_data(8192, 2048) - trainer_config = get_trainer(16) - optim_config = get_optim() - logger_config = get_logger() - - assert ( - base_config.model == model_config - ), f"{model_config} is expected class object but got {base_config.model}" - assert model_size == 8, f"8 is expected size for {model_config} but got {model_size}" - assert model_type == "llama", f"llama is expected model type for {model_config} but got {model_type}" - assert ( - base_config.data == data_config - ), f"f{data_config} is expected data config for {model_config} but got {base_config.data}" - assert ( - base_config.trainer == trainer_config - ), f"f{trainer_config} is expected trainer config for {model_config} but got {base_config.trainer}" - assert ( - base_config.optim == optim_config - ), f"f{optim_config} is expected trainer config for {model_config} but got {base_config.optim}" - assert ( - base_config.log == logger_config - ), f"f{logger_config} is expected trainer config for {model_config} but got {logger_config}" - - def test_mistral_base_config(self): - # Mistral 7B - model_config = run.Config(MistralConfig7B) - runner = AutoConfigurator( - model=model_config, - num_nodes=16, - path_to_logs="/", - data_paths="/", - seq_length=32768, - global_batch_size=2048, - ) - base_config = BaseConfig(runner) - model_size = runner._get_model_size(model_config) - model_type = runner._get_model_type(model_config) - data_config = get_data(32768, 2048) - trainer_config = get_trainer(16) - optim_config = get_optim() - logger_config = get_logger() - - assert ( - base_config.model == model_config - ), f"{model_config} is expected class object but got {base_config.model}" - assert model_size == 7, f"7 is expected size for {model_config} but got {model_size}" - assert model_type == "mistral", f"mistral is expected model type for {model_config} but got {model_type}" - assert ( - base_config.data == data_config - ), f"f{data_config} is expected data config for {model_config} but got {base_config.data}" - assert ( - base_config.trainer == trainer_config - ), f"f{trainer_config} is expected trainer config for {model_config} but got {base_config.trainer}" - assert ( - base_config.optim == optim_config - ), f"f{optim_config} is expected trainer config for {model_config} but got {base_config.optim}" - assert ( - base_config.log == logger_config - ), f"f{logger_config} is expected trainer config for {model_config} but got {logger_config}" - - def test_mixtral_base_config(self): - # Mixtral 8x3B - model_config = run.Config(MixtralConfig8x3B) - runner = AutoConfigurator( - model=model_config, - num_nodes=16, - path_to_logs="/", - data_paths="/", - seq_length=4096, - global_batch_size=2048, - ) - base_config = BaseConfig(runner) - model_size = runner._get_model_size(model_config) - model_type = runner._get_model_type(model_config) - data_config = get_data(4096, 2048) - trainer_config = get_trainer(16) - optim_config = get_optim() - logger_config = get_logger() - - assert ( - base_config.model == model_config - ), f"{model_config} is expected class object but got {base_config.model}" - assert model_size == 3, f"3 is expected size for {model_config} but got {model_size}" - assert model_type == "mixtral", f"mixtral is expected model type for {model_config} but got {model_type}" - assert ( - base_config.data == data_config - ), f"f{data_config} is expected data config for {model_config} but got {base_config.data}" - assert ( - base_config.trainer == trainer_config - ), f"f{trainer_config} is expected trainer config for {model_config} but got {base_config.trainer}" - assert ( - base_config.optim == optim_config - ), f"f{optim_config} is expected trainer config for {model_config} but got {base_config.optim}" - assert ( - base_config.log == logger_config - ), f"f{logger_config} is expected trainer config for {model_config} but got {logger_config}" - - def test_gemma_base_config(self): - # Gemma 2B - model_config = run.Config(GemmaConfig2B) - runner = AutoConfigurator( - model=model_config, - num_nodes=8, - path_to_logs="/", - data_paths="/", - seq_length=4096, - global_batch_size=1024, - ) - base_config = BaseConfig(runner) - model_size = runner._get_model_size(model_config) - model_type = runner._get_model_type(model_config) - data_config = get_data(4096, 1024) - trainer_config = get_trainer(8) - optim_config = get_optim() - logger_config = get_logger() - - assert ( - base_config.model == model_config - ), f"{model_config} is expected class object but got {base_config.model}" - assert model_size == 2, f"2 is expected size for {model_config} but got {model_size}" - assert model_type == "gemma", f"gemma is expected model type for {model_config} but got {model_type}" - assert ( - base_config.data == data_config - ), f"f{data_config} is expected data config for {model_config} but got {base_config.data}" - assert ( - base_config.trainer == trainer_config - ), f"f{trainer_config} is expected trainer config for {model_config} but got {base_config.trainer}" - assert ( - base_config.optim == optim_config - ), f"f{optim_config} is expected trainer config for {model_config} but got {base_config.optim}" - assert ( - base_config.log == logger_config - ), f"f{logger_config} is expected trainer config for {model_config} but got {logger_config}" - - def test_nemotron_base_config(self): - # Nemotron 22B - model_config = run.Config(Nemotron4Config22B) - runner = AutoConfigurator( - model=model_config, - num_nodes=64, - path_to_logs="/", - data_paths="/", - seq_length=4096, - global_batch_size=2048, - ) - base_config = BaseConfig(runner) - model_size = runner._get_model_size(model_config) - model_type = runner._get_model_type(model_config) - data_config = get_data(4096, 2048) - trainer_config = get_trainer(64) - optim_config = get_optim() - logger_config = get_logger() - - assert ( - base_config.model == model_config - ), f"{model_config} is expected class object but got {base_config.model}" - assert model_size == 22, f"22 is expected size for {model_config} but got {model_size}" - assert model_type == "nemotron", f"nemotron is expected model type for {model_config} but got {model_type}" - assert ( - base_config.data == data_config - ), f"f{data_config} is expected data config for {model_config} but got {base_config.data}" - assert ( - base_config.trainer == trainer_config - ), f"f{trainer_config} is expected trainer config for {model_config} but got {base_config.trainer}" - assert ( - base_config.optim == optim_config - ), f"f{optim_config} is expected trainer config for {model_config} but got {base_config.optim}" - assert ( - base_config.log == logger_config - ), f"f{logger_config} is expected trainer config for {model_config} but got {logger_config}" diff --git a/tests/collections/llm/auto_conf/test_generate_configs.py b/tests/collections/llm/auto_conf/test_generate_configs.py index f10425631f98..0d3e230e39fa 100644 --- a/tests/collections/llm/auto_conf/test_generate_configs.py +++ b/tests/collections/llm/auto_conf/test_generate_configs.py @@ -12,16 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import nemo_run as run - -from nemo.collections.llm import ( - GemmaConfig7B, - GPTConfig5B, - Llama3Config70B, - MistralConfig7B, - MixtralConfig8x22B, - Nemotron3Config8B, -) +from functools import partial + +from nemo.collections import llm from nemo.collections.llm.tools.auto_configurator import AutoConfigurator, generate_configs @@ -42,58 +35,12 @@ def get_auto_configs(configs): class TestGenerateConfgis: - def test_gpt_model(self): - # GPT3 126M - runner = AutoConfigurator( - model=run.Config(GPTConfig5B), - num_nodes=16, - seq_length=2048, - global_batch_size=2048, - tensor_parallel_sizes=[4], - pipeline_parallel_sizes=[2], - micro_batch_sizes=[1, 2], - context_parallel_sizes=[1], - expert_parallel_sizes=[1], - min_model_parallel_size=8, - max_model_parallel_size=8, - data_paths="/", - path_to_logs="/", - ) - - _, configs = generate_configs(runner) - - mbs = [1, 2] - for run_name, config, mb in zip(configs.keys(), configs.values(), mbs): - assert config.data.micro_batch_size == mb - assert config.data.seq_length == 2048 - assert config.data.global_batch_size == 2048 - - assert len(configs) == 2, f"{len(configs)} configurations were generated but 2 were expected." - - auto_configs = get_auto_configs(configs) - assert auto_configs[0] == [ - 4, - 2, - 1, - 1, - 1, - ], f"[4, 2, 1, 1, 1] is expected configuration output but got {auto_configs[0]}." - - assert auto_configs[1] == [ - 4, - 2, - 1, - 1, - 2, - ], f"[4, 2, 1, 1, 2] is expected configuration output but got {auto_configs[1]}." - def test_llama_model(self): # Llama3 70B + recipe = partial(llm.llama3_70b.pretrain_recipe, num_nodes=128, num_gpus_per_node=8)() + recipe.data.global_batch_size = 2048 runner = AutoConfigurator( - model=run.Config(Llama3Config70B), - num_nodes=128, - seq_length=8192, - global_batch_size=2048, + recipe=recipe, tensor_parallel_sizes="auto", pipeline_parallel_sizes="auto", micro_batch_sizes=[1], @@ -101,7 +48,6 @@ def test_llama_model(self): expert_parallel_sizes=[1], min_model_parallel_size=16, max_model_parallel_size=64, - data_paths="/", path_to_logs="/", ) @@ -142,11 +88,13 @@ def test_llama_model(self): def test_mistral_model(self): # Mistral 7B + recipe = partial(llm.mistral_7b.pretrain_recipe, num_nodes=16, num_gpus_per_node=8)() + recipe.data.seq_length = 4096 + recipe.data.global_batch_size = 2048 + recipe.model.config.seq_length = recipe.data.seq_length + runner = AutoConfigurator( - model=run.Config(MistralConfig7B), - num_nodes=16, - seq_length=4096, - global_batch_size=2048, + recipe=recipe, tensor_parallel_sizes=[4], pipeline_parallel_sizes=[1, 2], micro_batch_sizes=[1], @@ -154,7 +102,6 @@ def test_mistral_model(self): expert_parallel_sizes=[1], min_model_parallel_size=4, max_model_parallel_size=8, - data_paths="/", path_to_logs="/", ) @@ -187,11 +134,13 @@ def test_mistral_model(self): def test_mixtral_model(self): # Mixtral 8x22B + recipe = partial(llm.mixtral_8x22b.pretrain_recipe, num_nodes=16, num_gpus_per_node=8)() + recipe.data.seq_length = 4096 + recipe.data.global_batch_size = 2048 + recipe.model.config.seq_length = recipe.data.seq_length + runner = AutoConfigurator( - model=run.Config(MixtralConfig8x22B), - num_nodes=16, - seq_length=4096, - global_batch_size=2048, + recipe=recipe, tensor_parallel_sizes=[4], pipeline_parallel_sizes=[1], micro_batch_sizes=[1], @@ -199,7 +148,6 @@ def test_mixtral_model(self): expert_parallel_sizes=[1, 2], min_model_parallel_size=4, max_model_parallel_size=8, - data_paths="/", path_to_logs="/", ) @@ -232,11 +180,13 @@ def test_mixtral_model(self): def test_gemma_model(self): # Gemma 7B + recipe = partial(llm.gemma_7b.pretrain_recipe, num_nodes=16, num_gpus_per_node=8)() + recipe.data.seq_length = 8192 + recipe.data.global_batch_size = 2048 + recipe.model.config.seq_length = recipe.data.seq_length + runner = AutoConfigurator( - model=run.Config(GemmaConfig7B), - num_nodes=16, - seq_length=8192, - global_batch_size=2048, + recipe=recipe, tensor_parallel_sizes=[2], pipeline_parallel_sizes=[2], micro_batch_sizes=[1, 2], @@ -244,7 +194,6 @@ def test_gemma_model(self): expert_parallel_sizes=[1], min_model_parallel_size=4, max_model_parallel_size=8, - data_paths="/", path_to_logs="/", ) @@ -277,11 +226,13 @@ def test_gemma_model(self): def test_nemotron_model(self): # Nemotron3 8B + recipe = partial(llm.nemotron3_8b.pretrain_recipe, num_nodes=16, num_gpus_per_node=8)() + recipe.data.seq_length = 4096 + recipe.data.global_batch_size = 2048 + recipe.model.config.seq_length = recipe.data.seq_length + runner = AutoConfigurator( - model=run.Config(Nemotron3Config8B), - num_nodes=16, - seq_length=4096, - global_batch_size=2048, + recipe=recipe, tensor_parallel_sizes=[1], pipeline_parallel_sizes=[4], micro_batch_sizes=[1, 2], @@ -289,7 +240,6 @@ def test_nemotron_model(self): expert_parallel_sizes=[1], min_model_parallel_size=4, max_model_parallel_size=8, - data_paths="/", path_to_logs="/", ) diff --git a/tests/collections/llm/common.py b/tests/collections/llm/common.py new file mode 100644 index 000000000000..95b8bc0de584 --- /dev/null +++ b/tests/collections/llm/common.py @@ -0,0 +1,190 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytorch_lightning as pl +import torch + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.common.tokenizers import SentencePieceTokenizer +from nemo.utils import logging + + +def train_data( + data_path: str, tokenizer_path: str, index_mapping_dir: str, seq_length: int +) -> llm.PreTrainingDataModule: + """Single shard dataset tokenized by SentencePiece""" + tokenizer = SentencePieceTokenizer(model_path=tokenizer_path) + return llm.PreTrainingDataModule( + paths=data_path, + tokenizer=tokenizer, + seq_length=seq_length, + micro_batch_size=4, + global_batch_size=32, + seed=1234, + index_mapping_dir=index_mapping_dir, + ) + + +def small_llama_cfg(seq_length: int) -> llm.GPTConfig: + """Small 145m model""" + return llm.Llama3Config8B( + rotary_base=500_000, + seq_length=seq_length, + num_layers=12, + hidden_size=768, + ffn_hidden_size=2688, + num_attention_heads=16, + init_method_std=0.023, + ) + + +class StopBeforeEnd(pl.Callback): + """Preemptively stop training at a given global step. Allows stopping training before reaching + the max steps. Useful for testing checkpoint save and resume. + + Args: + stop_on_step (int): Stop training when trainer.global_step reaches this value. + Checked at the start of every step. + """ + + def __init__(self, stop_on_step: int): + self.stop_on_step = stop_on_step + + def on_train_batch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx + ) -> None: + if trainer.global_step >= self.stop_on_step: + logging.info(f"Global step {trainer.global_step} >= {self.stop_on_step}, signaling Trainer to stop.") + trainer.should_stop = True + # skip EarlyStopping validation unless val_check_interval met + if trainer.global_step % trainer.val_check_interval != 0: + trainer.limit_val_batches = 0 + + +class MCoreModelAttributeValidator(pl.Callback): + """Walk through submodules and verify user-specified attributes like parallelisms.""" + + def __init__(self, attr_dict: dict): + super().__init__() + self.attr_dict = attr_dict + + def _check_attrs(self, target): + for k, v in self.attr_dict.items(): + if hasattr(target, k): + model_val = getattr(target, k) + assert ( + model_val == v + ), f"Key {k} for model ({model_val}) does not match {v} from provided attribute mapping." + + def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + def walk_fn(module: torch.nn.Module) -> torch.nn.Module: + # self._check_attrs(module) # TE DPA has 'sequence_parallel' attribute that is always False. Checking module config should be sufficient + if hasattr(module, "config"): + self._check_attrs(module.config) + + return module + + trainer.model.walk(walk_fn) + + +class MiscAttributeValidator(pl.Callback): + """Place for any miscellaneous attribute assertions. Extend as needed.""" + + def __init__(self, attr_dict: dict): + super().__init__() + self.attr_dict = attr_dict + + def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + if 'max_steps' in self.attr_dict: + sched_max = trainer.model.optim.lr_scheduler._scheduler['lr_scheduler']['scheduler'].max_steps + assert ( + trainer.max_steps == self.attr_dict['max_steps'] + ), f"Trainer max_steps {trainer.max_steps} did not match provided {self.attr_dict['max_steps']}" + assert ( + sched_max == self.attr_dict['max_steps'] + ), f"Scheduler max_steps {sched_max} did not match provided {self.attr_dict['max_steps']}" + + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if 'stop_on_step' in self.attr_dict: + total_steps = trainer.fit_loop.epoch_loop.batch_progress.total.completed + assert total_steps == self.attr_dict['stop_on_step'] + + +def verify_distcp_dir(ckpt_path: str) -> None: + ckpt_name = os.path.basename(ckpt_path) + + weights_dir = os.path.join(ckpt_path, 'weights') + assert os.path.isdir(weights_dir), f"Weights not found in checkpoint {ckpt_name}" + assert os.path.isfile(os.path.join(weights_dir, 'common.pt')), f"No 'common.pt' file in checkpoint {ckpt_name}" + assert os.path.isfile( + os.path.join(weights_dir, 'metadata.json') + ), f"No 'metadata.json' file in checkpoint {ckpt_name}" + + shards = [shard for shard in os.listdir(weights_dir) if shard.endswith('.distcp')] + world_size = torch.distributed.get_world_size() + assert ( + len(shards) == 2 * world_size + ), f"Wrong number of .distcp files, Expected: {2*world_size} Found: {len(shards)}" + + +def verify_ckpt_dir( + model_ckpt: nl.ModelCheckpoint, max_steps: int, val_check_interval: int, exp_dir: str, dist_ckpts: bool = True +) -> None: + """Ensures that the provided checkpoint directory has + - correct number of checkpoints + - no more than top-k checkpoints + - no unfinished checkpoints + - a checkpoint for the last step + - all checkpoints in the correct format + """ + + ckpt_dir = os.path.join(exp_dir, 'checkpoints') + ckpts = os.listdir(ckpt_dir) + + if model_ckpt.save_last: + assert any([c.endswith('-last') for c in ckpts]), "No -last checkpoint found after training" + + expected_count = (max_steps // val_check_interval) + model_ckpt.save_last + if model_ckpt.save_top_k > 0: + assert ( + len(ckpts) == expected_count or len(ckpts) == model_ckpt.save_top_k + model_ckpt.save_last + ), f"Expected {expected_count} checkpoints or at most top {model_ckpt.save_top_k} checkpoints besides '-last'" + else: + assert len(ckpts) == expected_count, f"Expected {expected_count} checkpoints" + + for ckpt_name in ckpts: + ckpt_path = os.path.join(ckpt_dir, ckpt_name) + + assert ( + '-unfinished' not in ckpt_name + ), f"Unfinished checkpoint found. Something went wrong with saving checkpoint {ckpt_name}" + + if ckpt_name.endswith('-last') and 'step' in model_ckpt.filename: + assert f'step={max_steps-1}' in ckpt_name, f"Last checkpoint {ckpt_name} not for final step {max_steps}" + + if dist_ckpts: + assert os.path.isdir(ckpt_path), "Checkpoint is not correct type" + verify_distcp_dir(ckpt_path) + else: + assert os.path.isfile(ckpt_path), "Checkpoint is not correct type" + + +def create_verify_precision(precision: torch.dtype): + def verify_precision(tensor: torch.Tensor) -> None: + assert tensor.dtype == precision + + return verify_precision diff --git a/tests/collections/llm/llama3_pretraining.py b/tests/collections/llm/llama3_pretraining.py new file mode 100644 index 000000000000..24eeca8f01c8 --- /dev/null +++ b/tests/collections/llm/llama3_pretraining.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test the LLaMA3 recipe with a smaller model. +""" + +import argparse +import os + +import nemo_run as run +import torch + +from nemo.collections import llm +from nemo.lightning.pytorch.callbacks.debugging import ParameterDebugger +from tests.collections.llm.common import ( + MCoreModelAttributeValidator, + MiscAttributeValidator, + StopBeforeEnd, + create_verify_precision, + small_llama_cfg, + train_data, + verify_ckpt_dir, +) + + +def get_args(): + parser = argparse.ArgumentParser(prog="", description="") + parser.add_argument('--devices', type=int, required=True, help="Number of devices to use for training") + parser.add_argument('--max-steps', type=int, required=True, help="Number of steps to train for") + parser.add_argument( + '--early-stop', + type=int, + default=None, + help="Stop training early at this global step (for testing resume training)", + ) + parser.add_argument( + '--experiment-dir', type=str, required=True, help="directory to write results and checkpoints to" + ) + parser.add_argument( + '--data-path', type=str, default=None, help="Path to data file. If not specified, uses mock data." + ) + parser.add_argument( + '--tokenizer-path', + type=str, + default=None, + help="Path to a sentencepiece tokenizer model file. If not specified, uses mock data.", + ) + parser.add_argument('--index-mapping-dir', type=str, help="directory to write index mappings to") + parser.add_argument('--seq-length', type=int, default=8192, help="Sequence length. default is 8k") + parser.add_argument('--tp', type=int, default=None, help="Override tensor parallelism") + parser.add_argument('--pp', type=int, default=None, help="Override pipeline parallelism") + parser.add_argument('--vp', type=int, default=None, help="Override virtual pipeline parallelism") + parser.add_argument('--cp', type=int, default=None, help="Override context parallelism") + parser.add_argument('--sp', type=int, choices=[0, 1], default=None, help="Override sequence parallel") + parser.add_argument( + '--precision', type=str, choices=['bf16', 'fp16', 'fp32'], default='bf16', help="Override recipe precision" + ) + parser.add_argument('--fp8', action='store_true', help="Enable FP8") + + return parser.parse_args() + + +def main(): + args = get_args() + + exp_name = "L2_llama3_small_pretrain_test" + pretrain_recipe = llm.llama3_8b.pretrain_recipe( + dir=args.experiment_dir, name=exp_name, num_gpus_per_node=args.devices + ) + + pretrain_recipe.model = llm.LlamaModel(small_llama_cfg(args.seq_length)) + + if args.data_path and args.tokenizer_path: + pretrain_recipe.data = train_data( + data_path=args.data_path, + tokenizer_path=args.tokenizer_path, + index_mapping_dir=args.index_mapping_dir, + seq_length=args.seq_length, + ) + + # Recipe Overrides + pretrain_recipe.trainer.max_steps = args.max_steps + pretrain_recipe.trainer.log_every_n_steps = 1 + pretrain_recipe.log.ckpt.every_n_train_steps = None + pretrain_recipe.log.ckpt.train_time_interval = None + pretrain_recipe.trainer.val_check_interval = 2 + pretrain_recipe.trainer.limit_val_batches = 2 + + if args.early_stop: + pretrain_recipe.trainer.callbacks.append(StopBeforeEnd(stop_on_step=args.early_stop)) + + if not args.precision == 'bf16' or args.fp8: # default case is bf16 without fp8 + import llm.recipes.precision.mixed_precision as mp_recipes + + key = (args.precision, args.fp8) + precision_recipe = { + ("fp16", False): mp_recipes.fp16_mixed, + ("bf16", True): mp_recipes.bf16_with_fp8_mixed, + ("fp16", True): mp_recipes.fp16_with_fp8_mixed, + # Need fp32 + }[key] + pretrain_recipe.trainer.plugins = precision_recipe() + dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} + debugger_callback = ParameterDebugger( + param_fn=create_verify_precision(dtype_map[args.precision]), + grad_fn=create_verify_precision(torch.float32), + log_on_hooks=["on_train_start", "on_train_end"], + ) + pretrain_recipe.trainer.callbacks.append(debugger_callback) + + parallelisms = { + "tensor_model_parallel_size": args.tp, + "pipeline_model_parallel_size": args.pp, + "virtual_pipeline_model_parallel_size": args.vp, + "context_parallel_size": args.cp, + "sequence_parallel": bool(args.sp) if args.sp is not None else None, + } + for k, v in parallelisms.items(): + if v is not None: # use recipe default if not specified + setattr(pretrain_recipe.trainer.strategy, k, v) + parallelisms[k] = getattr(pretrain_recipe.trainer.strategy, k) + pretrain_recipe.trainer.callbacks.append(MCoreModelAttributeValidator(parallelisms)) + + misc_checker = MiscAttributeValidator( + {"max_steps": args.max_steps, "stop_on_step": args.early_stop or args.max_steps} + ) + pretrain_recipe.trainer.callbacks.append(misc_checker) + + run.run(pretrain_recipe, direct=True) + + verify_ckpt_dir( + pretrain_recipe.log.ckpt, + args.early_stop or args.max_steps, + pretrain_recipe.trainer.val_check_interval, + os.path.join(args.experiment_dir, exp_name), + ) + + +if __name__ == '__main__': + main() diff --git a/tests/collections/llm/recipes/test_mixtral_8x7b.py b/tests/collections/llm/recipes/test_mixtral_8x7b.py index 409dc26a8aa4..62dc0db3d884 100644 --- a/tests/collections/llm/recipes/test_mixtral_8x7b.py +++ b/tests/collections/llm/recipes/test_mixtral_8x7b.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import nemo_run as run import pytest import torch diff --git a/tests/collections/llm/recipes/test_t5_11b.py b/tests/collections/llm/recipes/test_t5_11b.py new file mode 100644 index 000000000000..8c4ab8332c18 --- /dev/null +++ b/tests/collections/llm/recipes/test_t5_11b.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nemo_run as run +import pytest + +from nemo.collections.llm.api import pretrain +from nemo.collections.llm.recipes import t5_11b +from nemo.collections.llm.t5.data.mock import MockDataModule +from nemo.collections.llm.t5.model.t5 import T5Config11B, T5Model +from nemo.lightning import Trainer + + +class TestT5_11B: + @pytest.fixture(scope="class") + def recipe_module(self): + return t5_11b + + def test_model(self, recipe_module): + model_config = recipe_module.model() + assert isinstance(model_config, run.Config) + assert model_config.__fn_or_cls__ == T5Model + assert isinstance(model_config.config, run.Config) + assert model_config.config.__fn_or_cls__ == T5Config11B + + def test_trainer(self, recipe_module): + trainer_config = recipe_module.trainer() + assert isinstance(trainer_config, run.Config) + assert trainer_config.__fn_or_cls__ == Trainer + assert trainer_config.accelerator == "gpu" + assert trainer_config.devices == 8 + assert trainer_config.num_nodes == 20 + assert trainer_config.max_steps == 1000000 + + # Check strategy configuration + assert isinstance(trainer_config.strategy, run.Config) + assert trainer_config.strategy.__fn_or_cls__.__name__ == "MegatronStrategy" + assert trainer_config.strategy.tensor_model_parallel_size == 4 + assert trainer_config.strategy.pipeline_model_parallel_size == 1 + assert trainer_config.strategy.pipeline_dtype is None + assert trainer_config.strategy.virtual_pipeline_model_parallel_size is None + assert trainer_config.strategy.context_parallel_size == 1 + assert trainer_config.strategy.sequence_parallel is False + assert trainer_config.strategy.gradient_as_bucket_view is True + assert trainer_config.strategy.ckpt_async_save is True + assert trainer_config.strategy.ckpt_parallel_load is True + + # Check other trainer configurations + assert trainer_config.accumulate_grad_batches == 1 + assert trainer_config.limit_test_batches == 50 + assert trainer_config.limit_val_batches == 32 + assert trainer_config.log_every_n_steps == 10 + assert trainer_config.use_distributed_sampler is False + assert trainer_config.val_check_interval == 2000 + + # Check plugins + assert isinstance(trainer_config.plugins, run.Config) + assert trainer_config.plugins.__fn_or_cls__.__name__ == "MegatronMixedPrecision" + + def test_pretrain_recipe(self, recipe_module): + recipe = recipe_module.pretrain_recipe() + assert isinstance(recipe, run.Partial) + assert recipe.__fn_or_cls__ == pretrain + assert isinstance(recipe.model, run.Config) + assert recipe.model.__fn_or_cls__ == T5Model + assert isinstance(recipe.trainer, run.Config) + assert recipe.trainer.__fn_or_cls__ == Trainer + assert isinstance(recipe.data, run.Config) + assert recipe.data.__fn_or_cls__ == MockDataModule + assert recipe.data.seq_length == 512 + assert recipe.data.seq_length_dec == 128 + assert recipe.data.global_batch_size == 1920 + + @pytest.mark.parametrize("num_nodes,num_gpus_per_node", [(1, 8), (2, 4), (4, 2)]) + def test_pretrain_recipe_with_different_configurations(self, recipe_module, num_nodes, num_gpus_per_node): + recipe = recipe_module.pretrain_recipe(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) + assert recipe.trainer.num_nodes == num_nodes + assert recipe.trainer.devices == num_gpus_per_node + + def test_trainer_parallelism_options(self, recipe_module): + trainer_config = recipe_module.trainer( + tensor_parallelism=2, + pipeline_parallelism=2, + ) + assert trainer_config.strategy.tensor_model_parallel_size == 2 + assert trainer_config.strategy.pipeline_model_parallel_size == 2 + + def test_model_config_parameters(self, recipe_module): + model_config = recipe_module.model() + llama_config = model_config.config + assert llama_config.num_layers == 24 + assert llama_config.encoder_num_layers == 24 + assert llama_config.hidden_size == 4096 + assert llama_config.ffn_hidden_size == 10240 + assert llama_config.num_attention_heads == 64 diff --git a/tests/collections/llm/recipes/test_t5_220m.py b/tests/collections/llm/recipes/test_t5_220m.py new file mode 100644 index 000000000000..744598e3b01b --- /dev/null +++ b/tests/collections/llm/recipes/test_t5_220m.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nemo_run as run +import pytest + +from nemo.collections.llm.api import pretrain +from nemo.collections.llm.recipes import t5_220m +from nemo.collections.llm.t5.data.mock import MockDataModule +from nemo.collections.llm.t5.model.t5 import T5Config220M, T5Model +from nemo.lightning import Trainer + + +class TestT5_220M: + @pytest.fixture(scope="class") + def recipe_module(self): + return t5_220m + + def test_model(self, recipe_module): + model_config = recipe_module.model() + assert isinstance(model_config, run.Config) + assert model_config.__fn_or_cls__ == T5Model + assert isinstance(model_config.config, run.Config) + assert model_config.config.__fn_or_cls__ == T5Config220M + + def test_trainer(self, recipe_module): + trainer_config = recipe_module.trainer() + assert isinstance(trainer_config, run.Config) + assert trainer_config.__fn_or_cls__ == Trainer + assert trainer_config.accelerator == "gpu" + assert trainer_config.devices == 8 + assert trainer_config.num_nodes == 1 + assert trainer_config.max_steps == 1000000 + + # Check strategy configuration + assert isinstance(trainer_config.strategy, run.Config) + assert trainer_config.strategy.__fn_or_cls__.__name__ == "MegatronStrategy" + assert trainer_config.strategy.tensor_model_parallel_size == 1 + assert trainer_config.strategy.pipeline_model_parallel_size == 1 + assert trainer_config.strategy.pipeline_dtype is None + assert trainer_config.strategy.virtual_pipeline_model_parallel_size is None + assert trainer_config.strategy.context_parallel_size == 1 + assert trainer_config.strategy.sequence_parallel is False + assert trainer_config.strategy.gradient_as_bucket_view is True + assert trainer_config.strategy.ckpt_async_save is True + assert trainer_config.strategy.ckpt_parallel_load is True + + # Check other trainer configurations + assert trainer_config.accumulate_grad_batches == 1 + assert trainer_config.limit_test_batches == 50 + assert trainer_config.limit_val_batches == 32 + assert trainer_config.log_every_n_steps == 10 + assert trainer_config.use_distributed_sampler is False + assert trainer_config.val_check_interval == 2000 + + # Check plugins + assert isinstance(trainer_config.plugins, run.Config) + assert trainer_config.plugins.__fn_or_cls__.__name__ == "MegatronMixedPrecision" + + def test_pretrain_recipe(self, recipe_module): + recipe = recipe_module.pretrain_recipe() + assert isinstance(recipe, run.Partial) + assert recipe.__fn_or_cls__ == pretrain + assert isinstance(recipe.model, run.Config) + assert recipe.model.__fn_or_cls__ == T5Model + assert isinstance(recipe.trainer, run.Config) + assert recipe.trainer.__fn_or_cls__ == Trainer + assert isinstance(recipe.data, run.Config) + assert recipe.data.__fn_or_cls__ == MockDataModule + assert recipe.data.seq_length == 512 + assert recipe.data.seq_length_dec == 128 + assert recipe.data.global_batch_size == 512 + + @pytest.mark.parametrize("num_nodes,num_gpus_per_node", [(1, 8), (2, 4), (4, 2)]) + def test_pretrain_recipe_with_different_configurations(self, recipe_module, num_nodes, num_gpus_per_node): + recipe = recipe_module.pretrain_recipe(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) + assert recipe.trainer.num_nodes == num_nodes + assert recipe.trainer.devices == num_gpus_per_node + + def test_trainer_parallelism_options(self, recipe_module): + trainer_config = recipe_module.trainer( + tensor_parallelism=2, + pipeline_parallelism=2, + ) + assert trainer_config.strategy.tensor_model_parallel_size == 2 + assert trainer_config.strategy.pipeline_model_parallel_size == 2 + + def test_model_config_parameters(self, recipe_module): + model_config = recipe_module.model() + llama_config = model_config.config + assert llama_config.num_layers == 12 + assert llama_config.hidden_size == 768 + assert llama_config.ffn_hidden_size == 3072 + assert llama_config.num_attention_heads == 12 diff --git a/tests/collections/llm/recipes/test_t5_3b.py b/tests/collections/llm/recipes/test_t5_3b.py new file mode 100644 index 000000000000..7672b95426cb --- /dev/null +++ b/tests/collections/llm/recipes/test_t5_3b.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nemo_run as run +import pytest + +from nemo.collections.llm.api import pretrain +from nemo.collections.llm.recipes import t5_3b +from nemo.collections.llm.t5.data.mock import MockDataModule +from nemo.collections.llm.t5.model.t5 import T5Config3B, T5Model +from nemo.lightning import Trainer + + +class TestT5_3B: + @pytest.fixture(scope="class") + def recipe_module(self): + return t5_3b + + def test_model(self, recipe_module): + model_config = recipe_module.model() + assert isinstance(model_config, run.Config) + assert model_config.__fn_or_cls__ == T5Model + assert isinstance(model_config.config, run.Config) + assert model_config.config.__fn_or_cls__ == T5Config3B + + def test_trainer(self, recipe_module): + trainer_config = recipe_module.trainer() + assert isinstance(trainer_config, run.Config) + assert trainer_config.__fn_or_cls__ == Trainer + assert trainer_config.accelerator == "gpu" + assert trainer_config.devices == 8 + assert trainer_config.num_nodes == 20 + assert trainer_config.max_steps == 1000000 + + # Check strategy configuration + assert isinstance(trainer_config.strategy, run.Config) + assert trainer_config.strategy.__fn_or_cls__.__name__ == "MegatronStrategy" + assert trainer_config.strategy.tensor_model_parallel_size == 2 + assert trainer_config.strategy.pipeline_model_parallel_size == 1 + assert trainer_config.strategy.pipeline_dtype is None + assert trainer_config.strategy.virtual_pipeline_model_parallel_size is None + assert trainer_config.strategy.context_parallel_size == 1 + assert trainer_config.strategy.sequence_parallel is False + assert trainer_config.strategy.gradient_as_bucket_view is True + assert trainer_config.strategy.ckpt_async_save is True + assert trainer_config.strategy.ckpt_parallel_load is True + + # Check other trainer configurations + assert trainer_config.accumulate_grad_batches == 1 + assert trainer_config.limit_test_batches == 50 + assert trainer_config.limit_val_batches == 32 + assert trainer_config.log_every_n_steps == 10 + assert trainer_config.use_distributed_sampler is False + assert trainer_config.val_check_interval == 2000 + + # Check plugins + assert isinstance(trainer_config.plugins, run.Config) + assert trainer_config.plugins.__fn_or_cls__.__name__ == "MegatronMixedPrecision" + + def test_pretrain_recipe(self, recipe_module): + recipe = recipe_module.pretrain_recipe() + assert isinstance(recipe, run.Partial) + assert recipe.__fn_or_cls__ == pretrain + assert isinstance(recipe.model, run.Config) + assert recipe.model.__fn_or_cls__ == T5Model + assert isinstance(recipe.trainer, run.Config) + assert recipe.trainer.__fn_or_cls__ == Trainer + assert isinstance(recipe.data, run.Config) + assert recipe.data.__fn_or_cls__ == MockDataModule + assert recipe.data.seq_length == 512 + assert recipe.data.seq_length_dec == 128 + assert recipe.data.global_batch_size == 1920 + + @pytest.mark.parametrize("num_nodes,num_gpus_per_node", [(1, 8), (2, 4), (4, 2)]) + def test_pretrain_recipe_with_different_configurations(self, recipe_module, num_nodes, num_gpus_per_node): + recipe = recipe_module.pretrain_recipe(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) + assert recipe.trainer.num_nodes == num_nodes + assert recipe.trainer.devices == num_gpus_per_node + + def test_trainer_parallelism_options(self, recipe_module): + trainer_config = recipe_module.trainer( + tensor_parallelism=2, + pipeline_parallelism=2, + ) + assert trainer_config.strategy.tensor_model_parallel_size == 2 + assert trainer_config.strategy.pipeline_model_parallel_size == 2 + + def test_model_config_parameters(self, recipe_module): + model_config = recipe_module.model() + llama_config = model_config.config + assert llama_config.num_layers == 24 + assert llama_config.encoder_num_layers == 24 + assert llama_config.hidden_size == 2048 + assert llama_config.ffn_hidden_size == 5120 + assert llama_config.num_attention_heads == 32